mirror of https://github.com/telemt/telemt.git
Compare commits
5 Commits
d78360982c
...
4f55d08c51
| Author | SHA1 | Date |
|---|---|---|
|
|
4f55d08c51 | |
|
|
93caab1aec | |
|
|
0c6bb3a641 | |
|
|
b2e15327fe | |
|
|
2e8be87ccf |
|
|
@ -381,7 +381,7 @@ fn validate_tls_handshake_at_time_with_boot_cap(
|
||||||
let mut msg = handshake.to_vec();
|
let mut msg = handshake.to_vec();
|
||||||
msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
|
msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
|
||||||
|
|
||||||
let mut first_match: Option<TlsValidation> = None;
|
let mut first_match: Option<(&String, u32)> = None;
|
||||||
|
|
||||||
for (user, secret) in secrets {
|
for (user, secret) in secrets {
|
||||||
let computed = sha256_hmac(secret, &msg);
|
let computed = sha256_hmac(secret, &msg);
|
||||||
|
|
@ -421,16 +421,16 @@ fn validate_tls_handshake_at_time_with_boot_cap(
|
||||||
}
|
}
|
||||||
|
|
||||||
if first_match.is_none() {
|
if first_match.is_none() {
|
||||||
first_match = Some(TlsValidation {
|
first_match = Some((user, timestamp));
|
||||||
user: user.clone(),
|
|
||||||
session_id: session_id.clone(),
|
|
||||||
digest,
|
|
||||||
timestamp,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
first_match
|
first_match.map(|(user, timestamp)| TlsValidation {
|
||||||
|
user: user.clone(),
|
||||||
|
session_id,
|
||||||
|
digest,
|
||||||
|
timestamp,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn curve25519_prime() -> BigUint {
|
fn curve25519_prime() -> BigUint {
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,19 @@ use crate::crypto::sha256_hmac;
|
||||||
/// [TLS_DIGEST_POS..+32] : digest = HMAC XOR [0..0 || timestamp_le]
|
/// [TLS_DIGEST_POS..+32] : digest = HMAC XOR [0..0 || timestamp_le]
|
||||||
/// [TLS_DIGEST_POS+32] : session_id_len = 32
|
/// [TLS_DIGEST_POS+32] : session_id_len = 32
|
||||||
/// [TLS_DIGEST_POS+33..+65] : session_id filler (0x42)
|
/// [TLS_DIGEST_POS+33..+65] : session_id filler (0x42)
|
||||||
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
|
fn make_valid_tls_handshake_with_session_id(
|
||||||
let session_id_len: usize = 32;
|
secret: &[u8],
|
||||||
|
timestamp: u32,
|
||||||
|
session_id: &[u8],
|
||||||
|
) -> Vec<u8> {
|
||||||
|
let session_id_len = session_id.len();
|
||||||
|
assert!(session_id_len <= u8::MAX as usize);
|
||||||
let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len;
|
let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len;
|
||||||
let mut handshake = vec![0x42u8; len];
|
let mut handshake = vec![0x42u8; len];
|
||||||
|
|
||||||
handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8;
|
handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8;
|
||||||
|
let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1;
|
||||||
|
handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id);
|
||||||
// Zero the digest slot before computing HMAC (mirrors what validate does).
|
// Zero the digest slot before computing HMAC (mirrors what validate does).
|
||||||
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
|
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
|
||||||
|
|
||||||
|
|
@ -34,6 +41,10 @@ fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
|
||||||
handshake
|
handshake
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
|
||||||
|
make_valid_tls_handshake_with_session_id(secret, timestamp, &[0x42; 32])
|
||||||
|
}
|
||||||
|
|
||||||
// ------------------------------------------------------------------
|
// ------------------------------------------------------------------
|
||||||
// Happy-path sanity
|
// Happy-path sanity
|
||||||
// ------------------------------------------------------------------
|
// ------------------------------------------------------------------
|
||||||
|
|
@ -311,6 +322,20 @@ fn too_short_handshake_rejected_without_panic() {
|
||||||
assert!(validate_tls_handshake(&[], &secrets, true).is_none());
|
assert!(validate_tls_handshake(&[], &secrets, true).is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn all_prefix_lengths_below_minimum_rejected_without_panic() {
|
||||||
|
let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1;
|
||||||
|
let secrets = vec![("u".to_string(), b"s".to_vec())];
|
||||||
|
|
||||||
|
for len in 0..min_len {
|
||||||
|
let h = vec![0u8; len];
|
||||||
|
assert!(
|
||||||
|
validate_tls_handshake(&h, &secrets, true).is_none(),
|
||||||
|
"prefix length {len} below minimum must be rejected"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn claimed_session_id_overflows_buffer_rejected() {
|
fn claimed_session_id_overflows_buffer_rejected() {
|
||||||
let session_id_len: usize = 32;
|
let session_id_len: usize = 32;
|
||||||
|
|
@ -332,6 +357,30 @@ fn max_session_id_len_255_does_not_panic() {
|
||||||
assert!(validate_tls_handshake(&h, &secrets, true).is_none());
|
assert!(validate_tls_handshake(&h, &secrets, true).is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn one_byte_session_id_validates_and_is_preserved() {
|
||||||
|
let secret = b"sid_len_1_test";
|
||||||
|
let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &[0xAB]);
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
|
||||||
|
let result = validate_tls_handshake(&handshake, &secrets, true)
|
||||||
|
.expect("one-byte session_id handshake must validate");
|
||||||
|
assert_eq!(result.session_id, vec![0xAB]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn max_session_id_len_255_with_valid_digest_is_accepted() {
|
||||||
|
let secret = b"sid_len_255_test";
|
||||||
|
let session_id = vec![0xCCu8; 255];
|
||||||
|
let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &session_id);
|
||||||
|
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||||
|
|
||||||
|
let result = validate_tls_handshake(&handshake, &secrets, true)
|
||||||
|
.expect("session_id_len=255 with valid digest must validate");
|
||||||
|
assert_eq!(result.session_id.len(), 255);
|
||||||
|
assert_eq!(result.session_id, session_id);
|
||||||
|
}
|
||||||
|
|
||||||
// ------------------------------------------------------------------
|
// ------------------------------------------------------------------
|
||||||
// Adversarial digest values
|
// Adversarial digest values
|
||||||
// ------------------------------------------------------------------
|
// ------------------------------------------------------------------
|
||||||
|
|
@ -867,6 +916,23 @@ fn test_parse_tls_record_header() {
|
||||||
assert_eq!(result.1, 16384);
|
assert_eq!(result.1, 16384);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tls_record_header_rejects_invalid_versions() {
|
||||||
|
let invalid = [
|
||||||
|
[0x16, 0x03, 0x00, 0x00, 0x10],
|
||||||
|
[0x16, 0x02, 0x00, 0x00, 0x10],
|
||||||
|
[0x16, 0x03, 0x02, 0x00, 0x10],
|
||||||
|
[0x16, 0x04, 0x00, 0x00, 0x10],
|
||||||
|
];
|
||||||
|
for header in invalid {
|
||||||
|
assert!(
|
||||||
|
parse_tls_record_header(&header).is_none(),
|
||||||
|
"invalid TLS record version {:?} must be rejected",
|
||||||
|
[header[1], header[2]]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_gen_fake_x25519_key() {
|
fn test_gen_fake_x25519_key() {
|
||||||
let rng = crate::crypto::SecureRandom::new();
|
let rng = crate::crypto::SecureRandom::new();
|
||||||
|
|
@ -1168,6 +1234,47 @@ fn extract_sni_rejects_when_extension_block_is_truncated() {
|
||||||
assert!(extract_sni_from_client_hello(&ch).is_none());
|
assert!(extract_sni_from_client_hello(&ch).is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_sni_rejects_session_id_len_overflow() {
|
||||||
|
let mut ch = build_client_hello_with_exts(Vec::new(), "example.com");
|
||||||
|
let sid_len_pos = 5 + 4 + 2 + 32;
|
||||||
|
ch[sid_len_pos] = 255;
|
||||||
|
assert!(extract_sni_from_client_hello(&ch).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_sni_rejects_cipher_suites_len_overflow() {
|
||||||
|
let mut ch = build_client_hello_with_exts(Vec::new(), "example.com");
|
||||||
|
let sid_len_pos = 5 + 4 + 2 + 32;
|
||||||
|
let cipher_len_pos = sid_len_pos + 1 + ch[sid_len_pos] as usize;
|
||||||
|
ch[cipher_len_pos] = 0xFF;
|
||||||
|
ch[cipher_len_pos + 1] = 0xFF;
|
||||||
|
assert!(extract_sni_from_client_hello(&ch).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_sni_rejects_compression_methods_len_overflow() {
|
||||||
|
let mut ch = build_client_hello_with_exts(Vec::new(), "example.com");
|
||||||
|
let sid_len_pos = 5 + 4 + 2 + 32;
|
||||||
|
let cipher_len_pos = sid_len_pos + 1 + ch[sid_len_pos] as usize;
|
||||||
|
let cipher_len = u16::from_be_bytes([ch[cipher_len_pos], ch[cipher_len_pos + 1]]) as usize;
|
||||||
|
let comp_len_pos = cipher_len_pos + 2 + cipher_len;
|
||||||
|
ch[comp_len_pos] = 0xFF;
|
||||||
|
assert!(extract_sni_from_client_hello(&ch).is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_alpn_returns_empty_on_session_id_len_overflow() {
|
||||||
|
let mut alpn_data = Vec::new();
|
||||||
|
alpn_data.extend_from_slice(&3u16.to_be_bytes());
|
||||||
|
alpn_data.push(2);
|
||||||
|
alpn_data.extend_from_slice(b"h2");
|
||||||
|
let mut ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test");
|
||||||
|
let sid_len_pos = 5 + 4 + 2 + 32;
|
||||||
|
ch[sid_len_pos] = 255;
|
||||||
|
assert!(extract_alpn_from_client_hello(&ch).is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn extract_alpn_rejects_when_extension_block_is_truncated() {
|
fn extract_alpn_rejects_when_extension_block_is_truncated() {
|
||||||
let mut ext_blob = Vec::new();
|
let mut ext_blob = Vec::new();
|
||||||
|
|
|
||||||
|
|
@ -4,11 +4,14 @@
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::net::IpAddr;
|
use std::net::{IpAddr, Ipv6Addr};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::{Mutex, OnceLock};
|
use std::sync::{Mutex, OnceLock};
|
||||||
|
use std::collections::hash_map::DefaultHasher;
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
|
use dashmap::mapref::entry::Entry;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||||
use tracing::{debug, warn, trace};
|
use tracing::{debug, warn, trace};
|
||||||
use zeroize::Zeroize;
|
use zeroize::Zeroize;
|
||||||
|
|
@ -57,6 +60,16 @@ fn auth_probe_state_map() -> &'static DashMap<IpAddr, AuthProbeState> {
|
||||||
AUTH_PROBE_STATE.get_or_init(DashMap::new)
|
AUTH_PROBE_STATE.get_or_init(DashMap::new)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr {
|
||||||
|
match peer_ip {
|
||||||
|
IpAddr::V4(ip) => IpAddr::V4(ip),
|
||||||
|
IpAddr::V6(ip) => {
|
||||||
|
let [a, b, c, d, _, _, _, _] = ip.segments();
|
||||||
|
IpAddr::V6(Ipv6Addr::new(a, b, c, d, 0, 0, 0, 0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn auth_probe_backoff(fail_streak: u32) -> Duration {
|
fn auth_probe_backoff(fail_streak: u32) -> Duration {
|
||||||
if fail_streak < AUTH_PROBE_BACKOFF_START_FAILS {
|
if fail_streak < AUTH_PROBE_BACKOFF_START_FAILS {
|
||||||
return Duration::ZERO;
|
return Duration::ZERO;
|
||||||
|
|
@ -74,7 +87,15 @@ fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool {
|
||||||
now.duration_since(state.last_seen) > retention
|
now.duration_since(state.last_seen) > retention
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
|
||||||
|
let mut hasher = DefaultHasher::new();
|
||||||
|
peer_ip.hash(&mut hasher);
|
||||||
|
now.hash(&mut hasher);
|
||||||
|
hasher.finish() as usize
|
||||||
|
}
|
||||||
|
|
||||||
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
|
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
|
||||||
|
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||||
let state = auth_probe_state_map();
|
let state = auth_probe_state_map();
|
||||||
let Some(entry) = state.get(&peer_ip) else {
|
let Some(entry) = state.get(&peer_ip) else {
|
||||||
return false;
|
return false;
|
||||||
|
|
@ -88,6 +109,7 @@ fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) {
|
fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) {
|
||||||
|
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||||
let state = auth_probe_state_map();
|
let state = auth_probe_state_map();
|
||||||
auth_probe_record_failure_with_state(state, peer_ip, now);
|
auth_probe_record_failure_with_state(state, peer_ip, now);
|
||||||
}
|
}
|
||||||
|
|
@ -97,24 +119,35 @@ fn auth_probe_record_failure_with_state(
|
||||||
peer_ip: IpAddr,
|
peer_ip: IpAddr,
|
||||||
now: Instant,
|
now: Instant,
|
||||||
) {
|
) {
|
||||||
if let Some(mut entry) = state.get_mut(&peer_ip) {
|
let make_new_state = || AuthProbeState {
|
||||||
if auth_probe_state_expired(&entry, now) {
|
fail_streak: 1,
|
||||||
*entry = AuthProbeState {
|
blocked_until: now + auth_probe_backoff(1),
|
||||||
fail_streak: 1,
|
last_seen: now,
|
||||||
blocked_until: now + auth_probe_backoff(1),
|
};
|
||||||
last_seen: now,
|
|
||||||
};
|
let update_existing = |entry: &mut AuthProbeState| {
|
||||||
|
if auth_probe_state_expired(entry, now) {
|
||||||
|
*entry = make_new_state();
|
||||||
|
} else {
|
||||||
|
entry.fail_streak = entry.fail_streak.saturating_add(1);
|
||||||
|
entry.last_seen = now;
|
||||||
|
entry.blocked_until = now + auth_probe_backoff(entry.fail_streak);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match state.entry(peer_ip) {
|
||||||
|
Entry::Occupied(mut entry) => {
|
||||||
|
update_existing(entry.get_mut());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
entry.fail_streak = entry.fail_streak.saturating_add(1);
|
Entry::Vacant(_) => {}
|
||||||
entry.last_seen = now;
|
}
|
||||||
entry.blocked_until = now + auth_probe_backoff(entry.fail_streak);
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||||
let mut stale_keys = Vec::new();
|
let mut stale_keys = Vec::new();
|
||||||
|
let mut eviction_candidates = Vec::new();
|
||||||
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
|
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
|
||||||
|
eviction_candidates.push(*entry.key());
|
||||||
if auth_probe_state_expired(entry.value(), now) {
|
if auth_probe_state_expired(entry.value(), now) {
|
||||||
stale_keys.push(*entry.key());
|
stale_keys.push(*entry.key());
|
||||||
}
|
}
|
||||||
|
|
@ -123,23 +156,27 @@ fn auth_probe_record_failure_with_state(
|
||||||
state.remove(&stale_key);
|
state.remove(&stale_key);
|
||||||
}
|
}
|
||||||
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||||
return;
|
if eviction_candidates.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let idx = auth_probe_eviction_offset(peer_ip, now) % eviction_candidates.len();
|
||||||
|
let evict_key = eviction_candidates[idx];
|
||||||
|
state.remove(&evict_key);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
state.insert(peer_ip, AuthProbeState {
|
match state.entry(peer_ip) {
|
||||||
fail_streak: 0,
|
Entry::Occupied(mut entry) => {
|
||||||
blocked_until: now,
|
update_existing(entry.get_mut());
|
||||||
last_seen: now,
|
}
|
||||||
});
|
Entry::Vacant(entry) => {
|
||||||
|
entry.insert(make_new_state());
|
||||||
if let Some(mut entry) = state.get_mut(&peer_ip) {
|
}
|
||||||
entry.fail_streak = 1;
|
|
||||||
entry.blocked_until = now + auth_probe_backoff(1);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn auth_probe_record_success(peer_ip: IpAddr) {
|
fn auth_probe_record_success(peer_ip: IpAddr) {
|
||||||
|
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||||
let state = auth_probe_state_map();
|
let state = auth_probe_state_map();
|
||||||
state.remove(&peer_ip);
|
state.remove(&peer_ip);
|
||||||
}
|
}
|
||||||
|
|
@ -153,6 +190,7 @@ fn clear_auth_probe_state_for_testing() {
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
fn auth_probe_fail_streak_for_testing(peer_ip: IpAddr) -> Option<u32> {
|
fn auth_probe_fail_streak_for_testing(peer_ip: IpAddr) -> Option<u32> {
|
||||||
|
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||||
let state = AUTH_PROBE_STATE.get()?;
|
let state = AUTH_PROBE_STATE.get()?;
|
||||||
state.get(&peer_ip).map(|entry| entry.fail_streak)
|
state.get(&peer_ip).map(|entry| entry.fail_streak)
|
||||||
}
|
}
|
||||||
|
|
@ -177,6 +215,12 @@ fn clear_warned_secrets_for_testing() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn warned_secrets_test_lock() -> &'static Mutex<()> {
|
||||||
|
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||||
|
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||||
|
}
|
||||||
|
|
||||||
fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option<usize>) {
|
fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option<usize>) {
|
||||||
let key = (name.to_string(), reason.to_string());
|
let key = (name.to_string(), reason.to_string());
|
||||||
let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new()));
|
let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new()));
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ use dashmap::DashMap;
|
||||||
use std::net::{IpAddr, Ipv4Addr};
|
use std::net::{IpAddr, Ipv4Addr};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::sync::Barrier;
|
||||||
|
|
||||||
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
|
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
|
||||||
let session_id_len: usize = 32;
|
let session_id_len: usize = 32;
|
||||||
|
|
@ -84,7 +85,6 @@ fn make_valid_tls_client_hello_with_alpn(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
|
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
|
||||||
clear_auth_probe_state_for_testing();
|
|
||||||
let mut cfg = ProxyConfig::default();
|
let mut cfg = ProxyConfig::default();
|
||||||
cfg.access.users.clear();
|
cfg.access.users.clear();
|
||||||
cfg.access
|
cfg.access
|
||||||
|
|
@ -369,6 +369,9 @@ async fn invalid_tls_probe_does_not_pollute_replay_cache() {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn empty_decoded_secret_is_rejected() {
|
async fn empty_decoded_secret_is_rejected() {
|
||||||
|
let _guard = warned_secrets_test_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||||
clear_warned_secrets_for_testing();
|
clear_warned_secrets_for_testing();
|
||||||
let config = test_config_with_secret_hex("");
|
let config = test_config_with_secret_hex("");
|
||||||
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
||||||
|
|
@ -393,6 +396,9 @@ async fn empty_decoded_secret_is_rejected() {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn wrong_length_decoded_secret_is_rejected() {
|
async fn wrong_length_decoded_secret_is_rejected() {
|
||||||
|
let _guard = warned_secrets_test_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||||
clear_warned_secrets_for_testing();
|
clear_warned_secrets_for_testing();
|
||||||
let config = test_config_with_secret_hex("aa");
|
let config = test_config_with_secret_hex("aa");
|
||||||
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
||||||
|
|
@ -443,6 +449,12 @@ async fn invalid_mtproto_probe_does_not_pollute_replay_cache() {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn mixed_secret_lengths_keep_valid_user_authenticating() {
|
async fn mixed_secret_lengths_keep_valid_user_authenticating() {
|
||||||
|
let _probe_guard = auth_probe_test_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||||
|
let _guard = warned_secrets_test_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||||
clear_warned_secrets_for_testing();
|
clear_warned_secrets_for_testing();
|
||||||
clear_auth_probe_state_for_testing();
|
clear_auth_probe_state_for_testing();
|
||||||
let good_secret = [0x22u8; 16];
|
let good_secret = [0x22u8; 16];
|
||||||
|
|
@ -708,6 +720,9 @@ fn mode_policy_matrix_is_stable_for_all_tag_transport_mode_combinations() {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn invalid_secret_warning_keys_do_not_collide_on_colon_boundaries() {
|
fn invalid_secret_warning_keys_do_not_collide_on_colon_boundaries() {
|
||||||
|
let _guard = warned_secrets_test_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||||
clear_warned_secrets_for_testing();
|
clear_warned_secrets_for_testing();
|
||||||
|
|
||||||
warn_invalid_secret_once("a:b", "c", ACCESS_SECRET_BYTES, Some(1));
|
warn_invalid_secret_once("a:b", "c", ACCESS_SECRET_BYTES, Some(1));
|
||||||
|
|
@ -755,8 +770,9 @@ async fn repeated_invalid_tls_probes_trigger_pre_auth_throttle() {
|
||||||
}
|
}
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
auth_probe_is_throttled_for_testing(peer.ip()),
|
auth_probe_fail_streak_for_testing(peer.ip())
|
||||||
"invalid probe burst must activate per-IP pre-auth throttle"
|
.is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS),
|
||||||
|
"invalid probe burst must grow pre-auth failure streak to backoff threshold"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -855,7 +871,7 @@ fn auth_probe_capacity_prunes_stale_entries_for_new_ips() {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn auth_probe_capacity_stays_fail_closed_when_map_is_fresh_and_full() {
|
fn auth_probe_capacity_forces_bounded_eviction_when_map_is_fresh_and_full() {
|
||||||
let state = DashMap::new();
|
let state = DashMap::new();
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
|
|
||||||
|
|
@ -880,12 +896,215 @@ fn auth_probe_capacity_stays_fail_closed_when_map_is_fresh_and_full() {
|
||||||
auth_probe_record_failure_with_state(&state, newcomer, now);
|
auth_probe_record_failure_with_state(&state, newcomer, now);
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
state.get(&newcomer).is_none(),
|
state.get(&newcomer).is_some(),
|
||||||
"when all entries are fresh and full, new probes must not be admitted"
|
"when all entries are fresh and full, one bounded eviction must admit a new probe source"
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
state.len(),
|
state.len(),
|
||||||
AUTH_PROBE_TRACK_MAX_ENTRIES,
|
AUTH_PROBE_TRACK_MAX_ENTRIES,
|
||||||
"auth probe map must stay at the configured cap"
|
"auth probe map must stay at the configured cap after forced eviction"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_probe_ipv6_is_bucketed_by_prefix_64() {
|
||||||
|
let state = DashMap::new();
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
let ip_a = IpAddr::V6("2001:db8:abcd:1234:1:2:3:4".parse().unwrap());
|
||||||
|
let ip_b = IpAddr::V6("2001:db8:abcd:1234:ffff:eeee:dddd:cccc".parse().unwrap());
|
||||||
|
|
||||||
|
auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_a), now);
|
||||||
|
auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_b), now);
|
||||||
|
|
||||||
|
let normalized = normalize_auth_probe_ip(ip_a);
|
||||||
|
assert_eq!(
|
||||||
|
state.len(),
|
||||||
|
1,
|
||||||
|
"IPv6 sources in the same /64 must share one pre-auth throttle bucket"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
state.get(&normalized).map(|entry| entry.fail_streak),
|
||||||
|
Some(2),
|
||||||
|
"failures from the same /64 must accumulate in one throttle state"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_probe_ipv6_different_prefixes_use_distinct_buckets() {
|
||||||
|
let state = DashMap::new();
|
||||||
|
let now = Instant::now();
|
||||||
|
|
||||||
|
let ip_a = IpAddr::V6("2001:db8:1111:2222:1:2:3:4".parse().unwrap());
|
||||||
|
let ip_b = IpAddr::V6("2001:db8:1111:3333:1:2:3:4".parse().unwrap());
|
||||||
|
|
||||||
|
auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_a), now);
|
||||||
|
auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_b), now);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
state.len(),
|
||||||
|
2,
|
||||||
|
"different IPv6 /64 prefixes must not share throttle buckets"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
state.get(&normalize_auth_probe_ip(ip_a)).map(|entry| entry.fail_streak),
|
||||||
|
Some(1)
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
state.get(&normalize_auth_probe_ip(ip_b)).map(|entry| entry.fail_streak),
|
||||||
|
Some(1)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_probe_success_clears_whole_ipv6_prefix_bucket() {
|
||||||
|
let _guard = auth_probe_test_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||||
|
clear_auth_probe_state_for_testing();
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
let ip_fail = IpAddr::V6("2001:db8:aaaa:bbbb:1:2:3:4".parse().unwrap());
|
||||||
|
let ip_success = IpAddr::V6("2001:db8:aaaa:bbbb:ffff:eeee:dddd:cccc".parse().unwrap());
|
||||||
|
|
||||||
|
auth_probe_record_failure(ip_fail, now);
|
||||||
|
assert_eq!(
|
||||||
|
auth_probe_fail_streak_for_testing(ip_fail),
|
||||||
|
Some(1),
|
||||||
|
"precondition: normalized prefix bucket must exist"
|
||||||
|
);
|
||||||
|
|
||||||
|
auth_probe_record_success(ip_success);
|
||||||
|
assert_eq!(
|
||||||
|
auth_probe_fail_streak_for_testing(ip_fail),
|
||||||
|
None,
|
||||||
|
"success from the same /64 must clear the shared bucket"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_probe_eviction_offset_varies_with_input() {
|
||||||
|
let now = Instant::now();
|
||||||
|
let ip1 = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 10));
|
||||||
|
let ip2 = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 11));
|
||||||
|
|
||||||
|
let a = auth_probe_eviction_offset(ip1, now);
|
||||||
|
let b = auth_probe_eviction_offset(ip1, now);
|
||||||
|
let c = auth_probe_eviction_offset(ip2, now);
|
||||||
|
|
||||||
|
assert_eq!(a, b, "same input must yield deterministic offset");
|
||||||
|
assert_ne!(a, c, "different peer IPs should not collapse to one offset");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||||
|
async fn auth_probe_concurrent_failures_do_not_lose_fail_streak_updates() {
|
||||||
|
let _guard = auth_probe_test_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||||
|
clear_auth_probe_state_for_testing();
|
||||||
|
|
||||||
|
let peer_ip: IpAddr = "198.51.100.90".parse().unwrap();
|
||||||
|
let tasks = 128usize;
|
||||||
|
let barrier = Arc::new(Barrier::new(tasks));
|
||||||
|
let mut handles = Vec::with_capacity(tasks);
|
||||||
|
|
||||||
|
for _ in 0..tasks {
|
||||||
|
let barrier = barrier.clone();
|
||||||
|
handles.push(tokio::spawn(async move {
|
||||||
|
barrier.wait().await;
|
||||||
|
auth_probe_record_failure(peer_ip, Instant::now());
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for handle in handles {
|
||||||
|
handle
|
||||||
|
.await
|
||||||
|
.expect("concurrent failure recording task must not panic");
|
||||||
|
}
|
||||||
|
|
||||||
|
let streak = auth_probe_fail_streak_for_testing(peer_ip)
|
||||||
|
.expect("tracked peer must exist after concurrent failure burst");
|
||||||
|
assert_eq!(
|
||||||
|
streak as usize,
|
||||||
|
tasks,
|
||||||
|
"concurrent failures for one source must account every attempt"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||||
|
async fn invalid_probe_noise_from_other_ips_does_not_break_valid_tls_handshake() {
|
||||||
|
let _guard = auth_probe_test_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||||
|
clear_auth_probe_state_for_testing();
|
||||||
|
|
||||||
|
let secret = [0x31u8; 16];
|
||||||
|
let config = Arc::new(test_config_with_secret_hex("31313131313131313131313131313131"));
|
||||||
|
let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60)));
|
||||||
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
let victim_peer: SocketAddr = "198.51.100.91:44391".parse().unwrap();
|
||||||
|
let valid = Arc::new(make_valid_tls_handshake(&secret, 0));
|
||||||
|
|
||||||
|
let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32];
|
||||||
|
invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32;
|
||||||
|
let invalid = Arc::new(invalid);
|
||||||
|
|
||||||
|
let mut noise_tasks = Vec::new();
|
||||||
|
for idx in 0..96u16 {
|
||||||
|
let config = config.clone();
|
||||||
|
let replay_checker = replay_checker.clone();
|
||||||
|
let rng = rng.clone();
|
||||||
|
let invalid = invalid.clone();
|
||||||
|
noise_tasks.push(tokio::spawn(async move {
|
||||||
|
let octet = ((idx % 200) + 1) as u8;
|
||||||
|
let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, octet)), 45000 + idx);
|
||||||
|
let result = handle_tls_handshake(
|
||||||
|
&invalid,
|
||||||
|
tokio::io::empty(),
|
||||||
|
tokio::io::sink(),
|
||||||
|
peer,
|
||||||
|
&config,
|
||||||
|
&replay_checker,
|
||||||
|
&rng,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
let victim_config = config.clone();
|
||||||
|
let victim_replay_checker = replay_checker.clone();
|
||||||
|
let victim_rng = rng.clone();
|
||||||
|
let victim_valid = valid.clone();
|
||||||
|
let victim_task = tokio::spawn(async move {
|
||||||
|
handle_tls_handshake(
|
||||||
|
&victim_valid,
|
||||||
|
tokio::io::empty(),
|
||||||
|
tokio::io::sink(),
|
||||||
|
victim_peer,
|
||||||
|
&victim_config,
|
||||||
|
&victim_replay_checker,
|
||||||
|
&victim_rng,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
});
|
||||||
|
|
||||||
|
for task in noise_tasks {
|
||||||
|
task.await.expect("noise task must not panic");
|
||||||
|
}
|
||||||
|
|
||||||
|
let victim_result = victim_task
|
||||||
|
.await
|
||||||
|
.expect("victim handshake task must not panic");
|
||||||
|
assert!(
|
||||||
|
matches!(victim_result, HandshakeResult::Success(_)),
|
||||||
|
"invalid probe noise from other IPs must not block a valid victim handshake"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
auth_probe_fail_streak_for_testing(victim_peer.ip()),
|
||||||
|
None,
|
||||||
|
"successful victim handshake must not retain pre-auth failure streak"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -223,10 +223,10 @@ async fn relay_to_mask<R, W, MR, MW>(
|
||||||
initial_data: &[u8],
|
initial_data: &[u8],
|
||||||
)
|
)
|
||||||
where
|
where
|
||||||
R: AsyncRead + Unpin + Send,
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
W: AsyncWrite + Unpin + Send,
|
W: AsyncWrite + Unpin + Send + 'static,
|
||||||
MR: AsyncRead + Unpin + Send,
|
MR: AsyncRead + Unpin + Send + 'static,
|
||||||
MW: AsyncWrite + Unpin + Send,
|
MW: AsyncWrite + Unpin + Send + 'static,
|
||||||
{
|
{
|
||||||
// Send initial data to mask host
|
// Send initial data to mask host
|
||||||
if mask_write.write_all(initial_data).await.is_err() {
|
if mask_write.write_all(initial_data).await.is_err() {
|
||||||
|
|
@ -236,39 +236,16 @@ where
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut client_buf = vec![0u8; MASK_BUFFER_SIZE];
|
let _ = tokio::join!(
|
||||||
let mut mask_buf = vec![0u8; MASK_BUFFER_SIZE];
|
async {
|
||||||
|
let _ = tokio::io::copy(&mut reader, &mut mask_write).await;
|
||||||
loop {
|
let _ = mask_write.shutdown().await;
|
||||||
tokio::select! {
|
},
|
||||||
client_read = reader.read(&mut client_buf) => {
|
async {
|
||||||
match client_read {
|
let _ = tokio::io::copy(&mut mask_read, &mut writer).await;
|
||||||
Ok(0) | Err(_) => {
|
let _ = writer.shutdown().await;
|
||||||
let _ = mask_write.shutdown().await;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Ok(n) => {
|
|
||||||
if mask_write.write_all(&client_buf[..n]).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mask_read_res = mask_read.read(&mut mask_buf) => {
|
|
||||||
match mask_read_res {
|
|
||||||
Ok(0) | Err(_) => {
|
|
||||||
let _ = writer.shutdown().await;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Ok(n) => {
|
|
||||||
if writer.write_all(&mask_buf[..n]).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Just consume all data from client without responding
|
/// Just consume all data from client without responding
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,14 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll};
|
||||||
use tokio::io::{duplex, AsyncBufReadExt, BufReader};
|
use tokio::io::{duplex, AsyncBufReadExt, BufReader};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
use tokio::net::UnixListener;
|
use tokio::net::UnixListener;
|
||||||
use tokio::time::{timeout, Duration};
|
use tokio::time::{sleep, timeout, Duration};
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() {
|
async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() {
|
||||||
|
|
@ -542,9 +544,188 @@ impl tokio::io::AsyncWrite for PendingWriter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct DropTrackedPendingReader {
|
||||||
|
dropped: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl tokio::io::AsyncRead for DropTrackedPendingReader {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
_buf: &mut tokio::io::ReadBuf<'_>,
|
||||||
|
) -> Poll<std::io::Result<()>> {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for DropTrackedPendingReader {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.dropped.store(true, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct DropTrackedPendingWriter {
|
||||||
|
dropped: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl tokio::io::AsyncWrite for DropTrackedPendingWriter {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
_buf: &[u8],
|
||||||
|
) -> Poll<std::io::Result<usize>> {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for DropTrackedPendingWriter {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.dropped.store(true, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn proxy_header_write_timeout_returns_false() {
|
async fn proxy_header_write_timeout_returns_false() {
|
||||||
let mut writer = PendingWriter;
|
let mut writer = PendingWriter;
|
||||||
let ok = write_proxy_header_with_timeout(&mut writer, b"PROXY UNKNOWN\r\n").await;
|
let ok = write_proxy_header_with_timeout(&mut writer, b"PROXY UNKNOWN\r\n").await;
|
||||||
assert!(!ok, "Proxy header writes that never complete must time out");
|
assert!(!ok, "Proxy header writes that never complete must time out");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stalls() {
|
||||||
|
let (mut client_feed_writer, client_feed_reader) = duplex(64);
|
||||||
|
let (mut client_visible_reader, client_visible_writer) = duplex(64);
|
||||||
|
let (mut backend_feed_writer, backend_feed_reader) = duplex(64);
|
||||||
|
|
||||||
|
// Make client->mask direction immediately active so the c2m path blocks on PendingWriter.
|
||||||
|
client_feed_writer.write_all(b"X").await.unwrap();
|
||||||
|
|
||||||
|
let relay = tokio::spawn(async move {
|
||||||
|
relay_to_mask(
|
||||||
|
client_feed_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
backend_feed_reader,
|
||||||
|
PendingWriter,
|
||||||
|
b"",
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Allow relay tasks to start, then emulate mask backend response.
|
||||||
|
sleep(Duration::from_millis(20)).await;
|
||||||
|
backend_feed_writer.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap();
|
||||||
|
backend_feed_writer.shutdown().await.unwrap();
|
||||||
|
|
||||||
|
let mut observed = vec![0u8; 19];
|
||||||
|
timeout(Duration::from_secs(1), client_visible_reader.read_exact(&mut observed))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(observed, b"HTTP/1.1 200 OK\r\n\r\n");
|
||||||
|
|
||||||
|
relay.abort();
|
||||||
|
let _ = relay.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_to_mask_preserves_backend_response_after_client_half_close() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let backend_addr = listener.local_addr().unwrap();
|
||||||
|
let request = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
|
||||||
|
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
|
||||||
|
|
||||||
|
let backend_task = tokio::spawn({
|
||||||
|
let request = request.clone();
|
||||||
|
let response = response.clone();
|
||||||
|
async move {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut observed_req = vec![0u8; request.len()];
|
||||||
|
stream.read_exact(&mut observed_req).await.unwrap();
|
||||||
|
assert_eq!(observed_req, request);
|
||||||
|
stream.write_all(&response).await.unwrap();
|
||||||
|
stream.shutdown().await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = backend_addr.port();
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.77:55001".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
let (mut client_write, client_read) = duplex(1024);
|
||||||
|
let (mut client_visible_reader, client_visible_writer) = duplex(2048);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let fallback_task = tokio::spawn(async move {
|
||||||
|
handle_bad_client(
|
||||||
|
client_read,
|
||||||
|
client_visible_writer,
|
||||||
|
&request,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
|
||||||
|
client_write.shutdown().await.unwrap();
|
||||||
|
|
||||||
|
let mut observed_resp = vec![0u8; response.len()];
|
||||||
|
timeout(Duration::from_secs(1), client_visible_reader.read_exact(&mut observed_resp))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(observed_resp, response);
|
||||||
|
|
||||||
|
timeout(Duration::from_secs(1), fallback_task).await.unwrap().unwrap();
|
||||||
|
timeout(Duration::from_secs(1), backend_task).await.unwrap().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
|
||||||
|
let reader_dropped = Arc::new(AtomicBool::new(false));
|
||||||
|
let writer_dropped = Arc::new(AtomicBool::new(false));
|
||||||
|
let mask_reader_dropped = Arc::new(AtomicBool::new(false));
|
||||||
|
let mask_writer_dropped = Arc::new(AtomicBool::new(false));
|
||||||
|
|
||||||
|
let reader = DropTrackedPendingReader {
|
||||||
|
dropped: reader_dropped.clone(),
|
||||||
|
};
|
||||||
|
let writer = DropTrackedPendingWriter {
|
||||||
|
dropped: writer_dropped.clone(),
|
||||||
|
};
|
||||||
|
let mask_read = DropTrackedPendingReader {
|
||||||
|
dropped: mask_reader_dropped.clone(),
|
||||||
|
};
|
||||||
|
let mask_write = DropTrackedPendingWriter {
|
||||||
|
dropped: mask_writer_dropped.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let timed = timeout(
|
||||||
|
Duration::from_millis(40),
|
||||||
|
relay_to_mask(reader, writer, mask_read, mask_write, b""),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(timed.is_err(), "stalled relay must be bounded by timeout");
|
||||||
|
|
||||||
|
assert!(reader_dropped.load(Ordering::SeqCst));
|
||||||
|
assert!(writer_dropped.load(Ordering::SeqCst));
|
||||||
|
assert!(mask_reader_dropped.load(Ordering::SeqCst));
|
||||||
|
assert!(mask_writer_dropped.load(Ordering::SeqCst));
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ use std::time::{Duration, Instant};
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
|
|
||||||
use bytes::{Bytes, BytesMut};
|
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
use tokio::sync::{mpsc, oneshot, watch};
|
use tokio::sync::{mpsc, oneshot, watch};
|
||||||
|
|
@ -24,11 +23,11 @@ use crate::proxy::route_mode::{
|
||||||
cutover_stagger_delay,
|
cutover_stagger_delay,
|
||||||
};
|
};
|
||||||
use crate::stats::Stats;
|
use crate::stats::Stats;
|
||||||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
|
||||||
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
|
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
|
||||||
|
|
||||||
enum C2MeCommand {
|
enum C2MeCommand {
|
||||||
Data { payload: Bytes, flags: u32 },
|
Data { payload: PooledBuffer, flags: u32 },
|
||||||
Close,
|
Close,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -107,7 +106,11 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
|
||||||
|
|
||||||
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
|
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
|
||||||
let mut stale_keys = Vec::new();
|
let mut stale_keys = Vec::new();
|
||||||
|
let mut eviction_candidate = None;
|
||||||
for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) {
|
for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) {
|
||||||
|
if eviction_candidate.is_none() {
|
||||||
|
eviction_candidate = Some(*entry.key());
|
||||||
|
}
|
||||||
if now.duration_since(*entry.value()) >= DESYNC_DEDUP_WINDOW {
|
if now.duration_since(*entry.value()) >= DESYNC_DEDUP_WINDOW {
|
||||||
stale_keys.push(*entry.key());
|
stale_keys.push(*entry.key());
|
||||||
}
|
}
|
||||||
|
|
@ -116,6 +119,11 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
|
||||||
dedup.remove(&stale_key);
|
dedup.remove(&stale_key);
|
||||||
}
|
}
|
||||||
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
|
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
|
||||||
|
let Some(evict_key) = eviction_candidate else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
dedup.remove(&evict_key);
|
||||||
|
dedup.insert(key, now);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -677,7 +685,7 @@ async fn read_client_payload<R>(
|
||||||
forensics: &RelayForensicsState,
|
forensics: &RelayForensicsState,
|
||||||
frame_counter: &mut u64,
|
frame_counter: &mut u64,
|
||||||
stats: &Stats,
|
stats: &Stats,
|
||||||
) -> Result<Option<(Bytes, bool)>>
|
) -> Result<Option<(PooledBuffer, bool)>>
|
||||||
where
|
where
|
||||||
R: AsyncRead + Unpin + Send + 'static,
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
{
|
{
|
||||||
|
|
@ -784,25 +792,21 @@ where
|
||||||
len
|
len
|
||||||
};
|
};
|
||||||
|
|
||||||
let chunk_cap = buffer_pool.buffer_size().max(1024);
|
let mut payload = buffer_pool.get();
|
||||||
let mut payload = BytesMut::with_capacity(len.min(chunk_cap));
|
payload.clear();
|
||||||
let mut remaining = len;
|
let current_cap = payload.capacity();
|
||||||
while remaining > 0 {
|
if current_cap < len {
|
||||||
let chunk_len = remaining.min(chunk_cap);
|
payload.reserve(len - current_cap);
|
||||||
let mut chunk = buffer_pool.get();
|
|
||||||
chunk.resize(chunk_len, 0);
|
|
||||||
read_exact_with_timeout(client_reader, &mut chunk[..chunk_len], frame_read_timeout)
|
|
||||||
.await?;
|
|
||||||
payload.extend_from_slice(&chunk[..chunk_len]);
|
|
||||||
remaining -= chunk_len;
|
|
||||||
}
|
}
|
||||||
|
payload.resize(len, 0);
|
||||||
|
read_exact_with_timeout(client_reader, &mut payload[..len], frame_read_timeout).await?;
|
||||||
|
|
||||||
// Secure Intermediate: strip validated trailing padding bytes.
|
// Secure Intermediate: strip validated trailing padding bytes.
|
||||||
if proto_tag == ProtoTag::Secure {
|
if proto_tag == ProtoTag::Secure {
|
||||||
payload.truncate(secure_payload_len);
|
payload.truncate(secure_payload_len);
|
||||||
}
|
}
|
||||||
*frame_counter += 1;
|
*frame_counter += 1;
|
||||||
return Ok(Some((payload.freeze(), quickack)));
|
return Ok(Some((payload, quickack)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use bytes::Bytes;
|
||||||
use crate::crypto::AesCtr;
|
use crate::crypto::AesCtr;
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
use crate::stats::Stats;
|
use crate::stats::Stats;
|
||||||
use crate::stream::{BufferPool, CryptoReader};
|
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::AtomicU64;
|
use std::sync::atomic::AtomicU64;
|
||||||
|
|
@ -9,6 +11,21 @@ use tokio::io::AsyncWriteExt;
|
||||||
use tokio::io::duplex;
|
use tokio::io::duplex;
|
||||||
use tokio::time::{Duration as TokioDuration, timeout};
|
use tokio::time::{Duration as TokioDuration, timeout};
|
||||||
|
|
||||||
|
fn make_pooled_payload(data: &[u8]) -> PooledBuffer {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4));
|
||||||
|
let mut payload = pool.get();
|
||||||
|
payload.resize(data.len(), 0);
|
||||||
|
payload[..data.len()].copy_from_slice(data);
|
||||||
|
payload
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_pooled_payload_from(pool: &Arc<BufferPool>, data: &[u8]) -> PooledBuffer {
|
||||||
|
let mut payload = pool.get();
|
||||||
|
payload.resize(data.len(), 0);
|
||||||
|
payload[..data.len()].copy_from_slice(data);
|
||||||
|
payload
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_yield_sender_only_on_budget_with_backlog() {
|
fn should_yield_sender_only_on_budget_with_backlog() {
|
||||||
assert!(!should_yield_c2me_sender(0, true));
|
assert!(!should_yield_c2me_sender(0, true));
|
||||||
|
|
@ -23,7 +40,7 @@ async fn enqueue_c2me_command_uses_try_send_fast_path() {
|
||||||
enqueue_c2me_command(
|
enqueue_c2me_command(
|
||||||
&tx,
|
&tx,
|
||||||
C2MeCommand::Data {
|
C2MeCommand::Data {
|
||||||
payload: Bytes::from_static(&[1, 2, 3]),
|
payload: make_pooled_payload(&[1, 2, 3]),
|
||||||
flags: 0,
|
flags: 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -47,7 +64,7 @@ async fn enqueue_c2me_command_uses_try_send_fast_path() {
|
||||||
async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
|
async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
|
||||||
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
|
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
|
||||||
tx.send(C2MeCommand::Data {
|
tx.send(C2MeCommand::Data {
|
||||||
payload: Bytes::from_static(&[9]),
|
payload: make_pooled_payload(&[9]),
|
||||||
flags: 9,
|
flags: 9,
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
|
|
@ -58,7 +75,7 @@ async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
|
||||||
enqueue_c2me_command(
|
enqueue_c2me_command(
|
||||||
&tx2,
|
&tx2,
|
||||||
C2MeCommand::Data {
|
C2MeCommand::Data {
|
||||||
payload: Bytes::from_static(&[7, 7]),
|
payload: make_pooled_payload(&[7, 7]),
|
||||||
flags: 7,
|
flags: 7,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -84,6 +101,74 @@ async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_c2me_command_closed_channel_recycles_payload() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(64, 4));
|
||||||
|
let payload = make_pooled_payload_from(&pool, &[1, 2, 3, 4]);
|
||||||
|
let (tx, rx) = mpsc::channel::<C2MeCommand>(1);
|
||||||
|
drop(rx);
|
||||||
|
|
||||||
|
let result = enqueue_c2me_command(
|
||||||
|
&tx,
|
||||||
|
C2MeCommand::Data {
|
||||||
|
payload,
|
||||||
|
flags: 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_err(), "closed queue must fail enqueue");
|
||||||
|
drop(result);
|
||||||
|
assert!(
|
||||||
|
pool.stats().pooled >= 1,
|
||||||
|
"payload must return to pool when enqueue fails on closed channel"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_c2me_command_full_then_closed_recycles_waiting_payload() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(64, 4));
|
||||||
|
let (tx, rx) = mpsc::channel::<C2MeCommand>(1);
|
||||||
|
|
||||||
|
tx.send(C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload_from(&pool, &[9]),
|
||||||
|
flags: 1,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tx2 = tx.clone();
|
||||||
|
let pool2 = pool.clone();
|
||||||
|
let blocked_send = tokio::spawn(async move {
|
||||||
|
enqueue_c2me_command(
|
||||||
|
&tx2,
|
||||||
|
C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload_from(&pool2, &[7, 7, 7]),
|
||||||
|
flags: 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(10)).await;
|
||||||
|
drop(rx);
|
||||||
|
|
||||||
|
let result = timeout(TokioDuration::from_secs(1), blocked_send)
|
||||||
|
.await
|
||||||
|
.expect("blocked send task must finish")
|
||||||
|
.expect("blocked send task must not panic");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
result.is_err(),
|
||||||
|
"closing receiver while sender is blocked must fail enqueue"
|
||||||
|
);
|
||||||
|
drop(result);
|
||||||
|
assert!(
|
||||||
|
pool.stats().pooled >= 2,
|
||||||
|
"both queued and blocked payloads must return to pool after channel close"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn desync_dedup_cache_is_bounded() {
|
fn desync_dedup_cache_is_bounded() {
|
||||||
let _guard = desync_dedup_test_lock()
|
let _guard = desync_dedup_test_lock()
|
||||||
|
|
@ -101,7 +186,7 @@ fn desync_dedup_cache_is_bounded() {
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
!should_emit_full_desync(u64::MAX, false, now),
|
!should_emit_full_desync(u64::MAX, false, now),
|
||||||
"new key above cap must be suppressed to bound memory"
|
"new key above cap must remain suppressed to avoid log amplification"
|
||||||
);
|
);
|
||||||
|
|
||||||
assert!(
|
assert!(
|
||||||
|
|
@ -110,6 +195,26 @@ fn desync_dedup_cache_is_bounded() {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn desync_dedup_full_cache_churn_stays_suppressed() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("desync dedup test lock must be available");
|
||||||
|
clear_desync_dedup_for_testing();
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
|
||||||
|
assert!(should_emit_full_desync(key, false, now));
|
||||||
|
}
|
||||||
|
|
||||||
|
for offset in 0..2048u64 {
|
||||||
|
assert!(
|
||||||
|
!should_emit_full_desync(u64::MAX - offset, false, now),
|
||||||
|
"fresh full-cache churn must remain suppressed under pressure"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn make_forensics_state() -> RelayForensicsState {
|
fn make_forensics_state() -> RelayForensicsState {
|
||||||
RelayForensicsState {
|
RelayForensicsState {
|
||||||
trace_id: 1,
|
trace_id: 1,
|
||||||
|
|
@ -130,6 +235,12 @@ fn make_crypto_reader(reader: tokio::io::DuplexStream) -> CryptoReader<tokio::io
|
||||||
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn make_crypto_writer(writer: tokio::io::DuplexStream) -> CryptoWriter<tokio::io::DuplexStream> {
|
||||||
|
let key = [0u8; 32];
|
||||||
|
let iv = 0u128;
|
||||||
|
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||||
|
}
|
||||||
|
|
||||||
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
|
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
|
||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
let iv = 0u128;
|
let iv = 0u128;
|
||||||
|
|
@ -199,3 +310,472 @@ async fn read_client_payload_times_out_on_payload_stall() {
|
||||||
"stalled payload body read must time out"
|
"stalled payload body read must time out"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_large_intermediate_frame_is_exact() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
|
||||||
|
let (reader, mut writer) = duplex(262_144);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let payload_len = buffer_pool.buffer_size().saturating_mul(3).max(65_537);
|
||||||
|
let mut plaintext = Vec::with_capacity(4 + payload_len);
|
||||||
|
plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes());
|
||||||
|
plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(31)));
|
||||||
|
|
||||||
|
let encrypted = encrypt_for_reader(&plaintext);
|
||||||
|
writer.write_all(&encrypted).await.unwrap();
|
||||||
|
|
||||||
|
let read = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
payload_len + 16,
|
||||||
|
TokioDuration::from_secs(1),
|
||||||
|
&buffer_pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("payload read must succeed")
|
||||||
|
.expect("frame must be present");
|
||||||
|
|
||||||
|
let (frame, quickack) = read;
|
||||||
|
assert!(!quickack, "quickack flag must be unset");
|
||||||
|
assert_eq!(frame.len(), payload_len, "payload size must match wire length");
|
||||||
|
for (idx, byte) in frame.iter().enumerate() {
|
||||||
|
assert_eq!(*byte, (idx as u8).wrapping_mul(31));
|
||||||
|
}
|
||||||
|
assert_eq!(frame_counter, 1, "exactly one frame must be counted");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_secure_strips_tail_padding_bytes() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
|
||||||
|
let (reader, mut writer) = duplex(1024);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let payload = [0x11u8, 0x22, 0x33, 0x44, 0xaa, 0xbb, 0xcc, 0xdd];
|
||||||
|
let tail = [0xeeu8, 0xff, 0x99];
|
||||||
|
let wire_len = payload.len() + tail.len();
|
||||||
|
|
||||||
|
let mut plaintext = Vec::with_capacity(4 + wire_len);
|
||||||
|
plaintext.extend_from_slice(&(wire_len as u32).to_le_bytes());
|
||||||
|
plaintext.extend_from_slice(&payload);
|
||||||
|
plaintext.extend_from_slice(&tail);
|
||||||
|
let encrypted = encrypt_for_reader(&plaintext);
|
||||||
|
writer.write_all(&encrypted).await.unwrap();
|
||||||
|
|
||||||
|
let read = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Secure,
|
||||||
|
1024,
|
||||||
|
TokioDuration::from_secs(1),
|
||||||
|
&buffer_pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("secure payload read must succeed")
|
||||||
|
.expect("secure frame must be present");
|
||||||
|
|
||||||
|
let (frame, quickack) = read;
|
||||||
|
assert!(!quickack, "quickack flag must be unset");
|
||||||
|
assert_eq!(frame.as_ref(), &payload);
|
||||||
|
assert_eq!(frame_counter, 1, "one secure frame must be counted");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_secure_rejects_wire_len_below_4() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
|
||||||
|
let (reader, mut writer) = duplex(1024);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let mut plaintext = Vec::with_capacity(7);
|
||||||
|
plaintext.extend_from_slice(&3u32.to_le_bytes());
|
||||||
|
plaintext.extend_from_slice(&[1u8, 2, 3]);
|
||||||
|
let encrypted = encrypt_for_reader(&plaintext);
|
||||||
|
writer.write_all(&encrypted).await.unwrap();
|
||||||
|
|
||||||
|
let result = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Secure,
|
||||||
|
1024,
|
||||||
|
TokioDuration::from_secs(1),
|
||||||
|
&buffer_pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(result, Err(ProxyError::Proxy(ref msg)) if msg.contains("Frame too small: 3")),
|
||||||
|
"secure wire length below 4 must be fail-closed by the frame-too-small guard"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_intermediate_skips_zero_len_frame() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
|
||||||
|
let (reader, mut writer) = duplex(1024);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let payload = [7u8, 6, 5, 4, 3, 2, 1, 0];
|
||||||
|
let mut plaintext = Vec::with_capacity(4 + 4 + payload.len());
|
||||||
|
plaintext.extend_from_slice(&0u32.to_le_bytes());
|
||||||
|
plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes());
|
||||||
|
plaintext.extend_from_slice(&payload);
|
||||||
|
let encrypted = encrypt_for_reader(&plaintext);
|
||||||
|
writer.write_all(&encrypted).await.unwrap();
|
||||||
|
|
||||||
|
let read = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
1024,
|
||||||
|
TokioDuration::from_secs(1),
|
||||||
|
&buffer_pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("intermediate payload read must succeed")
|
||||||
|
.expect("frame must be present");
|
||||||
|
|
||||||
|
let (frame, quickack) = read;
|
||||||
|
assert!(!quickack, "quickack flag must be unset");
|
||||||
|
assert_eq!(frame.as_ref(), &payload);
|
||||||
|
assert_eq!(frame_counter, 1, "zero-length frame must be skipped");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_abridged_extended_len_sets_quickack() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
|
||||||
|
let (reader, mut writer) = duplex(4096);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let payload_len = 4 * 130;
|
||||||
|
let len_words = (payload_len / 4) as u32;
|
||||||
|
let mut plaintext = Vec::with_capacity(1 + 3 + payload_len);
|
||||||
|
plaintext.push(0xff | 0x80);
|
||||||
|
let lw = len_words.to_le_bytes();
|
||||||
|
plaintext.extend_from_slice(&lw[..3]);
|
||||||
|
plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_add(17)));
|
||||||
|
|
||||||
|
let encrypted = encrypt_for_reader(&plaintext);
|
||||||
|
writer.write_all(&encrypted).await.unwrap();
|
||||||
|
|
||||||
|
let read = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Abridged,
|
||||||
|
payload_len + 16,
|
||||||
|
TokioDuration::from_secs(1),
|
||||||
|
&buffer_pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("abridged payload read must succeed")
|
||||||
|
.expect("frame must be present");
|
||||||
|
|
||||||
|
let (frame, quickack) = read;
|
||||||
|
assert!(quickack, "quickack bit must be propagated from abridged header");
|
||||||
|
assert_eq!(frame.len(), payload_len);
|
||||||
|
assert_eq!(frame_counter, 1, "one abridged frame must be counted");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_returns_buffer_to_pool_after_emit() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
|
||||||
|
let pool = Arc::new(BufferPool::with_config(64, 8));
|
||||||
|
pool.preallocate(1);
|
||||||
|
assert_eq!(pool.stats().pooled, 1, "precondition: one pooled buffer");
|
||||||
|
|
||||||
|
let (reader, mut writer) = duplex(4096);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
// Force growth beyond default pool buffer size to catch ownership-take regressions.
|
||||||
|
let payload_len = 257usize;
|
||||||
|
let mut plaintext = Vec::with_capacity(4 + payload_len);
|
||||||
|
plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes());
|
||||||
|
plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(13)));
|
||||||
|
|
||||||
|
let encrypted = encrypt_for_reader(&plaintext);
|
||||||
|
writer.write_all(&encrypted).await.unwrap();
|
||||||
|
|
||||||
|
let _ = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
payload_len + 8,
|
||||||
|
TokioDuration::from_secs(1),
|
||||||
|
&pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("payload read must succeed")
|
||||||
|
.expect("frame must be present");
|
||||||
|
|
||||||
|
assert_eq!(frame_counter, 1);
|
||||||
|
let pool_stats = pool.stats();
|
||||||
|
assert!(
|
||||||
|
pool_stats.pooled >= 1,
|
||||||
|
"emitted payload buffer must be returned to pool to avoid pool drain"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_keeps_pool_buffer_checked_out_until_frame_drop() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
|
||||||
|
let pool = Arc::new(BufferPool::with_config(64, 2));
|
||||||
|
pool.preallocate(1);
|
||||||
|
assert_eq!(pool.stats().pooled, 1, "one pooled buffer must be available");
|
||||||
|
|
||||||
|
let (reader, mut writer) = duplex(1024);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let payload = [0x41u8, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48];
|
||||||
|
let mut plaintext = Vec::with_capacity(4 + payload.len());
|
||||||
|
plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes());
|
||||||
|
plaintext.extend_from_slice(&payload);
|
||||||
|
let encrypted = encrypt_for_reader(&plaintext);
|
||||||
|
writer.write_all(&encrypted).await.unwrap();
|
||||||
|
|
||||||
|
let (frame, quickack) = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
1024,
|
||||||
|
TokioDuration::from_secs(1),
|
||||||
|
&pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("payload read must succeed")
|
||||||
|
.expect("frame must be present");
|
||||||
|
|
||||||
|
assert!(!quickack);
|
||||||
|
assert_eq!(frame.as_ref(), &payload);
|
||||||
|
assert_eq!(
|
||||||
|
pool.stats().pooled,
|
||||||
|
0,
|
||||||
|
"buffer must stay checked out while frame payload is alive"
|
||||||
|
);
|
||||||
|
|
||||||
|
drop(frame);
|
||||||
|
assert!(
|
||||||
|
pool.stats().pooled >= 1,
|
||||||
|
"buffer must return to pool only after frame drop"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_c2me_close_unblocks_after_queue_drain() {
|
||||||
|
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
|
||||||
|
tx.send(C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload(&[0x41]),
|
||||||
|
flags: 0,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tx2 = tx.clone();
|
||||||
|
let close_task = tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await });
|
||||||
|
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(10)).await;
|
||||||
|
|
||||||
|
let first = timeout(TokioDuration::from_millis(100), rx.recv())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.expect("first queued item must be present");
|
||||||
|
assert!(matches!(first, C2MeCommand::Data { .. }));
|
||||||
|
|
||||||
|
close_task.await.unwrap().expect("close enqueue must succeed after drain");
|
||||||
|
|
||||||
|
let second = timeout(TokioDuration::from_millis(100), rx.recv())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.expect("close command must follow after queue drain");
|
||||||
|
assert!(matches!(second, C2MeCommand::Close));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_c2me_close_full_then_receiver_drop_fails_cleanly() {
|
||||||
|
let (tx, rx) = mpsc::channel::<C2MeCommand>(1);
|
||||||
|
tx.send(C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload(&[0x42]),
|
||||||
|
flags: 0,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tx2 = tx.clone();
|
||||||
|
let close_task = tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await });
|
||||||
|
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(10)).await;
|
||||||
|
drop(rx);
|
||||||
|
|
||||||
|
let result = timeout(TokioDuration::from_secs(1), close_task)
|
||||||
|
.await
|
||||||
|
.expect("close task must finish")
|
||||||
|
.expect("close task must not panic");
|
||||||
|
assert!(
|
||||||
|
result.is_err(),
|
||||||
|
"close enqueue must fail cleanly when receiver is dropped under pressure"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn process_me_writer_response_ack_obeys_flush_policy() {
|
||||||
|
let (writer_side, _reader_side) = duplex(1024);
|
||||||
|
let mut writer = make_crypto_writer(writer_side);
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
let mut frame_buf = Vec::new();
|
||||||
|
let stats = Stats::new();
|
||||||
|
let bytes_me2c = AtomicU64::new(0);
|
||||||
|
|
||||||
|
let immediate = process_me_writer_response(
|
||||||
|
MeResponse::Ack(0x11223344),
|
||||||
|
&mut writer,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
&rng,
|
||||||
|
&mut frame_buf,
|
||||||
|
&stats,
|
||||||
|
"user",
|
||||||
|
&bytes_me2c,
|
||||||
|
77,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("ack response must be processed");
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
immediate,
|
||||||
|
MeWriterResponseOutcome::Continue {
|
||||||
|
frames: 1,
|
||||||
|
bytes: 4,
|
||||||
|
flush_immediately: true,
|
||||||
|
}
|
||||||
|
));
|
||||||
|
|
||||||
|
let delayed = process_me_writer_response(
|
||||||
|
MeResponse::Ack(0x55667788),
|
||||||
|
&mut writer,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
&rng,
|
||||||
|
&mut frame_buf,
|
||||||
|
&stats,
|
||||||
|
"user",
|
||||||
|
&bytes_me2c,
|
||||||
|
77,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("ack response must be processed");
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
delayed,
|
||||||
|
MeWriterResponseOutcome::Continue {
|
||||||
|
frames: 1,
|
||||||
|
bytes: 4,
|
||||||
|
flush_immediately: false,
|
||||||
|
}
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn process_me_writer_response_data_updates_byte_accounting() {
|
||||||
|
let (writer_side, _reader_side) = duplex(1024);
|
||||||
|
let mut writer = make_crypto_writer(writer_side);
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
let mut frame_buf = Vec::new();
|
||||||
|
let stats = Stats::new();
|
||||||
|
let bytes_me2c = AtomicU64::new(0);
|
||||||
|
|
||||||
|
let payload = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||||
|
let outcome = process_me_writer_response(
|
||||||
|
MeResponse::Data {
|
||||||
|
flags: 0,
|
||||||
|
data: Bytes::from(payload.clone()),
|
||||||
|
},
|
||||||
|
&mut writer,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
&rng,
|
||||||
|
&mut frame_buf,
|
||||||
|
&stats,
|
||||||
|
"user",
|
||||||
|
&bytes_me2c,
|
||||||
|
88,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("data response must be processed");
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
outcome,
|
||||||
|
MeWriterResponseOutcome::Continue {
|
||||||
|
frames: 1,
|
||||||
|
bytes,
|
||||||
|
flush_immediately: false,
|
||||||
|
} if bytes == payload.len()
|
||||||
|
));
|
||||||
|
assert_eq!(
|
||||||
|
bytes_me2c.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
payload.len() as u64,
|
||||||
|
"ME->C byte accounting must increase by emitted payload size"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -115,59 +115,109 @@ async fn reap_draining_writers(
|
||||||
pool: &Arc<MePool>,
|
pool: &Arc<MePool>,
|
||||||
warn_next_allowed: &mut HashMap<u64, Instant>,
|
warn_next_allowed: &mut HashMap<u64, Instant>,
|
||||||
) {
|
) {
|
||||||
|
if pool.draining_active_runtime() == 0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
let now_epoch_secs = MePool::now_epoch_secs();
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let drain_ttl_secs = pool.me_pool_drain_ttl_secs.load(std::sync::atomic::Ordering::Relaxed);
|
let drain_ttl_secs = pool.me_pool_drain_ttl_secs.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
let drain_threshold = pool
|
let drain_threshold = pool
|
||||||
.me_pool_drain_threshold
|
.me_pool_drain_threshold
|
||||||
.load(std::sync::atomic::Ordering::Relaxed);
|
.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
let writers = pool.writers.read().await.clone();
|
let mut draining_writers = {
|
||||||
let mut draining_writers = Vec::new();
|
let writers = pool.writers.read().await;
|
||||||
for writer in writers {
|
let mut draining_writers = Vec::<DrainingWriterSnapshot>::new();
|
||||||
if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) {
|
for writer in writers.iter() {
|
||||||
continue;
|
if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
draining_writers.push(DrainingWriterSnapshot {
|
||||||
|
id: writer.id,
|
||||||
|
writer_dc: writer.writer_dc,
|
||||||
|
addr: writer.addr,
|
||||||
|
generation: writer.generation,
|
||||||
|
created_at: writer.created_at,
|
||||||
|
draining_started_at_epoch_secs: writer
|
||||||
|
.draining_started_at_epoch_secs
|
||||||
|
.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
drain_deadline_epoch_secs: writer
|
||||||
|
.drain_deadline_epoch_secs
|
||||||
|
.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
allow_drain_fallback: writer
|
||||||
|
.allow_drain_fallback
|
||||||
|
.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
let is_empty = pool.registry.is_writer_empty(writer.id).await;
|
draining_writers
|
||||||
if is_empty {
|
};
|
||||||
pool.remove_writer_and_close_clients(writer.id).await;
|
|
||||||
continue;
|
if draining_writers.is_empty() {
|
||||||
}
|
return;
|
||||||
draining_writers.push(writer);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize {
|
let draining_ids: Vec<u64> = draining_writers.iter().map(|writer| writer.id).collect();
|
||||||
draining_writers.sort_by(|left, right| {
|
let non_empty_writer_ids = pool.registry.non_empty_writer_ids(&draining_ids).await;
|
||||||
let left_started = left
|
let mut non_empty_draining_writers =
|
||||||
.draining_started_at_epoch_secs
|
Vec::<DrainingWriterSnapshot>::with_capacity(draining_writers.len());
|
||||||
.load(std::sync::atomic::Ordering::Relaxed);
|
for writer in draining_writers.drain(..) {
|
||||||
let right_started = right
|
if non_empty_writer_ids.contains(&writer.id) {
|
||||||
.draining_started_at_epoch_secs
|
non_empty_draining_writers.push(writer);
|
||||||
.load(std::sync::atomic::Ordering::Relaxed);
|
} else {
|
||||||
left_started
|
|
||||||
.cmp(&right_started)
|
|
||||||
.then_with(|| left.created_at.cmp(&right.created_at))
|
|
||||||
.then_with(|| left.id.cmp(&right.id))
|
|
||||||
});
|
|
||||||
let overflow = draining_writers.len().saturating_sub(drain_threshold as usize);
|
|
||||||
warn!(
|
|
||||||
draining_writers = draining_writers.len(),
|
|
||||||
me_pool_drain_threshold = drain_threshold,
|
|
||||||
removing_writers = overflow,
|
|
||||||
"ME draining writer threshold exceeded, force-closing oldest draining writers"
|
|
||||||
);
|
|
||||||
for writer in draining_writers.drain(..overflow) {
|
|
||||||
pool.stats.increment_pool_force_close_total();
|
|
||||||
pool.remove_writer_and_close_clients(writer.id).await;
|
pool.remove_writer_and_close_clients(writer.id).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
draining_writers = non_empty_draining_writers;
|
||||||
|
if draining_writers.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let overflow = if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize {
|
||||||
|
draining_writers.len().saturating_sub(drain_threshold as usize)
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
let has_deadline_expired = draining_writers.iter().any(|writer| {
|
||||||
|
writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs
|
||||||
|
});
|
||||||
|
let can_drop_with_replacement = if overflow > 0 || has_deadline_expired {
|
||||||
|
pool.has_non_draining_writer_per_desired_dc_group().await
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
|
||||||
|
if overflow > 0 {
|
||||||
|
if can_drop_with_replacement {
|
||||||
|
draining_writers.sort_by(|left, right| {
|
||||||
|
left.draining_started_at_epoch_secs
|
||||||
|
.cmp(&right.draining_started_at_epoch_secs)
|
||||||
|
.then_with(|| left.created_at.cmp(&right.created_at))
|
||||||
|
.then_with(|| left.id.cmp(&right.id))
|
||||||
|
});
|
||||||
|
warn!(
|
||||||
|
draining_writers = draining_writers.len(),
|
||||||
|
me_pool_drain_threshold = drain_threshold,
|
||||||
|
removing_writers = overflow,
|
||||||
|
"ME draining writer threshold exceeded, force-closing oldest draining writers"
|
||||||
|
);
|
||||||
|
for writer in draining_writers.drain(..overflow) {
|
||||||
|
pool.stats.increment_pool_force_close_total();
|
||||||
|
pool.remove_writer_and_close_clients(writer.id).await;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
draining_writers = draining_writers.len(),
|
||||||
|
me_pool_drain_threshold = drain_threshold,
|
||||||
|
overflow,
|
||||||
|
"ME draining threshold exceeded, but replacement coverage is incomplete; keeping draining writers"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for writer in draining_writers {
|
for writer in draining_writers {
|
||||||
let drain_started_at_epoch_secs = writer
|
|
||||||
.draining_started_at_epoch_secs
|
|
||||||
.load(std::sync::atomic::Ordering::Relaxed);
|
|
||||||
if drain_ttl_secs > 0
|
if drain_ttl_secs > 0
|
||||||
&& drain_started_at_epoch_secs != 0
|
&& writer.draining_started_at_epoch_secs != 0
|
||||||
&& now_epoch_secs.saturating_sub(drain_started_at_epoch_secs) > drain_ttl_secs
|
&& now_epoch_secs.saturating_sub(writer.draining_started_at_epoch_secs) > drain_ttl_secs
|
||||||
&& should_emit_writer_warn(
|
&& should_emit_writer_warn(
|
||||||
warn_next_allowed,
|
warn_next_allowed,
|
||||||
writer.id,
|
writer.id,
|
||||||
|
|
@ -182,21 +232,45 @@ async fn reap_draining_writers(
|
||||||
generation = writer.generation,
|
generation = writer.generation,
|
||||||
drain_ttl_secs,
|
drain_ttl_secs,
|
||||||
force_close_secs = pool.me_pool_force_close_secs.load(std::sync::atomic::Ordering::Relaxed),
|
force_close_secs = pool.me_pool_force_close_secs.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
allow_drain_fallback = writer.allow_drain_fallback.load(std::sync::atomic::Ordering::Relaxed),
|
allow_drain_fallback = writer.allow_drain_fallback,
|
||||||
"ME draining writer remains non-empty past drain TTL"
|
"ME draining writer remains non-empty past drain TTL"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let deadline_epoch_secs = writer
|
if writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs
|
||||||
.drain_deadline_epoch_secs
|
{
|
||||||
.load(std::sync::atomic::Ordering::Relaxed);
|
if can_drop_with_replacement {
|
||||||
if deadline_epoch_secs != 0 && now_epoch_secs >= deadline_epoch_secs {
|
warn!(writer_id = writer.id, "Drain timeout, force-closing");
|
||||||
warn!(writer_id = writer.id, "Drain timeout, force-closing");
|
pool.stats.increment_pool_force_close_total();
|
||||||
pool.stats.increment_pool_force_close_total();
|
pool.remove_writer_and_close_clients(writer.id).await;
|
||||||
pool.remove_writer_and_close_clients(writer.id).await;
|
} else if should_emit_writer_warn(
|
||||||
|
warn_next_allowed,
|
||||||
|
writer.id,
|
||||||
|
now,
|
||||||
|
pool.warn_rate_limit_duration(),
|
||||||
|
) {
|
||||||
|
warn!(
|
||||||
|
writer_id = writer.id,
|
||||||
|
writer_dc = writer.writer_dc,
|
||||||
|
endpoint = %writer.addr,
|
||||||
|
"Drain timeout reached, but replacement coverage is incomplete; keeping draining writer"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct DrainingWriterSnapshot {
|
||||||
|
id: u64,
|
||||||
|
writer_dc: i32,
|
||||||
|
addr: SocketAddr,
|
||||||
|
generation: u64,
|
||||||
|
created_at: Instant,
|
||||||
|
draining_started_at_epoch_secs: u64,
|
||||||
|
drain_deadline_epoch_secs: u64,
|
||||||
|
allow_drain_fallback: bool,
|
||||||
|
}
|
||||||
|
|
||||||
fn should_emit_writer_warn(
|
fn should_emit_writer_warn(
|
||||||
next_allowed: &mut HashMap<u64, Instant>,
|
next_allowed: &mut HashMap<u64, Instant>,
|
||||||
writer_id: u64,
|
writer_id: u64,
|
||||||
|
|
@ -1330,6 +1404,15 @@ mod tests {
|
||||||
me_pool_drain_threshold,
|
me_pool_drain_threshold,
|
||||||
..GeneralConfig::default()
|
..GeneralConfig::default()
|
||||||
};
|
};
|
||||||
|
let mut proxy_map_v4 = HashMap::new();
|
||||||
|
proxy_map_v4.insert(
|
||||||
|
2,
|
||||||
|
vec![(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 10)), 443)],
|
||||||
|
);
|
||||||
|
let decision = NetworkDecision {
|
||||||
|
ipv4_me: true,
|
||||||
|
..NetworkDecision::default()
|
||||||
|
};
|
||||||
MePool::new(
|
MePool::new(
|
||||||
None,
|
None,
|
||||||
vec![1u8; 32],
|
vec![1u8; 32],
|
||||||
|
|
@ -1341,10 +1424,10 @@ mod tests {
|
||||||
None,
|
None,
|
||||||
12,
|
12,
|
||||||
1200,
|
1200,
|
||||||
HashMap::new(),
|
proxy_map_v4,
|
||||||
HashMap::new(),
|
HashMap::new(),
|
||||||
None,
|
None,
|
||||||
NetworkDecision::default(),
|
decision,
|
||||||
None,
|
None,
|
||||||
Arc::new(SecureRandom::new()),
|
Arc::new(SecureRandom::new()),
|
||||||
Arc::new(Stats::default()),
|
Arc::new(Stats::default()),
|
||||||
|
|
@ -1438,6 +1521,7 @@ mod tests {
|
||||||
pool.writers.write().await.push(writer);
|
pool.writers.write().await.push(writer);
|
||||||
pool.registry.register_writer(writer_id, tx).await;
|
pool.registry.register_writer(writer_id, tx).await;
|
||||||
pool.conn_count.fetch_add(1, Ordering::Relaxed);
|
pool.conn_count.fetch_add(1, Ordering::Relaxed);
|
||||||
|
pool.increment_draining_active_runtime();
|
||||||
assert!(
|
assert!(
|
||||||
pool.registry
|
pool.registry
|
||||||
.bind_writer(
|
.bind_writer(
|
||||||
|
|
@ -1455,8 +1539,56 @@ mod tests {
|
||||||
conn_id
|
conn_id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn insert_live_writer(pool: &Arc<MePool>, writer_id: u64, writer_dc: i32) {
|
||||||
|
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
|
||||||
|
let writer = MeWriter {
|
||||||
|
id: writer_id,
|
||||||
|
addr: SocketAddr::new(
|
||||||
|
IpAddr::V4(Ipv4Addr::new(203, 0, 113, (writer_id as u8).saturating_add(1))),
|
||||||
|
4000 + writer_id as u16,
|
||||||
|
),
|
||||||
|
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
|
||||||
|
writer_dc,
|
||||||
|
generation: 2,
|
||||||
|
contour: Arc::new(AtomicU8::new(WriterContour::Active.as_u8())),
|
||||||
|
created_at: Instant::now(),
|
||||||
|
tx: tx.clone(),
|
||||||
|
cancel: CancellationToken::new(),
|
||||||
|
degraded: Arc::new(AtomicBool::new(false)),
|
||||||
|
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
|
||||||
|
draining: Arc::new(AtomicBool::new(false)),
|
||||||
|
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(0)),
|
||||||
|
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)),
|
||||||
|
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
|
||||||
|
};
|
||||||
|
pool.writers.write().await.push(writer);
|
||||||
|
pool.registry.register_writer(writer_id, tx).await;
|
||||||
|
pool.conn_count.fetch_add(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn reap_draining_writers_force_closes_oldest_over_threshold() {
|
async fn reap_draining_writers_force_closes_oldest_over_threshold() {
|
||||||
|
let pool = make_pool(2).await;
|
||||||
|
insert_live_writer(&pool, 1, 2).await;
|
||||||
|
assert!(pool.has_non_draining_writer_per_desired_dc_group().await);
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
|
||||||
|
let conn_b = insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(20)).await;
|
||||||
|
let conn_c = insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(10)).await;
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
|
let mut writer_ids: Vec<u64> = pool.writers.read().await.iter().map(|writer| writer.id).collect();
|
||||||
|
writer_ids.sort_unstable();
|
||||||
|
assert_eq!(writer_ids, vec![1, 20, 30]);
|
||||||
|
assert!(pool.registry.get_writer(conn_a).await.is_none());
|
||||||
|
assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20);
|
||||||
|
assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_does_not_force_close_overflow_without_replacement() {
|
||||||
let pool = make_pool(2).await;
|
let pool = make_pool(2).await;
|
||||||
let now_epoch_secs = MePool::now_epoch_secs();
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
|
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
|
||||||
|
|
@ -1466,9 +1598,10 @@ mod tests {
|
||||||
|
|
||||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
let writer_ids: Vec<u64> = pool.writers.read().await.iter().map(|writer| writer.id).collect();
|
let mut writer_ids: Vec<u64> = pool.writers.read().await.iter().map(|writer| writer.id).collect();
|
||||||
assert_eq!(writer_ids, vec![20, 30]);
|
writer_ids.sort_unstable();
|
||||||
assert!(pool.registry.get_writer(conn_a).await.is_none());
|
assert_eq!(writer_ids, vec![10, 20, 30]);
|
||||||
|
assert_eq!(pool.registry.get_writer(conn_a).await.unwrap().writer_id, 10);
|
||||||
assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20);
|
assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20);
|
||||||
assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30);
|
assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -160,6 +160,7 @@ pub struct MePool {
|
||||||
pub(super) refill_inflight: Arc<Mutex<HashSet<RefillEndpointKey>>>,
|
pub(super) refill_inflight: Arc<Mutex<HashSet<RefillEndpointKey>>>,
|
||||||
pub(super) refill_inflight_dc: Arc<Mutex<HashSet<RefillDcKey>>>,
|
pub(super) refill_inflight_dc: Arc<Mutex<HashSet<RefillDcKey>>>,
|
||||||
pub(super) conn_count: AtomicUsize,
|
pub(super) conn_count: AtomicUsize,
|
||||||
|
pub(super) draining_active_runtime: AtomicU64,
|
||||||
pub(super) stats: Arc<crate::stats::Stats>,
|
pub(super) stats: Arc<crate::stats::Stats>,
|
||||||
pub(super) generation: AtomicU64,
|
pub(super) generation: AtomicU64,
|
||||||
pub(super) active_generation: AtomicU64,
|
pub(super) active_generation: AtomicU64,
|
||||||
|
|
@ -438,6 +439,7 @@ impl MePool {
|
||||||
refill_inflight: Arc::new(Mutex::new(HashSet::new())),
|
refill_inflight: Arc::new(Mutex::new(HashSet::new())),
|
||||||
refill_inflight_dc: Arc::new(Mutex::new(HashSet::new())),
|
refill_inflight_dc: Arc::new(Mutex::new(HashSet::new())),
|
||||||
conn_count: AtomicUsize::new(0),
|
conn_count: AtomicUsize::new(0),
|
||||||
|
draining_active_runtime: AtomicU64::new(0),
|
||||||
generation: AtomicU64::new(1),
|
generation: AtomicU64::new(1),
|
||||||
active_generation: AtomicU64::new(1),
|
active_generation: AtomicU64::new(1),
|
||||||
warm_generation: AtomicU64::new(0),
|
warm_generation: AtomicU64::new(0),
|
||||||
|
|
@ -690,6 +692,32 @@ impl MePool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) fn draining_active_runtime(&self) -> u64 {
|
||||||
|
self.draining_active_runtime.load(Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn increment_draining_active_runtime(&self) {
|
||||||
|
self.draining_active_runtime.fetch_add(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn decrement_draining_active_runtime(&self) {
|
||||||
|
let mut current = self.draining_active_runtime.load(Ordering::Relaxed);
|
||||||
|
loop {
|
||||||
|
if current == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
match self.draining_active_runtime.compare_exchange_weak(
|
||||||
|
current,
|
||||||
|
current - 1,
|
||||||
|
Ordering::Relaxed,
|
||||||
|
Ordering::Relaxed,
|
||||||
|
) {
|
||||||
|
Ok(_) => break,
|
||||||
|
Err(actual) => current = actual,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(super) async fn key_selector(&self) -> u32 {
|
pub(super) async fn key_selector(&self) -> u32 {
|
||||||
self.proxy_secret.read().await.key_selector
|
self.proxy_secret.read().await.key_selector
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -141,6 +141,38 @@ impl MePool {
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) async fn has_non_draining_writer_per_desired_dc_group(&self) -> bool {
|
||||||
|
let desired_by_dc = self.desired_dc_endpoints().await;
|
||||||
|
let required_dcs: HashSet<i32> = desired_by_dc
|
||||||
|
.iter()
|
||||||
|
.filter_map(|(dc, endpoints)| {
|
||||||
|
if endpoints.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(*dc)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
if required_dcs.is_empty() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ws = self.writers.read().await;
|
||||||
|
let mut covered_dcs = HashSet::<i32>::with_capacity(required_dcs.len());
|
||||||
|
for writer in ws.iter() {
|
||||||
|
if writer.draining.load(Ordering::Relaxed) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if required_dcs.contains(&writer.writer_dc) {
|
||||||
|
covered_dcs.insert(writer.writer_dc);
|
||||||
|
if covered_dcs.len() == required_dcs.len() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
fn hardswap_warmup_connect_delay_ms(&self) -> u64 {
|
fn hardswap_warmup_connect_delay_ms(&self) -> u64 {
|
||||||
let min_ms = self.me_hardswap_warmup_delay_min_ms.load(Ordering::Relaxed);
|
let min_ms = self.me_hardswap_warmup_delay_min_ms.load(Ordering::Relaxed);
|
||||||
let max_ms = self.me_hardswap_warmup_delay_max_ms.load(Ordering::Relaxed);
|
let max_ms = self.me_hardswap_warmup_delay_max_ms.load(Ordering::Relaxed);
|
||||||
|
|
@ -475,12 +507,30 @@ impl MePool {
|
||||||
coverage_ratio = format_args!("{coverage_ratio:.3}"),
|
coverage_ratio = format_args!("{coverage_ratio:.3}"),
|
||||||
min_ratio = format_args!("{min_ratio:.3}"),
|
min_ratio = format_args!("{min_ratio:.3}"),
|
||||||
drain_timeout_secs,
|
drain_timeout_secs,
|
||||||
"ME reinit cycle covered; draining stale writers"
|
"ME reinit cycle covered; processing stale writers"
|
||||||
);
|
);
|
||||||
self.stats.increment_pool_swap_total();
|
self.stats.increment_pool_swap_total();
|
||||||
|
let can_drop_with_replacement = self
|
||||||
|
.has_non_draining_writer_per_desired_dc_group()
|
||||||
|
.await;
|
||||||
|
if can_drop_with_replacement {
|
||||||
|
info!(
|
||||||
|
stale_writers = stale_writer_ids.len(),
|
||||||
|
"ME reinit stale writers: replacement coverage ready, force-closing clients for fast rebind"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
stale_writers = stale_writer_ids.len(),
|
||||||
|
"ME reinit stale writers: replacement coverage incomplete, keeping draining fallback"
|
||||||
|
);
|
||||||
|
}
|
||||||
for writer_id in stale_writer_ids {
|
for writer_id in stale_writer_ids {
|
||||||
self.mark_writer_draining_with_timeout(writer_id, drain_timeout, !hardswap)
|
self.mark_writer_draining_with_timeout(writer_id, drain_timeout, !hardswap)
|
||||||
.await;
|
.await;
|
||||||
|
if can_drop_with_replacement {
|
||||||
|
self.stats.increment_pool_force_close_total();
|
||||||
|
self.remove_writer_and_close_clients(writer_id).await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if hardswap {
|
if hardswap {
|
||||||
self.clear_pending_hardswap_state();
|
self.clear_pending_hardswap_state();
|
||||||
|
|
|
||||||
|
|
@ -514,6 +514,7 @@ impl MePool {
|
||||||
let was_draining = w.draining.load(Ordering::Relaxed);
|
let was_draining = w.draining.load(Ordering::Relaxed);
|
||||||
if was_draining {
|
if was_draining {
|
||||||
self.stats.decrement_pool_drain_active();
|
self.stats.decrement_pool_drain_active();
|
||||||
|
self.decrement_draining_active_runtime();
|
||||||
}
|
}
|
||||||
self.stats.increment_me_writer_removed_total();
|
self.stats.increment_me_writer_removed_total();
|
||||||
w.cancel.cancel();
|
w.cancel.cancel();
|
||||||
|
|
@ -572,6 +573,7 @@ impl MePool {
|
||||||
.store(drain_deadline_epoch_secs, Ordering::Relaxed);
|
.store(drain_deadline_epoch_secs, Ordering::Relaxed);
|
||||||
if !already_draining {
|
if !already_draining {
|
||||||
self.stats.increment_pool_drain_active();
|
self.stats.increment_pool_drain_active();
|
||||||
|
self.increment_draining_active_runtime();
|
||||||
}
|
}
|
||||||
w.contour
|
w.contour
|
||||||
.store(WriterContour::Draining.as_u8(), Ordering::Relaxed);
|
.store(WriterContour::Draining.as_u8(), Ordering::Relaxed);
|
||||||
|
|
|
||||||
|
|
@ -436,6 +436,19 @@ impl ConnRegistry {
|
||||||
.map(|s| s.is_empty())
|
.map(|s| s.is_empty())
|
||||||
.unwrap_or(true)
|
.unwrap_or(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet<u64> {
|
||||||
|
let inner = self.inner.read().await;
|
||||||
|
let mut out = HashSet::<u64>::with_capacity(writer_ids.len());
|
||||||
|
for writer_id in writer_ids {
|
||||||
|
if let Some(conns) = inner.conns_for_writer.get(writer_id)
|
||||||
|
&& !conns.is_empty()
|
||||||
|
{
|
||||||
|
out.insert(*writer_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -634,4 +647,35 @@ mod tests {
|
||||||
);
|
);
|
||||||
assert!(registry.get_writer(conn_id).await.is_none());
|
assert!(registry.get_writer(conn_id).await.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn non_empty_writer_ids_returns_only_writers_with_bound_clients() {
|
||||||
|
let registry = ConnRegistry::new();
|
||||||
|
let (conn_id, _rx) = registry.register().await;
|
||||||
|
let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8);
|
||||||
|
let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8);
|
||||||
|
registry.register_writer(10, writer_tx_a).await;
|
||||||
|
registry.register_writer(20, writer_tx_b).await;
|
||||||
|
|
||||||
|
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443);
|
||||||
|
assert!(
|
||||||
|
registry
|
||||||
|
.bind_writer(
|
||||||
|
conn_id,
|
||||||
|
10,
|
||||||
|
ConnMeta {
|
||||||
|
target_dc: 2,
|
||||||
|
client_addr: addr,
|
||||||
|
our_addr: addr,
|
||||||
|
proto_flags: 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
);
|
||||||
|
|
||||||
|
let non_empty = registry.non_empty_writer_ids(&[10, 20, 30]).await;
|
||||||
|
assert!(non_empty.contains(&10));
|
||||||
|
assert!(!non_empty.contains(&20));
|
||||||
|
assert!(!non_empty.contains(&30));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue