mirror of https://github.com/telemt/telemt.git
feat(proxy): refactor auth probe failure handling and add concurrent failure tests
This commit is contained in:
parent
0c6bb3a641
commit
93caab1aec
|
|
@ -11,6 +11,7 @@ use std::collections::hash_map::DefaultHasher;
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
|
use dashmap::mapref::entry::Entry;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||||
use tracing::{debug, warn, trace};
|
use tracing::{debug, warn, trace};
|
||||||
use zeroize::Zeroize;
|
use zeroize::Zeroize;
|
||||||
|
|
@ -118,20 +119,29 @@ fn auth_probe_record_failure_with_state(
|
||||||
peer_ip: IpAddr,
|
peer_ip: IpAddr,
|
||||||
now: Instant,
|
now: Instant,
|
||||||
) {
|
) {
|
||||||
if let Some(mut entry) = state.get_mut(&peer_ip) {
|
let make_new_state = || AuthProbeState {
|
||||||
if auth_probe_state_expired(&entry, now) {
|
fail_streak: 1,
|
||||||
*entry = AuthProbeState {
|
blocked_until: now + auth_probe_backoff(1),
|
||||||
fail_streak: 1,
|
last_seen: now,
|
||||||
blocked_until: now + auth_probe_backoff(1),
|
};
|
||||||
last_seen: now,
|
|
||||||
};
|
let update_existing = |entry: &mut AuthProbeState| {
|
||||||
|
if auth_probe_state_expired(entry, now) {
|
||||||
|
*entry = make_new_state();
|
||||||
|
} else {
|
||||||
|
entry.fail_streak = entry.fail_streak.saturating_add(1);
|
||||||
|
entry.last_seen = now;
|
||||||
|
entry.blocked_until = now + auth_probe_backoff(entry.fail_streak);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match state.entry(peer_ip) {
|
||||||
|
Entry::Occupied(mut entry) => {
|
||||||
|
update_existing(entry.get_mut());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
entry.fail_streak = entry.fail_streak.saturating_add(1);
|
Entry::Vacant(_) => {}
|
||||||
entry.last_seen = now;
|
}
|
||||||
entry.blocked_until = now + auth_probe_backoff(entry.fail_streak);
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||||
let mut stale_keys = Vec::new();
|
let mut stale_keys = Vec::new();
|
||||||
|
|
@ -155,11 +165,14 @@ fn auth_probe_record_failure_with_state(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
state.insert(peer_ip, AuthProbeState {
|
match state.entry(peer_ip) {
|
||||||
fail_streak: 1,
|
Entry::Occupied(mut entry) => {
|
||||||
blocked_until: now + auth_probe_backoff(1),
|
update_existing(entry.get_mut());
|
||||||
last_seen: now,
|
}
|
||||||
});
|
Entry::Vacant(entry) => {
|
||||||
|
entry.insert(make_new_state());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn auth_probe_record_success(peer_ip: IpAddr) {
|
fn auth_probe_record_success(peer_ip: IpAddr) {
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ use dashmap::DashMap;
|
||||||
use std::net::{IpAddr, Ipv4Addr};
|
use std::net::{IpAddr, Ipv4Addr};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
use tokio::sync::Barrier;
|
||||||
|
|
||||||
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
|
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
|
||||||
let session_id_len: usize = 32;
|
let session_id_len: usize = 32;
|
||||||
|
|
@ -994,3 +995,116 @@ fn auth_probe_eviction_offset_varies_with_input() {
|
||||||
assert_eq!(a, b, "same input must yield deterministic offset");
|
assert_eq!(a, b, "same input must yield deterministic offset");
|
||||||
assert_ne!(a, c, "different peer IPs should not collapse to one offset");
|
assert_ne!(a, c, "different peer IPs should not collapse to one offset");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||||
|
async fn auth_probe_concurrent_failures_do_not_lose_fail_streak_updates() {
|
||||||
|
let _guard = auth_probe_test_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||||
|
clear_auth_probe_state_for_testing();
|
||||||
|
|
||||||
|
let peer_ip: IpAddr = "198.51.100.90".parse().unwrap();
|
||||||
|
let tasks = 128usize;
|
||||||
|
let barrier = Arc::new(Barrier::new(tasks));
|
||||||
|
let mut handles = Vec::with_capacity(tasks);
|
||||||
|
|
||||||
|
for _ in 0..tasks {
|
||||||
|
let barrier = barrier.clone();
|
||||||
|
handles.push(tokio::spawn(async move {
|
||||||
|
barrier.wait().await;
|
||||||
|
auth_probe_record_failure(peer_ip, Instant::now());
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for handle in handles {
|
||||||
|
handle
|
||||||
|
.await
|
||||||
|
.expect("concurrent failure recording task must not panic");
|
||||||
|
}
|
||||||
|
|
||||||
|
let streak = auth_probe_fail_streak_for_testing(peer_ip)
|
||||||
|
.expect("tracked peer must exist after concurrent failure burst");
|
||||||
|
assert_eq!(
|
||||||
|
streak as usize,
|
||||||
|
tasks,
|
||||||
|
"concurrent failures for one source must account every attempt"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||||
|
async fn invalid_probe_noise_from_other_ips_does_not_break_valid_tls_handshake() {
|
||||||
|
let _guard = auth_probe_test_lock()
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||||
|
clear_auth_probe_state_for_testing();
|
||||||
|
|
||||||
|
let secret = [0x31u8; 16];
|
||||||
|
let config = Arc::new(test_config_with_secret_hex("31313131313131313131313131313131"));
|
||||||
|
let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60)));
|
||||||
|
let rng = Arc::new(SecureRandom::new());
|
||||||
|
let victim_peer: SocketAddr = "198.51.100.91:44391".parse().unwrap();
|
||||||
|
let valid = Arc::new(make_valid_tls_handshake(&secret, 0));
|
||||||
|
|
||||||
|
let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32];
|
||||||
|
invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32;
|
||||||
|
let invalid = Arc::new(invalid);
|
||||||
|
|
||||||
|
let mut noise_tasks = Vec::new();
|
||||||
|
for idx in 0..96u16 {
|
||||||
|
let config = config.clone();
|
||||||
|
let replay_checker = replay_checker.clone();
|
||||||
|
let rng = rng.clone();
|
||||||
|
let invalid = invalid.clone();
|
||||||
|
noise_tasks.push(tokio::spawn(async move {
|
||||||
|
let octet = ((idx % 200) + 1) as u8;
|
||||||
|
let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, octet)), 45000 + idx);
|
||||||
|
let result = handle_tls_handshake(
|
||||||
|
&invalid,
|
||||||
|
tokio::io::empty(),
|
||||||
|
tokio::io::sink(),
|
||||||
|
peer,
|
||||||
|
&config,
|
||||||
|
&replay_checker,
|
||||||
|
&rng,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
let victim_config = config.clone();
|
||||||
|
let victim_replay_checker = replay_checker.clone();
|
||||||
|
let victim_rng = rng.clone();
|
||||||
|
let victim_valid = valid.clone();
|
||||||
|
let victim_task = tokio::spawn(async move {
|
||||||
|
handle_tls_handshake(
|
||||||
|
&victim_valid,
|
||||||
|
tokio::io::empty(),
|
||||||
|
tokio::io::sink(),
|
||||||
|
victim_peer,
|
||||||
|
&victim_config,
|
||||||
|
&victim_replay_checker,
|
||||||
|
&victim_rng,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
});
|
||||||
|
|
||||||
|
for task in noise_tasks {
|
||||||
|
task.await.expect("noise task must not panic");
|
||||||
|
}
|
||||||
|
|
||||||
|
let victim_result = victim_task
|
||||||
|
.await
|
||||||
|
.expect("victim handshake task must not panic");
|
||||||
|
assert!(
|
||||||
|
matches!(victim_result, HandshakeResult::Success(_)),
|
||||||
|
"invalid probe noise from other IPs must not block a valid victim handshake"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
auth_probe_fail_streak_for_testing(victim_peer.ip()),
|
||||||
|
None,
|
||||||
|
"successful victim handshake must not retain pre-auth failure streak"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ use std::time::{Duration, Instant};
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use std::sync::Mutex;
|
use std::sync::Mutex;
|
||||||
|
|
||||||
use bytes::Bytes;
|
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
use tokio::sync::{mpsc, oneshot, watch};
|
use tokio::sync::{mpsc, oneshot, watch};
|
||||||
|
|
@ -24,11 +23,11 @@ use crate::proxy::route_mode::{
|
||||||
cutover_stagger_delay,
|
cutover_stagger_delay,
|
||||||
};
|
};
|
||||||
use crate::stats::Stats;
|
use crate::stats::Stats;
|
||||||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
|
||||||
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
|
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
|
||||||
|
|
||||||
enum C2MeCommand {
|
enum C2MeCommand {
|
||||||
Data { payload: Bytes, flags: u32 },
|
Data { payload: PooledBuffer, flags: u32 },
|
||||||
Close,
|
Close,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -686,7 +685,7 @@ async fn read_client_payload<R>(
|
||||||
forensics: &RelayForensicsState,
|
forensics: &RelayForensicsState,
|
||||||
frame_counter: &mut u64,
|
frame_counter: &mut u64,
|
||||||
stats: &Stats,
|
stats: &Stats,
|
||||||
) -> Result<Option<(Bytes, bool)>>
|
) -> Result<Option<(PooledBuffer, bool)>>
|
||||||
where
|
where
|
||||||
R: AsyncRead + Unpin + Send + 'static,
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
{
|
{
|
||||||
|
|
@ -807,8 +806,7 @@ where
|
||||||
payload.truncate(secure_payload_len);
|
payload.truncate(secure_payload_len);
|
||||||
}
|
}
|
||||||
*frame_counter += 1;
|
*frame_counter += 1;
|
||||||
let payload_bytes = Bytes::copy_from_slice(&payload[..]);
|
return Ok(Some((payload, quickack)));
|
||||||
return Ok(Some((payload_bytes, quickack)));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use bytes::Bytes;
|
||||||
use crate::crypto::AesCtr;
|
use crate::crypto::AesCtr;
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
use crate::stats::Stats;
|
use crate::stats::Stats;
|
||||||
use crate::stream::{BufferPool, CryptoReader};
|
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::AtomicU64;
|
use std::sync::atomic::AtomicU64;
|
||||||
|
|
@ -9,6 +11,21 @@ use tokio::io::AsyncWriteExt;
|
||||||
use tokio::io::duplex;
|
use tokio::io::duplex;
|
||||||
use tokio::time::{Duration as TokioDuration, timeout};
|
use tokio::time::{Duration as TokioDuration, timeout};
|
||||||
|
|
||||||
|
fn make_pooled_payload(data: &[u8]) -> PooledBuffer {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4));
|
||||||
|
let mut payload = pool.get();
|
||||||
|
payload.resize(data.len(), 0);
|
||||||
|
payload[..data.len()].copy_from_slice(data);
|
||||||
|
payload
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_pooled_payload_from(pool: &Arc<BufferPool>, data: &[u8]) -> PooledBuffer {
|
||||||
|
let mut payload = pool.get();
|
||||||
|
payload.resize(data.len(), 0);
|
||||||
|
payload[..data.len()].copy_from_slice(data);
|
||||||
|
payload
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_yield_sender_only_on_budget_with_backlog() {
|
fn should_yield_sender_only_on_budget_with_backlog() {
|
||||||
assert!(!should_yield_c2me_sender(0, true));
|
assert!(!should_yield_c2me_sender(0, true));
|
||||||
|
|
@ -23,7 +40,7 @@ async fn enqueue_c2me_command_uses_try_send_fast_path() {
|
||||||
enqueue_c2me_command(
|
enqueue_c2me_command(
|
||||||
&tx,
|
&tx,
|
||||||
C2MeCommand::Data {
|
C2MeCommand::Data {
|
||||||
payload: Bytes::from_static(&[1, 2, 3]),
|
payload: make_pooled_payload(&[1, 2, 3]),
|
||||||
flags: 0,
|
flags: 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -47,7 +64,7 @@ async fn enqueue_c2me_command_uses_try_send_fast_path() {
|
||||||
async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
|
async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
|
||||||
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
|
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
|
||||||
tx.send(C2MeCommand::Data {
|
tx.send(C2MeCommand::Data {
|
||||||
payload: Bytes::from_static(&[9]),
|
payload: make_pooled_payload(&[9]),
|
||||||
flags: 9,
|
flags: 9,
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
|
|
@ -58,7 +75,7 @@ async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
|
||||||
enqueue_c2me_command(
|
enqueue_c2me_command(
|
||||||
&tx2,
|
&tx2,
|
||||||
C2MeCommand::Data {
|
C2MeCommand::Data {
|
||||||
payload: Bytes::from_static(&[7, 7]),
|
payload: make_pooled_payload(&[7, 7]),
|
||||||
flags: 7,
|
flags: 7,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -84,6 +101,74 @@ async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_c2me_command_closed_channel_recycles_payload() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(64, 4));
|
||||||
|
let payload = make_pooled_payload_from(&pool, &[1, 2, 3, 4]);
|
||||||
|
let (tx, rx) = mpsc::channel::<C2MeCommand>(1);
|
||||||
|
drop(rx);
|
||||||
|
|
||||||
|
let result = enqueue_c2me_command(
|
||||||
|
&tx,
|
||||||
|
C2MeCommand::Data {
|
||||||
|
payload,
|
||||||
|
flags: 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_err(), "closed queue must fail enqueue");
|
||||||
|
drop(result);
|
||||||
|
assert!(
|
||||||
|
pool.stats().pooled >= 1,
|
||||||
|
"payload must return to pool when enqueue fails on closed channel"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_c2me_command_full_then_closed_recycles_waiting_payload() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(64, 4));
|
||||||
|
let (tx, rx) = mpsc::channel::<C2MeCommand>(1);
|
||||||
|
|
||||||
|
tx.send(C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload_from(&pool, &[9]),
|
||||||
|
flags: 1,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tx2 = tx.clone();
|
||||||
|
let pool2 = pool.clone();
|
||||||
|
let blocked_send = tokio::spawn(async move {
|
||||||
|
enqueue_c2me_command(
|
||||||
|
&tx2,
|
||||||
|
C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload_from(&pool2, &[7, 7, 7]),
|
||||||
|
flags: 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(10)).await;
|
||||||
|
drop(rx);
|
||||||
|
|
||||||
|
let result = timeout(TokioDuration::from_secs(1), blocked_send)
|
||||||
|
.await
|
||||||
|
.expect("blocked send task must finish")
|
||||||
|
.expect("blocked send task must not panic");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
result.is_err(),
|
||||||
|
"closing receiver while sender is blocked must fail enqueue"
|
||||||
|
);
|
||||||
|
drop(result);
|
||||||
|
assert!(
|
||||||
|
pool.stats().pooled >= 2,
|
||||||
|
"both queued and blocked payloads must return to pool after channel close"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn desync_dedup_cache_is_bounded() {
|
fn desync_dedup_cache_is_bounded() {
|
||||||
let _guard = desync_dedup_test_lock()
|
let _guard = desync_dedup_test_lock()
|
||||||
|
|
@ -150,6 +235,12 @@ fn make_crypto_reader(reader: tokio::io::DuplexStream) -> CryptoReader<tokio::io
|
||||||
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn make_crypto_writer(writer: tokio::io::DuplexStream) -> CryptoWriter<tokio::io::DuplexStream> {
|
||||||
|
let key = [0u8; 32];
|
||||||
|
let iv = 0u128;
|
||||||
|
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||||
|
}
|
||||||
|
|
||||||
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
|
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
|
||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
let iv = 0u128;
|
let iv = 0u128;
|
||||||
|
|
@ -476,3 +567,215 @@ async fn read_client_payload_returns_buffer_to_pool_after_emit() {
|
||||||
"emitted payload buffer must be returned to pool to avoid pool drain"
|
"emitted payload buffer must be returned to pool to avoid pool drain"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_keeps_pool_buffer_checked_out_until_frame_drop() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
|
||||||
|
let pool = Arc::new(BufferPool::with_config(64, 2));
|
||||||
|
pool.preallocate(1);
|
||||||
|
assert_eq!(pool.stats().pooled, 1, "one pooled buffer must be available");
|
||||||
|
|
||||||
|
let (reader, mut writer) = duplex(1024);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let payload = [0x41u8, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48];
|
||||||
|
let mut plaintext = Vec::with_capacity(4 + payload.len());
|
||||||
|
plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes());
|
||||||
|
plaintext.extend_from_slice(&payload);
|
||||||
|
let encrypted = encrypt_for_reader(&plaintext);
|
||||||
|
writer.write_all(&encrypted).await.unwrap();
|
||||||
|
|
||||||
|
let (frame, quickack) = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
1024,
|
||||||
|
TokioDuration::from_secs(1),
|
||||||
|
&pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("payload read must succeed")
|
||||||
|
.expect("frame must be present");
|
||||||
|
|
||||||
|
assert!(!quickack);
|
||||||
|
assert_eq!(frame.as_ref(), &payload);
|
||||||
|
assert_eq!(
|
||||||
|
pool.stats().pooled,
|
||||||
|
0,
|
||||||
|
"buffer must stay checked out while frame payload is alive"
|
||||||
|
);
|
||||||
|
|
||||||
|
drop(frame);
|
||||||
|
assert!(
|
||||||
|
pool.stats().pooled >= 1,
|
||||||
|
"buffer must return to pool only after frame drop"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_c2me_close_unblocks_after_queue_drain() {
|
||||||
|
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
|
||||||
|
tx.send(C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload(&[0x41]),
|
||||||
|
flags: 0,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tx2 = tx.clone();
|
||||||
|
let close_task = tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await });
|
||||||
|
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(10)).await;
|
||||||
|
|
||||||
|
let first = timeout(TokioDuration::from_millis(100), rx.recv())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.expect("first queued item must be present");
|
||||||
|
assert!(matches!(first, C2MeCommand::Data { .. }));
|
||||||
|
|
||||||
|
close_task.await.unwrap().expect("close enqueue must succeed after drain");
|
||||||
|
|
||||||
|
let second = timeout(TokioDuration::from_millis(100), rx.recv())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.expect("close command must follow after queue drain");
|
||||||
|
assert!(matches!(second, C2MeCommand::Close));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_c2me_close_full_then_receiver_drop_fails_cleanly() {
|
||||||
|
let (tx, rx) = mpsc::channel::<C2MeCommand>(1);
|
||||||
|
tx.send(C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload(&[0x42]),
|
||||||
|
flags: 0,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tx2 = tx.clone();
|
||||||
|
let close_task = tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await });
|
||||||
|
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(10)).await;
|
||||||
|
drop(rx);
|
||||||
|
|
||||||
|
let result = timeout(TokioDuration::from_secs(1), close_task)
|
||||||
|
.await
|
||||||
|
.expect("close task must finish")
|
||||||
|
.expect("close task must not panic");
|
||||||
|
assert!(
|
||||||
|
result.is_err(),
|
||||||
|
"close enqueue must fail cleanly when receiver is dropped under pressure"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn process_me_writer_response_ack_obeys_flush_policy() {
|
||||||
|
let (writer_side, _reader_side) = duplex(1024);
|
||||||
|
let mut writer = make_crypto_writer(writer_side);
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
let mut frame_buf = Vec::new();
|
||||||
|
let stats = Stats::new();
|
||||||
|
let bytes_me2c = AtomicU64::new(0);
|
||||||
|
|
||||||
|
let immediate = process_me_writer_response(
|
||||||
|
MeResponse::Ack(0x11223344),
|
||||||
|
&mut writer,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
&rng,
|
||||||
|
&mut frame_buf,
|
||||||
|
&stats,
|
||||||
|
"user",
|
||||||
|
&bytes_me2c,
|
||||||
|
77,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("ack response must be processed");
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
immediate,
|
||||||
|
MeWriterResponseOutcome::Continue {
|
||||||
|
frames: 1,
|
||||||
|
bytes: 4,
|
||||||
|
flush_immediately: true,
|
||||||
|
}
|
||||||
|
));
|
||||||
|
|
||||||
|
let delayed = process_me_writer_response(
|
||||||
|
MeResponse::Ack(0x55667788),
|
||||||
|
&mut writer,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
&rng,
|
||||||
|
&mut frame_buf,
|
||||||
|
&stats,
|
||||||
|
"user",
|
||||||
|
&bytes_me2c,
|
||||||
|
77,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("ack response must be processed");
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
delayed,
|
||||||
|
MeWriterResponseOutcome::Continue {
|
||||||
|
frames: 1,
|
||||||
|
bytes: 4,
|
||||||
|
flush_immediately: false,
|
||||||
|
}
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn process_me_writer_response_data_updates_byte_accounting() {
|
||||||
|
let (writer_side, _reader_side) = duplex(1024);
|
||||||
|
let mut writer = make_crypto_writer(writer_side);
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
let mut frame_buf = Vec::new();
|
||||||
|
let stats = Stats::new();
|
||||||
|
let bytes_me2c = AtomicU64::new(0);
|
||||||
|
|
||||||
|
let payload = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||||
|
let outcome = process_me_writer_response(
|
||||||
|
MeResponse::Data {
|
||||||
|
flags: 0,
|
||||||
|
data: Bytes::from(payload.clone()),
|
||||||
|
},
|
||||||
|
&mut writer,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
&rng,
|
||||||
|
&mut frame_buf,
|
||||||
|
&stats,
|
||||||
|
"user",
|
||||||
|
&bytes_me2c,
|
||||||
|
88,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("data response must be processed");
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
outcome,
|
||||||
|
MeWriterResponseOutcome::Continue {
|
||||||
|
frames: 1,
|
||||||
|
bytes,
|
||||||
|
flush_immediately: false,
|
||||||
|
} if bytes == payload.len()
|
||||||
|
));
|
||||||
|
assert_eq!(
|
||||||
|
bytes_me2c.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
payload.len() as u64,
|
||||||
|
"ME->C byte accounting must increase by emitted payload size"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue