Enhance TLS Emulator with ALPN Support and Add Adversarial Tests

- Modified `build_emulated_server_hello` to accept ALPN (Application-Layer Protocol Negotiation) as an optional parameter, allowing for the embedding of ALPN markers in the application data payload.
- Implemented logic to handle oversized ALPN values and ensure they do not interfere with the application data payload.
- Added new security tests in `emulator_security_tests.rs` to validate the behavior of the ALPN embedding, including scenarios for oversized ALPN and preference for certificate payloads over ALPN markers.
- Introduced `send_adversarial_tests.rs` to cover edge cases and potential issues in the middle proxy's send functionality, ensuring robustness against various failure modes.
- Updated `middle_proxy` module to include new test modules and ensure proper handling of writer commands during data transmission.
This commit is contained in:
David Osipov
2026-03-18 17:04:50 +04:00
parent 97d4a1c5c8
commit 20e205189c
20 changed files with 2935 additions and 113 deletions

View File

@@ -31,19 +31,16 @@ struct UserConnectionReservation {
user: String,
ip: IpAddr,
active: bool,
runtime_handle: Option<tokio::runtime::Handle>,
}
impl UserConnectionReservation {
fn new(stats: Arc<Stats>, ip_tracker: Arc<UserIpTracker>, user: String, ip: IpAddr) -> Self {
let runtime_handle = tokio::runtime::Handle::try_current().ok();
Self {
stats,
ip_tracker,
user,
ip,
active: true,
runtime_handle,
}
}
@@ -64,29 +61,7 @@ impl Drop for UserConnectionReservation {
}
self.active = false;
self.stats.decrement_user_curr_connects(&self.user);
if let Some(handle) = &self.runtime_handle {
let ip_tracker = self.ip_tracker.clone();
let user = self.user.clone();
let ip = self.ip;
let handle = handle.clone();
handle.spawn(async move {
ip_tracker.remove_ip(&user, ip).await;
});
} else if let Ok(handle) = tokio::runtime::Handle::try_current() {
let ip_tracker = self.ip_tracker.clone();
let user = self.user.clone();
let ip = self.ip;
handle.spawn(async move {
ip_tracker.remove_ip(&user, ip).await;
});
} else {
warn!(
user = %self.user,
ip = %self.ip,
"UserConnectionReservation dropped without Tokio runtime; IP reservation cleanup skipped"
);
}
self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip);
}
}

View File

@@ -42,6 +42,35 @@ where
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
}
#[tokio::test]
async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() {
let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new());
let stats = Arc::new(crate::stats::Stats::new());
let user = "sync-drop-user".to_string();
let ip: std::net::IpAddr = "192.168.1.1".parse().unwrap();
ip_tracker.set_user_limit(&user, 1).await;
ip_tracker.check_and_add(&user, ip).await.unwrap();
stats.increment_user_curr_connects(&user);
assert_eq!(ip_tracker.get_active_ip_count(&user).await, 1);
assert_eq!(stats.get_user_curr_connects(&user), 1);
let reservation = UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip);
// Drop the reservation synchronously without any tokio::spawn/await yielding!
drop(reservation);
// The IP is now inside the cleanup_queue, check that the queue has length 1
let queue_len = ip_tracker.cleanup_queue.lock().unwrap().len();
assert_eq!(queue_len, 1, "Reservation drop must push directly to synchronized IP queue");
assert_eq!(stats.get_user_curr_connects(&user), 0, "Stats must decrement immediately");
ip_tracker.drain_cleanup_queue().await;
assert_eq!(ip_tracker.get_active_ip_count(&user).await, 0);
}
#[tokio::test]
async fn relay_task_abort_releases_user_gate_and_ip_reservation() {
let tg_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();

View File

@@ -132,7 +132,11 @@ fn open_unknown_dc_log_append(path: &Path) -> std::io::Result<std::fs::File> {
}
#[cfg(not(unix))]
{
OpenOptions::new().create(true).append(true).open(path)
let _ = path;
Err(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"unknown_dc_file_log_enabled requires unix O_NOFOLLOW support",
))
}
}
@@ -204,6 +208,7 @@ where
config.general.direct_relay_copy_buf_s2c_bytes,
user,
Arc::clone(&stats),
config.access.user_data_quota.get(user).copied(),
buffer_pool,
);
tokio::pin!(relay_result);

View File

@@ -241,7 +241,26 @@ fn auth_probe_record_failure_with_state(
rounds += 1;
if rounds > 8 {
auth_probe_note_saturation(now);
return;
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) =>
{
}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
}
let Some((evict_key, _, _)) = eviction_candidate else {
return;
};
state.remove(&evict_key);
break;
}
let mut stale_keys = Vec::new();
@@ -518,6 +537,7 @@ pub struct HandshakeSuccess {
/// Client address
pub peer: SocketAddr,
/// Whether TLS was used
pub is_tls: bool,
}
@@ -716,7 +736,11 @@ where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
{
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
trace!(
peer = %peer,
handshake_head = %hex::encode(&handshake[..8]),
"MTProto handshake prefix"
);
let throttle_now = Instant::now();
if auth_probe_should_apply_preauth_throttle(peer.ip(), throttle_now) {
@@ -916,6 +940,7 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, A
}
/// Encrypt nonce for sending to Telegram (legacy function for compatibility)
pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(nonce);
encrypted

View File

@@ -1584,6 +1584,47 @@ fn stress_auth_probe_full_map_churn_keeps_bound_and_tracks_newcomers() {
}
}
#[test]
fn auth_probe_over_cap_churn_still_tracks_newcomer_after_round_limit() {
let _guard = auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing();
let state = DashMap::new();
let now = Instant::now();
let initial = AUTH_PROBE_TRACK_MAX_ENTRIES + 32;
for idx in 0..initial {
let ip = IpAddr::V4(Ipv4Addr::new(
10,
6,
((idx >> 8) & 0xff) as u8,
(idx & 0xff) as u8,
));
state.insert(
ip,
AuthProbeState {
fail_streak: 1,
blocked_until: now,
last_seen: now + Duration::from_millis((idx % 1024) as u64),
},
);
}
let newcomer = IpAddr::V4(Ipv4Addr::new(203, 0, 114, 77));
auth_probe_record_failure_with_state(&state, newcomer, now + Duration::from_secs(1));
assert!(
state.get(&newcomer).is_some(),
"new probe source must still be tracked even when map starts above hard cap"
);
assert!(
state.len() < initial + 1,
"round-limited eviction path must still reclaim capacity under over-cap churn"
);
}
#[test]
fn auth_probe_capacity_prefers_evicting_low_fail_streak_entries_first() {
let _guard = auth_probe_test_lock()

View File

@@ -2,15 +2,13 @@ use std::collections::hash_map::RandomState;
use std::hash::BuildHasher;
use std::hash::{Hash, Hasher};
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, OnceLock};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant};
#[cfg(test)]
use std::sync::Mutex;
use dashmap::DashMap;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot, watch};
use tokio::sync::{mpsc, oneshot, watch, Mutex as AsyncMutex};
use tokio::time::timeout;
use tracing::{debug, trace, warn};
@@ -35,14 +33,22 @@ enum C2MeCommand {
const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60);
const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536;
const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024;
const DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL: Duration = Duration::from_millis(1000);
const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
#[cfg(test)]
const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50);
#[cfg(not(test))]
const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5);
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new();
static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock<Mutex<Option<Instant>>> = OnceLock::new();
static DESYNC_DEDUP_EVER_SATURATED: OnceLock<AtomicBool> = OnceLock::new();
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<AsyncMutex<()>>>> = OnceLock::new();
struct RelayForensicsState {
trace_id: u64,
@@ -98,6 +104,11 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
}
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
let saturated_before = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES;
let ever_saturated = DESYNC_DEDUP_EVER_SATURATED.get_or_init(|| AtomicBool::new(false));
if saturated_before {
ever_saturated.store(true, Ordering::Relaxed);
}
if let Some(mut seen_at) = dedup.get_mut(&key) {
if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW {
@@ -132,12 +143,52 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
};
dedup.remove(&evict_key);
dedup.insert(key, now);
return false;
return should_emit_full_desync_full_cache(now);
}
}
dedup.insert(key, now);
true
let saturated_after = dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES;
// Preserve the first sequential insert that reaches capacity as a normal
// emit, while still gating concurrent newcomer churn after the cache has
// ever been observed at saturation.
let was_ever_saturated = if saturated_after {
ever_saturated.swap(true, Ordering::Relaxed)
} else {
ever_saturated.load(Ordering::Relaxed)
};
if saturated_before || (saturated_after && was_ever_saturated) {
should_emit_full_desync_full_cache(now)
} else {
true
}
}
fn should_emit_full_desync_full_cache(now: Instant) -> bool {
let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None));
let Ok(mut last_emit_at) = gate.lock() else {
return false;
};
match *last_emit_at {
None => {
*last_emit_at = Some(now);
true
}
Some(last) => {
let Some(elapsed) = now.checked_duration_since(last) else {
*last_emit_at = Some(now);
return true;
};
if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL {
*last_emit_at = Some(now);
true
} else {
false
}
}
}
}
#[cfg(test)]
@@ -145,6 +196,21 @@ fn clear_desync_dedup_for_testing() {
if let Some(dedup) = DESYNC_DEDUP.get() {
dedup.clear();
}
if let Some(ever_saturated) = DESYNC_DEDUP_EVER_SATURATED.get() {
ever_saturated.store(false, Ordering::Relaxed);
}
if let Some(last_emit_at) = DESYNC_FULL_CACHE_LAST_EMIT_AT.get() {
match last_emit_at.lock() {
Ok(mut guard) => {
*guard = None;
}
Err(poisoned) => {
let mut guard = poisoned.into_inner();
*guard = None;
last_emit_at.clear_poison();
}
}
}
}
#[cfg(test)]
@@ -248,6 +314,38 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool
has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET
}
fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option<u64>) -> bool {
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota)
}
fn quota_would_be_exceeded_for_user(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
bytes: u64,
) -> bool {
quota_limit.is_some_and(|quota| {
let used = stats.get_user_total_octets(user);
used >= quota || bytes > quota.saturating_sub(used)
})
}
fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
let created = Arc::new(AsyncMutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
}
async fn enqueue_c2me_command(
tx: &mpsc::Sender<C2MeCommand>,
cmd: C2MeCommand,
@@ -260,7 +358,14 @@ async fn enqueue_c2me_command(
if tx.capacity() <= C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS {
tokio::task::yield_now().await;
}
tx.send(cmd).await
match timeout(C2ME_SEND_TIMEOUT, tx.reserve()).await {
Ok(Ok(permit)) => {
permit.send(cmd);
Ok(())
}
Ok(Err(_)) => Err(mpsc::error::SendError(cmd)),
Err(_) => Err(mpsc::error::SendError(cmd)),
}
}
}
}
@@ -284,6 +389,7 @@ where
W: AsyncWrite + Unpin + Send + 'static,
{
let user = success.user.clone();
let quota_limit = config.access.user_data_quota.get(&user).copied();
let peer = success.peer;
let proto_tag = success.proto_tag;
let pool_generation = me_pool.current_generation();
@@ -432,6 +538,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_limit,
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@@ -464,6 +571,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_limit,
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@@ -496,6 +604,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_limit,
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@@ -528,6 +637,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_limit,
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@@ -609,7 +719,19 @@ where
forensics.bytes_c2me = forensics
.bytes_c2me
.saturating_add(payload.len() as u64);
stats.add_user_octets_from(&user, payload.len() as u64);
if let Some(limit) = quota_limit {
let quota_lock = quota_user_lock(&user);
let _quota_guard = quota_lock.lock().await;
stats.add_user_octets_from(&user, payload.len() as u64);
if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) {
main_result = Err(ProxyError::DataQuotaExceeded {
user: user.clone(),
});
break;
}
} else {
stats.add_user_octets_from(&user, payload.len() as u64);
}
let mut flags = proto_flags;
if quickack {
flags |= RPC_FLAG_QUICKACK;
@@ -833,6 +955,7 @@ async fn process_me_writer_response<W>(
frame_buf: &mut Vec<u8>,
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
bytes_me2c: &AtomicU64,
conn_id: u64,
ack_flush_immediate: bool,
@@ -848,17 +971,47 @@ where
} else {
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
}
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64);
write_client_payload(
client_writer,
proto_tag,
flags,
&data,
rng,
frame_buf,
)
.await?;
let data_len = data.len() as u64;
if let Some(limit) = quota_limit {
let quota_lock = quota_user_lock(user);
let _quota_guard = quota_lock.lock().await;
if quota_would_be_exceeded_for_user(stats, user, Some(limit), data_len) {
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
write_client_payload(
client_writer,
proto_tag,
flags,
&data,
rng,
frame_buf,
)
.await?;
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64);
if quota_exceeded_for_user(stats, user, Some(limit)) {
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
} else {
write_client_payload(
client_writer,
proto_tag,
flags,
&data,
rng,
frame_buf,
)
.await?;
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64);
}
Ok(MeWriterResponseOutcome::Continue {
frames: 1,

View File

@@ -13,8 +13,9 @@ use rand::{Rng, SeedableRng};
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use tokio::io::AsyncWriteExt;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::thread;
use tokio::io::AsyncReadExt;
use tokio::io::duplex;
use tokio::time::{Duration as TokioDuration, timeout};
@@ -176,6 +177,36 @@ async fn enqueue_c2me_command_full_then_closed_recycles_waiting_payload() {
);
}
#[tokio::test]
async fn enqueue_c2me_command_full_queue_times_out_without_receiver_progress() {
let (tx, _rx) = mpsc::channel::<C2MeCommand>(1);
tx.send(C2MeCommand::Data {
payload: make_pooled_payload(&[1]),
flags: 0,
})
.await
.unwrap();
let started = Instant::now();
let result = enqueue_c2me_command(
&tx,
C2MeCommand::Data {
payload: make_pooled_payload(&[2, 2]),
flags: 1,
},
)
.await;
assert!(
result.is_err(),
"enqueue must fail when queue stays full beyond bounded timeout"
);
assert!(
started.elapsed() < TokioDuration::from_millis(400),
"full-queue timeout must resolve promptly"
);
}
#[test]
fn desync_dedup_cache_is_bounded() {
let _guard = desync_dedup_test_lock()
@@ -192,12 +223,12 @@ fn desync_dedup_cache_is_bounded() {
}
assert!(
!should_emit_full_desync(u64::MAX, false, now),
"new key above cap must remain suppressed to avoid log amplification"
should_emit_full_desync(u64::MAX, false, now),
"new key above cap must emit once after bounded eviction for forensic visibility"
);
assert!(
!should_emit_full_desync(7, false, now),
!should_emit_full_desync(u64::MAX, false, now),
"already tracked key inside dedup window must stay suppressed"
);
}
@@ -215,10 +246,18 @@ fn desync_dedup_full_cache_churn_stays_suppressed() {
}
for offset in 0..2048u64 {
assert!(
!should_emit_full_desync(u64::MAX - offset, false, now),
"fresh full-cache churn must remain suppressed under pressure"
);
let emitted = should_emit_full_desync(u64::MAX - offset, false, now);
if offset == 0 {
assert!(
emitted,
"first full-cache newcomer should emit for forensic visibility"
);
} else {
assert!(
!emitted,
"full-cache newcomer churn inside emit interval must stay suppressed"
);
}
}
}
@@ -296,18 +335,20 @@ fn stress_desync_dedup_churn_keeps_cache_hard_bounded() {
let now = Instant::now();
let total = DESYNC_DEDUP_MAX_ENTRIES + 8192;
let mut emitted_count = 0usize;
for key in 0..total as u64 {
let emitted = should_emit_full_desync(key, false, now);
if key < DESYNC_DEDUP_MAX_ENTRIES as u64 {
assert!(emitted, "keys below cap must be admitted initially");
} else {
assert!(
!emitted,
"new keys above cap must stay suppressed under sustained churn"
);
if emitted {
emitted_count += 1;
}
}
assert_eq!(
emitted_count,
DESYNC_DEDUP_MAX_ENTRIES + 1,
"after capacity is reached, same-tick newcomer churn must be rate-limited"
);
let len = DESYNC_DEDUP
.get()
.expect("dedup cache must be initialized by stress run")
@@ -318,6 +359,282 @@ fn stress_desync_dedup_churn_keeps_cache_hard_bounded() {
);
}
#[test]
fn full_cache_newcomer_emission_is_rate_limited_but_periodic() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("desync dedup test lock must be available");
clear_desync_dedup_for_testing();
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
let base_now = Instant::now();
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
dedup.insert(key, base_now - TokioDuration::from_millis(10));
}
// Same-tick newcomer storm: only the first should emit full forensic record.
let mut burst_emits = 0usize;
for i in 0..1024u64 {
if should_emit_full_desync(10_000_000 + i, false, base_now) {
burst_emits += 1;
}
}
assert_eq!(
burst_emits, 1,
"full-cache newcomer burst must be bounded to a single full emit per interval"
);
// After each interval elapses, one newcomer may emit again.
for step in 1..=6u64 {
let t = base_now + DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL * step as u32;
assert!(
should_emit_full_desync(20_000_000 + step, false, t),
"full-cache newcomer should re-emit once interval has elapsed"
);
assert!(
!should_emit_full_desync(30_000_000 + step, false, t),
"additional newcomers in the same interval tick must remain suppressed"
);
}
}
#[test]
fn full_cache_mode_override_emits_every_event() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("desync dedup test lock must be available");
clear_desync_dedup_for_testing();
let now = Instant::now();
for i in 0..10_000u64 {
assert!(
should_emit_full_desync(100_000_000 + i, true, now),
"desync_all_full override must bypass dedup and rate-limit suppression"
);
}
}
#[test]
fn report_desync_stats_follow_rate_limited_full_cache_policy() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("desync dedup test lock must be available");
clear_desync_dedup_for_testing();
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
let base_now = Instant::now();
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
dedup.insert(key, base_now - TokioDuration::from_millis(10));
}
let stats = Stats::new();
let mut state = make_forensics_state();
state.started_at = base_now;
for i in 0..128u64 {
state.peer_hash = 0xABC0_0000_0000_0000u64 ^ i;
let _ = report_desync_frame_too_large(
&state,
ProtoTag::Secure,
3,
1024,
4096,
Some([0x16, 0x03, 0x03, 0x00]),
&stats,
);
}
assert_eq!(
stats.get_desync_total(),
128,
"every detected desync must increment total counter"
);
assert_eq!(
stats.get_desync_full_logged(),
1,
"same-interval full-cache newcomer storm must allow only one full forensic emit"
);
assert_eq!(
stats.get_desync_suppressed(),
127,
"remaining same-interval full-cache newcomer events must be suppressed"
);
// After one full interval in real wall clock, a newcomer should emit again.
thread::sleep(DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL + TokioDuration::from_millis(20));
state.peer_hash = 0xDEAD_BEEF_DEAD_BEEFu64;
let _ = report_desync_frame_too_large(
&state,
ProtoTag::Secure,
4,
1024,
4097,
Some([0x16, 0x03, 0x03, 0x01]),
&stats,
);
assert_eq!(
stats.get_desync_full_logged(),
2,
"full forensic emission must recover after rate-limit interval"
);
}
#[test]
fn concurrent_full_cache_newcomer_storm_is_single_emit_per_interval() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("desync dedup test lock must be available");
clear_desync_dedup_for_testing();
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
let base_now = Instant::now();
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
dedup.insert(key, base_now - TokioDuration::from_millis(10));
}
let emits = Arc::new(AtomicUsize::new(0));
let mut workers = Vec::new();
for worker_id in 0..32u64 {
let emits = Arc::clone(&emits);
workers.push(thread::spawn(move || {
for i in 0..512u64 {
let key = 0x7000_0000_0000_0000u64 ^ (worker_id << 20) ^ i;
if should_emit_full_desync(key, false, base_now) {
emits.fetch_add(1, Ordering::Relaxed);
}
}
}));
}
for worker in workers {
worker.join().expect("worker thread must not panic");
}
assert_eq!(
emits.load(Ordering::Relaxed),
1,
"concurrent same-interval full-cache storm must allow only one full forensic emit"
);
}
#[test]
fn light_fuzz_full_cache_rate_limit_oracle_matches_model() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("desync dedup test lock must be available");
clear_desync_dedup_for_testing();
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
let base_now = Instant::now();
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
dedup.insert(key, base_now - TokioDuration::from_millis(10));
}
let mut rng = StdRng::seed_from_u64(0xD15EA5E5_F00DBAAD);
let mut model_last_emit: Option<Instant> = None;
for i in 0..4096u64 {
let jitter_ms: u64 = rng.random_range(0..=3000);
let t = base_now + TokioDuration::from_millis(jitter_ms);
let key = 0x55AA_0000_0000_0000u64 ^ i ^ rng.random::<u64>();
let actual = should_emit_full_desync(key, false, t);
let expected = match model_last_emit {
None => {
model_last_emit = Some(t);
true
}
Some(last) => {
match t.checked_duration_since(last) {
Some(elapsed) if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL => {
model_last_emit = Some(t);
true
}
Some(_) => false,
None => {
// Match production fail-open behavior for non-monotonic synthetic input.
model_last_emit = Some(t);
true
}
}
}
};
assert_eq!(
actual, expected,
"full-cache rate-limit gate diverged from reference model under light fuzz"
);
}
}
#[test]
fn full_cache_gate_lock_poison_is_fail_closed_without_panic() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("desync dedup test lock must be available");
clear_desync_dedup_for_testing();
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
let base_now = Instant::now();
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
dedup.insert(key, base_now - TokioDuration::from_millis(10));
}
// Poison the full-cache gate lock intentionally.
let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None));
let _ = std::panic::catch_unwind(|| {
let _lock = gate.lock().expect("gate lock must be lockable before poison");
panic!("intentional gate poison for fail-closed regression");
});
let emitted = should_emit_full_desync(0xFACE_0000_0000_0001, false, base_now);
assert!(
!emitted,
"poisoned full-cache gate must fail-closed (suppress) instead of panic or fail-open"
);
assert!(
dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES,
"dedup cache must remain bounded even when gate lock is poisoned"
);
}
#[test]
fn full_cache_non_monotonic_time_emits_and_resets_gate_safely() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("desync dedup test lock must be available");
clear_desync_dedup_for_testing();
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
let base_now = Instant::now();
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
dedup.insert(key, base_now - TokioDuration::from_millis(10));
}
// First event seeds the gate.
assert!(should_emit_full_desync(
0xABCD_0000_0000_0001,
false,
base_now + TokioDuration::from_millis(900)
));
// Synthetic earlier timestamp must not panic; it should fail-open and reset gate.
assert!(should_emit_full_desync(
0xABCD_0000_0000_0002,
false,
base_now + TokioDuration::from_millis(100)
));
// Same instant again remains suppressed after reset.
assert!(!should_emit_full_desync(
0xABCD_0000_0000_0003,
false,
base_now + TokioDuration::from_millis(100)
));
}
#[test]
fn desync_dedup_full_cache_inserts_new_key_with_bounded_single_key_churn() {
let _guard = desync_dedup_test_lock()
@@ -338,8 +655,8 @@ fn desync_dedup_full_cache_inserts_new_key_with_bounded_single_key_churn() {
let newcomer_key = u64::MAX;
let emitted = should_emit_full_desync(newcomer_key, false, base_now);
assert!(
!emitted,
"new entry under full fresh cache must stay suppressed"
emitted,
"new entry under full fresh cache must emit after bounded eviction"
);
assert!(
dedup.get(&newcomer_key).is_some(),
@@ -406,6 +723,24 @@ fn light_fuzz_desync_dedup_temporal_gate_behavior_is_stable() {
panic!("expected at least one post-window sample to re-emit forensic record");
}
#[test]
#[ignore = "Tracking for M-04: Verify should_emit_full_desync returns true on first occurrence and false on duplicate within window"]
fn should_emit_full_desync_filters_duplicates() {
unimplemented!("Stub for M-04");
}
#[test]
#[ignore = "Tracking for M-04: Verify desync dedup eviction behaves correctly under map-full condition"]
fn desync_dedup_eviction_under_map_full_condition() {
unimplemented!("Stub for M-04");
}
#[tokio::test]
#[ignore = "Tracking for M-05: Verify C2ME channel full path yields then sends under backpressure"]
async fn c2me_channel_full_path_yields_then_sends() {
unimplemented!("Stub for M-05");
}
fn make_forensics_state() -> RelayForensicsState {
RelayForensicsState {
trace_id: 1,
@@ -974,6 +1309,7 @@ async fn process_me_writer_response_ack_obeys_flush_policy() {
&mut frame_buf,
&stats,
"user",
None,
&bytes_me2c,
77,
true,
@@ -999,6 +1335,7 @@ async fn process_me_writer_response_ack_obeys_flush_policy() {
&mut frame_buf,
&stats,
"user",
None,
&bytes_me2c,
77,
false,
@@ -1038,6 +1375,7 @@ async fn process_me_writer_response_data_updates_byte_accounting() {
&mut frame_buf,
&stats,
"user",
None,
&bytes_me2c,
88,
false,
@@ -1061,6 +1399,162 @@ async fn process_me_writer_response_data_updates_byte_accounting() {
);
}
#[tokio::test]
async fn process_me_writer_response_data_enforces_live_user_quota() {
let (writer_side, mut reader_side) = duplex(1024);
let mut writer = make_crypto_writer(writer_side);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let stats = Stats::new();
let bytes_me2c = AtomicU64::new(0);
stats.add_user_octets_from("quota-user", 10);
let result = process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from(vec![1u8, 2, 3, 4]),
},
&mut writer,
ProtoTag::Intermediate,
&rng,
&mut frame_buf,
&stats,
"quota-user",
Some(12),
&bytes_me2c,
89,
false,
false,
)
.await;
assert!(
matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "quota-user"),
"ME->client runtime path must terminate when live user quota is crossed"
);
let mut raw = [0u8; 1];
assert!(
timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw))
.await
.is_err(),
"quota exhaustion must not write any ciphertext to the client stream"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn process_me_writer_response_concurrent_same_user_quota_does_not_overshoot_limit() {
let stats = Stats::new();
let bytes_me2c = AtomicU64::new(0);
let user = "quota-race-user";
let (writer_side_a, _reader_side_a) = duplex(1024);
let (writer_side_b, _reader_side_b) = duplex(1024);
let mut writer_a = make_crypto_writer(writer_side_a);
let mut writer_b = make_crypto_writer(writer_side_b);
let mut frame_buf_a = Vec::new();
let mut frame_buf_b = Vec::new();
let rng_a = SecureRandom::new();
let rng_b = SecureRandom::new();
let fut_a = process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x11]),
},
&mut writer_a,
ProtoTag::Intermediate,
&rng_a,
&mut frame_buf_a,
&stats,
user,
Some(1),
&bytes_me2c,
91,
false,
false,
);
let fut_b = process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x22]),
},
&mut writer_b,
ProtoTag::Intermediate,
&rng_b,
&mut frame_buf_b,
&stats,
user,
Some(1),
&bytes_me2c,
92,
false,
false,
);
let (result_a, result_b) = tokio::join!(fut_a, fut_b);
assert!(
matches!(result_a, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user")
|| matches!(result_a, Ok(_)),
"concurrent quota test must complete without panicking"
);
assert!(
matches!(result_b, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user")
|| matches!(result_b, Ok(_)),
"concurrent quota test must complete without panicking"
);
assert!(
stats.get_user_total_octets(user) <= 1,
"same-user concurrent middle-relay responses must not overshoot the configured quota"
);
}
#[tokio::test]
async fn process_me_writer_response_data_does_not_forward_partial_payload_when_remaining_quota_is_smaller_than_message() {
let (writer_side, mut reader_side) = duplex(1024);
let mut writer = make_crypto_writer(writer_side);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let stats = Stats::new();
let bytes_me2c = AtomicU64::new(0);
stats.add_user_octets_to("partial-quota-user", 3);
let result = process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from(vec![1u8, 2, 3, 4]),
},
&mut writer,
ProtoTag::Intermediate,
&rng,
&mut frame_buf,
&stats,
"partial-quota-user",
Some(4),
&bytes_me2c,
90,
false,
false,
)
.await;
assert!(
matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "partial-quota-user"),
"ME->client runtime path must reject oversized payloads before writing"
);
let mut raw = [0u8; 1];
assert!(
timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw))
.await
.is_err(),
"oversized payloads must not leak any partial ciphertext to the client stream"
);
}
#[tokio::test]
async fn middle_relay_abort_midflight_releases_route_gauge() {
let stats = Arc::new(Stats::new());

View File

@@ -53,16 +53,17 @@
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use dashmap::DashMap;
use tokio::io::{
AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes,
};
use tokio::time::Instant;
use tracing::{debug, trace, warn};
use crate::error::Result;
use crate::error::{ProxyError, Result};
use crate::stats::Stats;
use crate::stream::BufferPool;
@@ -205,6 +206,8 @@ struct StatsIo<S> {
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
epoch: Instant,
}
@@ -214,11 +217,62 @@ impl<S> StatsIo<S> {
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
epoch: Instant,
) -> Self {
// Mark initial activity so the watchdog doesn't fire before data flows
counters.touch(Instant::now(), epoch);
Self { inner, counters, stats, user, epoch }
Self {
inner,
counters,
stats,
user,
quota_limit,
quota_exceeded,
epoch,
}
}
}
#[derive(Debug)]
struct QuotaIoSentinel;
impl std::fmt::Display for QuotaIoSentinel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("user data quota exceeded")
}
}
impl std::error::Error for QuotaIoSentinel {}
fn quota_io_error() -> io::Error {
io::Error::new(io::ErrorKind::PermissionDenied, QuotaIoSentinel)
}
fn is_quota_io_error(err: &io::Error) -> bool {
err.kind() == io::ErrorKind::PermissionDenied
&& err
.get_ref()
.and_then(|source| source.downcast_ref::<QuotaIoSentinel>())
.is_some()
}
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
let created = Arc::new(Mutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
}
@@ -229,6 +283,32 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
if this.quota_exceeded.load(Ordering::Relaxed) {
return Poll::Ready(Err(quota_io_error()));
}
let quota_lock = this
.quota_limit
.is_some()
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => Some(guard),
Err(_) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
}
} else {
None
};
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
let before = buf.filled().len();
match Pin::new(&mut this.inner).poll_read(cx, buf) {
@@ -243,6 +323,13 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
this.stats.add_user_octets_from(&this.user, n as u64);
this.stats.increment_user_msgs_from(&this.user);
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
trace!(user = %this.user, bytes = n, "C->S");
}
Poll::Ready(Ok(()))
@@ -259,8 +346,46 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if this.quota_exceeded.load(Ordering::Relaxed) {
return Poll::Ready(Err(quota_io_error()));
}
match Pin::new(&mut this.inner).poll_write(cx, buf) {
let quota_lock = this
.quota_limit
.is_some()
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => Some(guard),
Err(_) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
}
} else {
None
};
let write_buf = if let Some(limit) = this.quota_limit {
let used = this.stats.get_user_total_octets(&this.user);
if used >= limit {
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
let remaining = (limit - used) as usize;
if buf.len() > remaining {
// Fail closed: do not emit partial S->C payload when remaining
// quota cannot accommodate the pending write request.
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
buf
} else {
buf
};
match Pin::new(&mut this.inner).poll_write(cx, write_buf) {
Poll::Ready(Ok(n)) => {
if n > 0 {
// S→C: data written to client
@@ -271,6 +396,13 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
this.stats.add_user_octets_to(&this.user, n as u64);
this.stats.increment_user_msgs_to(&this.user);
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
trace!(user = %this.user, bytes = n, "S->C");
}
Poll::Ready(Ok(n))
@@ -307,7 +439,8 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
/// - Per-user stats: bytes and ops counted per direction
/// - Periodic rate logging: every 10 seconds when active
/// - Clean shutdown: both write sides are shut down on exit
/// - Error propagation: I/O errors are returned as `ProxyError::Io`
/// - Error propagation: quota exits return `ProxyError::DataQuotaExceeded`,
/// other I/O failures are returned as `ProxyError::Io`
pub async fn relay_bidirectional<CR, CW, SR, SW>(
client_reader: CR,
client_writer: CW,
@@ -317,6 +450,7 @@ pub async fn relay_bidirectional<CR, CW, SR, SW>(
s2c_buf_size: usize,
user: &str,
stats: Arc<Stats>,
quota_limit: Option<u64>,
_buffer_pool: Arc<BufferPool>,
) -> Result<()>
where
@@ -327,6 +461,7 @@ where
{
let epoch = Instant::now();
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let user_owned = user.to_string();
// ── Combine split halves into bidirectional streams ──────────────
@@ -339,12 +474,15 @@ where
Arc::clone(&counters),
Arc::clone(&stats),
user_owned.clone(),
quota_limit,
Arc::clone(&quota_exceeded),
epoch,
);
// ── Watchdog: activity timeout + periodic rate logging ──────────
let wd_counters = Arc::clone(&counters);
let wd_user = user_owned.clone();
let wd_quota_exceeded = Arc::clone(&quota_exceeded);
let watchdog = async {
let mut prev_c2s: u64 = 0;
@@ -356,6 +494,11 @@ where
let now = Instant::now();
let idle = wd_counters.idle_duration(now, epoch);
if wd_quota_exceeded.load(Ordering::Relaxed) {
warn!(user = %wd_user, "User data quota reached, closing relay");
return;
}
// ── Activity timeout ────────────────────────────────────
if idle >= ACTIVITY_TIMEOUT {
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
@@ -439,6 +582,22 @@ where
);
Ok(())
}
Some(Err(e)) if is_quota_io_error(&e) => {
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
let s2c = counters.s2c_bytes.load(Ordering::Relaxed);
warn!(
user = %user_owned,
c2s_bytes = c2s,
s2c_bytes = s2c,
c2s_msgs = c2s_ops,
s2c_msgs = s2c_ops,
duration_secs = duration.as_secs(),
"Data quota reached, closing relay"
);
Err(ProxyError::DataQuotaExceeded {
user: user_owned.clone(),
})
}
Some(Err(e)) => {
// I/O error in one of the directions
let c2s = counters.c2s_bytes.load(Ordering::Relaxed);
@@ -472,3 +631,7 @@ where
}
}
}
#[cfg(test)]
#[path = "relay_security_tests.rs"]
mod security_tests;

View File

@@ -0,0 +1,972 @@
use super::relay_bidirectional;
use crate::error::ProxyError;
use crate::stats::Stats;
use crate::stream::BufferPool;
use std::future::poll_fn;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Mutex;
use std::task::{Context, Poll};
use std::task::Waker;
use tokio::io::{AsyncRead, ReadBuf};
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex};
use tokio::time::{Duration, timeout};
#[tokio::test]
async fn relay_bidirectional_enforces_live_user_quota() {
let stats = Arc::new(Stats::new());
let user = "quota-user";
stats.add_user_octets_from(user, 6);
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
user,
Arc::clone(&stats),
Some(8),
Arc::new(BufferPool::new()),
));
client_peer
.write_all(&[0x10, 0x20, 0x30, 0x40])
.await
.expect("client write must succeed");
let mut forwarded = [0u8; 4];
let _ = timeout(
Duration::from_millis(200),
server_peer.read_exact(&mut forwarded),
)
.await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under quota cutoff")
.expect("relay task must not panic");
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-user"),
"relay must surface a typed quota error once live quota is exceeded"
);
}
#[tokio::test]
async fn relay_bidirectional_does_not_forward_server_bytes_after_quota_is_exhausted() {
let stats = Arc::new(Stats::new());
let quota_user = "quota-exhausted-user";
stats.add_user_octets_from(quota_user, 1);
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(1),
Arc::new(BufferPool::new()),
));
server_peer
.write_all(&[0xde, 0xad, 0xbe, 0xef])
.await
.expect("server write must succeed");
let mut observed = [0u8; 4];
let forwarded = timeout(
Duration::from_millis(200),
client_peer.read_exact(&mut observed),
)
.await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under quota cutoff")
.expect("relay task must not panic");
assert!(
!matches!(forwarded, Ok(Ok(n)) if n == observed.len()),
"no full server payload should be forwarded once quota is already exhausted"
);
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"relay must still terminate with a typed quota error"
);
}
#[tokio::test]
async fn relay_bidirectional_does_not_leak_partial_server_payload_when_remaining_quota_is_smaller_than_write() {
let stats = Arc::new(Stats::new());
let quota_user = "partial-leak-user";
stats.add_user_octets_from(quota_user, 3);
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(4),
Arc::new(BufferPool::new()),
));
server_peer
.write_all(&[0x11, 0x22, 0x33, 0x44])
.await
.expect("server write must succeed");
let mut observed = [0u8; 8];
let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under quota cutoff")
.expect("relay task must not panic");
assert!(
!matches!(forwarded, Ok(Ok(n)) if n > 0),
"quota exhaustion must not leak any partial server payload when remaining quota is smaller than the write"
);
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"relay must still terminate with a typed quota error"
);
}
#[tokio::test]
async fn relay_bidirectional_zero_quota_remains_fail_closed_for_server_payloads_under_stress() {
let stats = Arc::new(Stats::new());
let quota_user = "zero-quota-user";
for payload_len in [1usize, 16, 512, 4096] {
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(0),
Arc::new(BufferPool::new()),
));
let payload = vec![0x7f; payload_len];
let _ = server_peer.write_all(&payload).await;
let mut observed = vec![0u8; payload_len];
let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under zero-quota cutoff")
.expect("relay task must not panic");
assert!(
!matches!(forwarded, Ok(Ok(n)) if n > 0),
"zero quota must not forward any server bytes for payload_len={payload_len}"
);
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"zero quota must terminate with the typed quota error for payload_len={payload_len}"
);
}
}
#[tokio::test]
async fn relay_bidirectional_allows_exact_server_payload_at_quota_boundary() {
let stats = Arc::new(Stats::new());
let quota_user = "exact-boundary-user";
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(4),
Arc::new(BufferPool::new()),
));
server_peer
.write_all(&[0x91, 0x92, 0x93, 0x94])
.await
.expect("server write must succeed at exact quota boundary");
let mut observed = [0u8; 4];
client_peer
.read_exact(&mut observed)
.await
.expect("client must receive the full payload at the exact quota boundary");
assert_eq!(observed, [0x91, 0x92, 0x93, 0x94]);
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish after exact boundary delivery")
.expect("relay task must not panic");
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"relay must close with a typed quota error after reaching the exact boundary"
);
}
#[tokio::test]
async fn relay_bidirectional_does_not_forward_client_bytes_after_quota_is_exhausted() {
let stats = Arc::new(Stats::new());
let quota_user = "client-exhausted-user";
stats.add_user_octets_from(quota_user, 1);
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(1),
Arc::new(BufferPool::new()),
));
client_peer
.write_all(&[0x51, 0x52, 0x53, 0x54])
.await
.expect("client write must succeed even when quota is already exhausted");
let mut observed = [0u8; 4];
let forwarded = timeout(
Duration::from_millis(200),
server_peer.read_exact(&mut observed),
)
.await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under quota cutoff")
.expect("relay task must not panic");
assert!(
!matches!(forwarded, Ok(Ok(n)) if n == observed.len()),
"client payload must not be fully forwarded once quota is already exhausted"
);
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"relay must still terminate with a typed quota error"
);
}
#[tokio::test]
async fn relay_bidirectional_server_bytes_remain_blocked_even_under_multiple_payload_sizes() {
let stats = Arc::new(Stats::new());
let quota_user = "quota-fuzz-user";
stats.add_user_octets_from(quota_user, 2);
for payload_len in [1usize, 32, 1024, 8192] {
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
quota_user,
Arc::clone(&stats),
Some(2),
Arc::new(BufferPool::new()),
));
let payload = vec![0xaa; payload_len];
let _ = server_peer.write_all(&payload).await;
let mut observed = vec![0u8; payload_len];
let forwarded = timeout(
Duration::from_millis(200),
client_peer.read_exact(&mut observed),
)
.await;
let relay_result = timeout(Duration::from_secs(2), relay_task)
.await
.expect("relay task must finish under quota cutoff")
.expect("relay task must not panic");
assert!(
!matches!(forwarded, Ok(Ok(n)) if n == payload_len),
"quota exhaustion must block full server-to-client forwarding for payload_len={payload_len}"
);
assert!(
matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user),
"relay must keep returning the typed quota error for payload_len={payload_len}"
);
}
}
#[tokio::test]
async fn relay_bidirectional_terminates_on_activity_timeout() {
tokio::time::pause();
let stats = Arc::new(Stats::new());
let user = "timeout-user";
let (client_peer, relay_client) = duplex(4096);
let (relay_server, server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
user,
Arc::clone(&stats),
None, // No quota
Arc::new(BufferPool::new()),
));
// Wait past the activity timeout threshold (1800 seconds) + buffer
tokio::time::sleep(Duration::from_secs(1805)).await;
// Resume time to process timeouts
tokio::time::resume();
let relay_result = timeout(Duration::from_secs(1), relay_task)
.await
.expect("relay task must finish inside bounded timeout due to inactivity cutoff")
.expect("relay task must not panic");
assert!(
relay_result.is_ok(),
"relay should complete successfully on scheduled inactivity timeout"
);
// Verify client/server sockets are closed
drop(client_peer);
drop(server_peer);
}
#[tokio::test]
async fn relay_bidirectional_watchdog_resists_premature_execution() {
tokio::time::pause();
let stats = Arc::new(Stats::new());
let user = "activity-user";
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let mut relay_task = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
user,
Arc::clone(&stats),
None,
Arc::new(BufferPool::new()),
));
// Advance by half the timeout
tokio::time::sleep(Duration::from_secs(900)).await;
// Provide activity
client_peer
.write_all(&[0xaa, 0xbb])
.await
.expect("client write must succeed");
client_peer.flush().await.unwrap();
// Advance by another half (total time since start is 1800, but since last activity is 900)
tokio::time::sleep(Duration::from_secs(900)).await;
tokio::time::resume();
// Re-evaluating the task, it should NOT have timed out and still be pending
let relay_result = timeout(Duration::from_millis(100), &mut relay_task).await;
assert!(
relay_result.is_err(),
"Relay must not exit prematurely as long as activity was received before timeout"
);
// Explicitly drop sockets to cleanly shut down relay loop
drop(client_peer);
drop(server_peer);
let completion = timeout(Duration::from_secs(1), relay_task).await
.expect("relay task must complete securely after client disconnection")
.expect("relay task must not panic");
assert!(completion.is_ok(), "relay exits clean");
}
#[tokio::test]
async fn relay_bidirectional_half_closure_terminates_cleanly() {
let stats = Arc::new(Stats::new());
let (client_peer, relay_client) = duplex(4096);
let (relay_server, server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader, client_writer, server_reader, server_writer, 1024, 1024, "half-close", stats, None, Arc::new(BufferPool::new()),
));
// Half closure: drop the client completely but leave the server active.
drop(client_peer);
// Check that we don't immediately crash. Bidirectional relay stays open for the server -> client flush.
// Eventually dropping the server cleanly closes the task.
drop(server_peer);
timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap();
}
#[tokio::test]
async fn relay_bidirectional_zero_length_noise_fuzzing() {
let stats = Arc::new(Stats::new());
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader, client_writer, server_reader, server_writer, 1024, 1024, "fuzz", stats, None, Arc::new(BufferPool::new()),
));
// Flood with zero-length payloads (edge cases in stream framing logic sometimes loop)
for _ in 0..100 {
client_peer.write_all(&[]).await.unwrap();
}
client_peer.write_all(&[1, 2, 3]).await.unwrap();
client_peer.flush().await.unwrap();
let mut buf = [0u8; 3];
server_peer.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, &[1, 2, 3]);
drop(client_peer);
drop(server_peer);
timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap();
}
#[tokio::test]
async fn relay_bidirectional_asymmetric_backpressure() {
let stats = Arc::new(Stats::new());
// Give the client stream an extremely narrow throughput limit explicitly
let (client_peer, relay_client) = duplex(1024);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_task = tokio::spawn(relay_bidirectional(
client_reader, client_writer, server_reader, server_writer, 1024, 1024, "slowloris", stats, None, Arc::new(BufferPool::new()),
));
let payload = vec![0xba; 65536]; // 64k payload
// Server attempts to shove 64KB into a relay whose client pipe only holds 1KB!
let write_res = tokio::time::timeout(Duration::from_millis(50), server_peer.write_all(&payload)).await;
assert!(
write_res.is_err(),
"Relay backpressure MUST halt the server writer from unbounded buffering when client stream is full!"
);
drop(client_peer);
drop(server_peer);
let completion = timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap();
assert!(
completion.is_ok() || completion.is_err(),
"Task must unwind reliably (either Ok or BrokenPipe Err) when dropped despite active backpressure locks"
);
}
use rand::{Rng, SeedableRng, rngs::StdRng};
#[tokio::test]
async fn relay_bidirectional_light_fuzzing_temporal_jitter() {
tokio::time::pause();
let stats = Arc::new(Stats::new());
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let mut relay_task = tokio::spawn(relay_bidirectional(
client_reader, client_writer, server_reader, server_writer, 1024, 1024, "fuzz-user", stats, None, Arc::new(BufferPool::new()),
));
let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
for _ in 0..10 {
// Vary timing significantly up to 1600 seconds (limit is 1800s)
let jitter = rng.random_range(100..1600);
tokio::time::sleep(Duration::from_secs(jitter)).await;
client_peer.write_all(&[0x11]).await.unwrap();
client_peer.flush().await.unwrap();
// Ensure task has not died
let res = timeout(Duration::from_millis(10), &mut relay_task).await;
assert!(res.is_err(), "Relay must remain open indefinitely under light temporal fuzzing with active jitter pulses");
}
drop(client_peer);
drop(server_peer);
timeout(Duration::from_secs(1), relay_task).await.unwrap().unwrap().unwrap();
}
struct FaultyReader {
error_once: Option<io::Error>,
}
struct TwoPartyGate {
arrivals: AtomicUsize,
total_bytes: AtomicUsize,
wakers: Mutex<Vec<Waker>>,
}
impl TwoPartyGate {
fn new() -> Self {
Self {
arrivals: AtomicUsize::new(0),
total_bytes: AtomicUsize::new(0),
wakers: Mutex::new(Vec::new()),
}
}
fn arrive_or_park(&self, cx: &mut Context<'_>) -> bool {
if self.arrivals.load(Ordering::Relaxed) >= 2 {
return true;
}
let prev = self.arrivals.fetch_add(1, Ordering::AcqRel);
if prev + 1 >= 2 {
let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner());
for waker in wakers.drain(..) {
waker.wake();
}
true
} else {
let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner());
wakers.push(cx.waker().clone());
false
}
}
fn total_bytes(&self) -> usize {
self.total_bytes.load(Ordering::Relaxed)
}
}
struct GateWriter {
gate: Arc<TwoPartyGate>,
entered: bool,
}
impl GateWriter {
fn new(gate: Arc<TwoPartyGate>) -> Self {
Self {
gate,
entered: false,
}
}
}
impl AsyncWrite for GateWriter {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if !self.entered {
self.entered = true;
}
if !self.gate.arrive_or_park(cx) {
return Poll::Pending;
}
self.gate
.total_bytes
.fetch_add(buf.len(), Ordering::Relaxed);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
struct GateReader {
gate: Arc<TwoPartyGate>,
entered: bool,
emitted: bool,
}
impl GateReader {
fn new(gate: Arc<TwoPartyGate>) -> Self {
Self {
gate,
entered: false,
emitted: false,
}
}
}
impl AsyncRead for GateReader {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if self.emitted {
return Poll::Ready(Ok(()));
}
if !self.entered {
self.entered = true;
}
if !self.gate.arrive_or_park(cx) {
return Poll::Pending;
}
buf.put_slice(&[0x42]);
self.gate.total_bytes.fetch_add(1, Ordering::Relaxed);
self.emitted = true;
Poll::Ready(Ok(()))
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() {
let stats = Arc::new(Stats::new());
let gate = Arc::new(TwoPartyGate::new());
let user = "concurrent-quota-write".to_string();
let writer_a = super::StatsIo::new(
GateWriter::new(Arc::clone(&gate)),
Arc::new(super::SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(1),
Arc::new(std::sync::atomic::AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let writer_b = super::StatsIo::new(
GateWriter::new(Arc::clone(&gate)),
Arc::new(super::SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(1),
Arc::new(std::sync::atomic::AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let task_a = tokio::spawn(async move {
let mut w = writer_a;
AsyncWriteExt::write_all(&mut w, &[0x01]).await
});
let task_b = tokio::spawn(async move {
let mut w = writer_b;
AsyncWriteExt::write_all(&mut w, &[0x02]).await
});
let (res_a, res_b) = tokio::join!(task_a, task_b);
let _ = res_a.expect("task a must join");
let _ = res_b.expect("task b must join");
assert!(
gate.total_bytes() <= 1,
"concurrent same-user writes must not forward more than one byte under quota=1"
);
assert!(
stats.get_user_total_octets(&user) <= 1,
"concurrent same-user writes must not account over limit"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() {
let stats = Arc::new(Stats::new());
let gate = Arc::new(TwoPartyGate::new());
let user = "concurrent-quota-read".to_string();
let reader_a = super::StatsIo::new(
GateReader::new(Arc::clone(&gate)),
Arc::new(super::SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(1),
Arc::new(std::sync::atomic::AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let reader_b = super::StatsIo::new(
GateReader::new(Arc::clone(&gate)),
Arc::new(super::SharedCounters::new()),
Arc::clone(&stats),
user.clone(),
Some(1),
Arc::new(std::sync::atomic::AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let task_a = tokio::spawn(async move {
let mut r = reader_a;
let mut one = [0u8; 1];
AsyncReadExt::read_exact(&mut r, &mut one).await
});
let task_b = tokio::spawn(async move {
let mut r = reader_b;
let mut one = [0u8; 1];
AsyncReadExt::read_exact(&mut r, &mut one).await
});
let (res_a, res_b) = tokio::join!(task_a, task_b);
let _ = res_a.expect("task a must join");
let _ = res_b.expect("task b must join");
assert!(
gate.total_bytes() <= 1,
"concurrent same-user reads must not consume more than one byte under quota=1"
);
assert!(
stats.get_user_total_octets(&user) <= 1,
"concurrent same-user reads must not account over limit"
);
}
#[tokio::test]
async fn stress_same_user_quota_parallel_relays_never_exceed_limit() {
let stats = Arc::new(Stats::new());
let user = "parallel-quota-user";
for _ in 0..128 {
let (mut client_peer_a, relay_client_a) = duplex(256);
let (relay_server_a, mut server_peer_a) = duplex(256);
let (mut client_peer_b, relay_client_b) = duplex(256);
let (relay_server_b, mut server_peer_b) = duplex(256);
let (client_reader_a, client_writer_a) = tokio::io::split(relay_client_a);
let (server_reader_a, server_writer_a) = tokio::io::split(relay_server_a);
let (client_reader_b, client_writer_b) = tokio::io::split(relay_client_b);
let (server_reader_b, server_writer_b) = tokio::io::split(relay_server_b);
let relay_a = tokio::spawn(relay_bidirectional(
client_reader_a,
client_writer_a,
server_reader_a,
server_writer_a,
64,
64,
user,
Arc::clone(&stats),
Some(1),
Arc::new(BufferPool::new()),
));
let relay_b = tokio::spawn(relay_bidirectional(
client_reader_b,
client_writer_b,
server_reader_b,
server_writer_b,
64,
64,
user,
Arc::clone(&stats),
Some(1),
Arc::new(BufferPool::new()),
));
let _ = tokio::join!(
client_peer_a.write_all(&[0x01]),
server_peer_a.write_all(&[0x02]),
client_peer_b.write_all(&[0x03]),
server_peer_b.write_all(&[0x04]),
);
let _ = timeout(Duration::from_millis(50), poll_fn(|cx| {
let mut one = [0u8; 1];
let _ = Pin::new(&mut client_peer_a).poll_read(cx, &mut ReadBuf::new(&mut one));
Poll::Ready(())
}))
.await;
drop(client_peer_a);
drop(server_peer_a);
drop(client_peer_b);
drop(server_peer_b);
let _ = timeout(Duration::from_secs(1), relay_a).await;
let _ = timeout(Duration::from_secs(1), relay_b).await;
assert!(
stats.get_user_total_octets(user) <= 1,
"parallel relays must not exceed configured quota"
);
}
}
impl FaultyReader {
fn permission_denied_with_message(message: impl Into<String>) -> Self {
Self {
error_once: Some(io::Error::new(io::ErrorKind::PermissionDenied, message.into())),
}
}
}
impl AsyncRead for FaultyReader {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if let Some(err) = self.error_once.take() {
return Poll::Ready(Err(err));
}
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn relay_bidirectional_does_not_misclassify_transport_permission_denied_as_quota() {
let stats = Arc::new(Stats::new());
let (client_peer, relay_client) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let relay_result = relay_bidirectional(
client_reader,
client_writer,
FaultyReader::permission_denied_with_message("user data quota exceeded"),
tokio::io::sink(),
1024,
1024,
"non-quota-permission-denied",
Arc::clone(&stats),
None,
Arc::new(BufferPool::new()),
)
.await;
drop(client_peer);
assert!(
matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied),
"non-quota transport PermissionDenied errors must remain IO errors"
);
}
#[tokio::test]
async fn relay_bidirectional_light_fuzz_permission_denied_messages_remain_io_errors() {
let mut rng = StdRng::seed_from_u64(0xA11CE0B5);
for i in 0..128u64 {
let stats = Arc::new(Stats::new());
let (client_peer, relay_client) = duplex(1024);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let random_len = rng.random_range(1..=48);
let mut msg = String::with_capacity(random_len);
for _ in 0..random_len {
let ch = (b'a' + (rng.random::<u8>() % 26)) as char;
msg.push(ch);
}
// Include the legacy quota string in a subset of fuzz cases to validate
// collision resistance against message-based classification.
if i % 7 == 0 {
msg = "user data quota exceeded".to_string();
}
let relay_result = relay_bidirectional(
client_reader,
client_writer,
FaultyReader::permission_denied_with_message(msg),
tokio::io::sink(),
1024,
1024,
"fuzz-perm-denied",
Arc::clone(&stats),
None,
Arc::new(BufferPool::new()),
)
.await;
drop(client_peer);
assert!(
matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied),
"transport PermissionDenied case must stay typed as IO regardless of message content"
);
}
}