diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 142495c..dbd50d5 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -11,6 +11,7 @@ use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use std::time::{Duration, Instant}; use dashmap::DashMap; +use dashmap::mapref::entry::Entry; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, warn, trace}; use zeroize::Zeroize; @@ -118,20 +119,29 @@ fn auth_probe_record_failure_with_state( peer_ip: IpAddr, now: Instant, ) { - if let Some(mut entry) = state.get_mut(&peer_ip) { - if auth_probe_state_expired(&entry, now) { - *entry = AuthProbeState { - fail_streak: 1, - blocked_until: now + auth_probe_backoff(1), - last_seen: now, - }; + let make_new_state = || AuthProbeState { + fail_streak: 1, + blocked_until: now + auth_probe_backoff(1), + last_seen: now, + }; + + let update_existing = |entry: &mut AuthProbeState| { + if auth_probe_state_expired(entry, now) { + *entry = make_new_state(); + } else { + entry.fail_streak = entry.fail_streak.saturating_add(1); + entry.last_seen = now; + entry.blocked_until = now + auth_probe_backoff(entry.fail_streak); + } + }; + + match state.entry(peer_ip) { + Entry::Occupied(mut entry) => { + update_existing(entry.get_mut()); return; } - entry.fail_streak = entry.fail_streak.saturating_add(1); - entry.last_seen = now; - entry.blocked_until = now + auth_probe_backoff(entry.fail_streak); - return; - }; + Entry::Vacant(_) => {} + } if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { let mut stale_keys = Vec::new(); @@ -155,11 +165,14 @@ fn auth_probe_record_failure_with_state( } } - state.insert(peer_ip, AuthProbeState { - fail_streak: 1, - blocked_until: now + auth_probe_backoff(1), - last_seen: now, - }); + match state.entry(peer_ip) { + Entry::Occupied(mut entry) => { + update_existing(entry.get_mut()); + } + Entry::Vacant(entry) => { + entry.insert(make_new_state()); + } + } } fn auth_probe_record_success(peer_ip: IpAddr) { diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index d8d8d3b..6bdc345 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -4,6 +4,7 @@ use dashmap::DashMap; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; use std::time::{Duration, Instant}; +use tokio::sync::Barrier; fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { let session_id_len: usize = 32; @@ -994,3 +995,116 @@ fn auth_probe_eviction_offset_varies_with_input() { assert_eq!(a, b, "same input must yield deterministic offset"); assert_ne!(a, c, "different peer IPs should not collapse to one offset"); } + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn auth_probe_concurrent_failures_do_not_lose_fail_streak_updates() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let peer_ip: IpAddr = "198.51.100.90".parse().unwrap(); + let tasks = 128usize; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::with_capacity(tasks); + + for _ in 0..tasks { + let barrier = barrier.clone(); + handles.push(tokio::spawn(async move { + barrier.wait().await; + auth_probe_record_failure(peer_ip, Instant::now()); + })); + } + + for handle in handles { + handle + .await + .expect("concurrent failure recording task must not panic"); + } + + let streak = auth_probe_fail_streak_for_testing(peer_ip) + .expect("tracked peer must exist after concurrent failure burst"); + assert_eq!( + streak as usize, + tasks, + "concurrent failures for one source must account every attempt" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn invalid_probe_noise_from_other_ips_does_not_break_valid_tls_handshake() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x31u8; 16]; + let config = Arc::new(test_config_with_secret_hex("31313131313131313131313131313131")); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + let victim_peer: SocketAddr = "198.51.100.91:44391".parse().unwrap(); + let valid = Arc::new(make_valid_tls_handshake(&secret, 0)); + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let invalid = Arc::new(invalid); + + let mut noise_tasks = Vec::new(); + for idx in 0..96u16 { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let invalid = invalid.clone(); + noise_tasks.push(tokio::spawn(async move { + let octet = ((idx % 200) + 1) as u8; + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, octet)), 45000 + idx); + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + })); + } + + let victim_config = config.clone(); + let victim_replay_checker = replay_checker.clone(); + let victim_rng = rng.clone(); + let victim_valid = valid.clone(); + let victim_task = tokio::spawn(async move { + handle_tls_handshake( + &victim_valid, + tokio::io::empty(), + tokio::io::sink(), + victim_peer, + &victim_config, + &victim_replay_checker, + &victim_rng, + None, + ) + .await + }); + + for task in noise_tasks { + task.await.expect("noise task must not panic"); + } + + let victim_result = victim_task + .await + .expect("victim handshake task must not panic"); + assert!( + matches!(victim_result, HandshakeResult::Success(_)), + "invalid probe noise from other IPs must not block a valid victim handshake" + ); + assert_eq!( + auth_probe_fail_streak_for_testing(victim_peer.ip()), + None, + "successful victim handshake must not retain pre-auth failure streak" + ); +} diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 091094d..1acbdc1 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -7,7 +7,6 @@ use std::time::{Duration, Instant}; #[cfg(test)] use std::sync::Mutex; -use bytes::Bytes; use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::{mpsc, oneshot, watch}; @@ -24,11 +23,11 @@ use crate::proxy::route_mode::{ cutover_stagger_delay, }; use crate::stats::Stats; -use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; enum C2MeCommand { - Data { payload: Bytes, flags: u32 }, + Data { payload: PooledBuffer, flags: u32 }, Close, } @@ -686,7 +685,7 @@ async fn read_client_payload( forensics: &RelayForensicsState, frame_counter: &mut u64, stats: &Stats, -) -> Result> +) -> Result> where R: AsyncRead + Unpin + Send + 'static, { @@ -807,8 +806,7 @@ where payload.truncate(secure_payload_len); } *frame_counter += 1; - let payload_bytes = Bytes::copy_from_slice(&payload[..]); - return Ok(Some((payload_bytes, quickack))); + return Ok(Some((payload, quickack))); } } diff --git a/src/proxy/middle_relay_security_tests.rs b/src/proxy/middle_relay_security_tests.rs index 511a853..a2a6c3e 100644 --- a/src/proxy/middle_relay_security_tests.rs +++ b/src/proxy/middle_relay_security_tests.rs @@ -1,7 +1,9 @@ use super::*; +use bytes::Bytes; use crate::crypto::AesCtr; +use crate::crypto::SecureRandom; use crate::stats::Stats; -use crate::stream::{BufferPool, CryptoReader}; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::AtomicU64; @@ -9,6 +11,21 @@ use tokio::io::AsyncWriteExt; use tokio::io::duplex; use tokio::time::{Duration as TokioDuration, timeout}; +fn make_pooled_payload(data: &[u8]) -> PooledBuffer { + let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); + let mut payload = pool.get(); + payload.resize(data.len(), 0); + payload[..data.len()].copy_from_slice(data); + payload +} + +fn make_pooled_payload_from(pool: &Arc, data: &[u8]) -> PooledBuffer { + let mut payload = pool.get(); + payload.resize(data.len(), 0); + payload[..data.len()].copy_from_slice(data); + payload +} + #[test] fn should_yield_sender_only_on_budget_with_backlog() { assert!(!should_yield_c2me_sender(0, true)); @@ -23,7 +40,7 @@ async fn enqueue_c2me_command_uses_try_send_fast_path() { enqueue_c2me_command( &tx, C2MeCommand::Data { - payload: Bytes::from_static(&[1, 2, 3]), + payload: make_pooled_payload(&[1, 2, 3]), flags: 0, }, ) @@ -47,7 +64,7 @@ async fn enqueue_c2me_command_uses_try_send_fast_path() { async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { let (tx, mut rx) = mpsc::channel::(1); tx.send(C2MeCommand::Data { - payload: Bytes::from_static(&[9]), + payload: make_pooled_payload(&[9]), flags: 9, }) .await @@ -58,7 +75,7 @@ async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { enqueue_c2me_command( &tx2, C2MeCommand::Data { - payload: Bytes::from_static(&[7, 7]), + payload: make_pooled_payload(&[7, 7]), flags: 7, }, ) @@ -84,6 +101,74 @@ async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { } } +#[tokio::test] +async fn enqueue_c2me_command_closed_channel_recycles_payload() { + let pool = Arc::new(BufferPool::with_config(64, 4)); + let payload = make_pooled_payload_from(&pool, &[1, 2, 3, 4]); + let (tx, rx) = mpsc::channel::(1); + drop(rx); + + let result = enqueue_c2me_command( + &tx, + C2MeCommand::Data { + payload, + flags: 0, + }, + ) + .await; + + assert!(result.is_err(), "closed queue must fail enqueue"); + drop(result); + assert!( + pool.stats().pooled >= 1, + "payload must return to pool when enqueue fails on closed channel" + ); +} + +#[tokio::test] +async fn enqueue_c2me_command_full_then_closed_recycles_waiting_payload() { + let pool = Arc::new(BufferPool::with_config(64, 4)); + let (tx, rx) = mpsc::channel::(1); + + tx.send(C2MeCommand::Data { + payload: make_pooled_payload_from(&pool, &[9]), + flags: 1, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let pool2 = pool.clone(); + let blocked_send = tokio::spawn(async move { + enqueue_c2me_command( + &tx2, + C2MeCommand::Data { + payload: make_pooled_payload_from(&pool2, &[7, 7, 7]), + flags: 2, + }, + ) + .await + }); + + tokio::time::sleep(TokioDuration::from_millis(10)).await; + drop(rx); + + let result = timeout(TokioDuration::from_secs(1), blocked_send) + .await + .expect("blocked send task must finish") + .expect("blocked send task must not panic"); + + assert!( + result.is_err(), + "closing receiver while sender is blocked must fail enqueue" + ); + drop(result); + assert!( + pool.stats().pooled >= 2, + "both queued and blocked payloads must return to pool after channel close" + ); +} + #[test] fn desync_dedup_cache_is_bounded() { let _guard = desync_dedup_test_lock() @@ -150,6 +235,12 @@ fn make_crypto_reader(reader: tokio::io::DuplexStream) -> CryptoReader CryptoWriter { + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + fn encrypt_for_reader(plaintext: &[u8]) -> Vec { let key = [0u8; 32]; let iv = 0u128; @@ -476,3 +567,215 @@ async fn read_client_payload_returns_buffer_to_pool_after_emit() { "emitted payload buffer must be returned to pool to avoid pool drain" ); } + +#[tokio::test] +async fn read_client_payload_keeps_pool_buffer_checked_out_until_frame_drop() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let pool = Arc::new(BufferPool::with_config(64, 2)); + pool.preallocate(1); + assert_eq!(pool.stats().pooled, 1, "one pooled buffer must be available"); + + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let payload = [0x41u8, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48]; + let mut plaintext = Vec::with_capacity(4 + payload.len()); + plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); + plaintext.extend_from_slice(&payload); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let (frame, quickack) = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + TokioDuration::from_secs(1), + &pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("payload read must succeed") + .expect("frame must be present"); + + assert!(!quickack); + assert_eq!(frame.as_ref(), &payload); + assert_eq!( + pool.stats().pooled, + 0, + "buffer must stay checked out while frame payload is alive" + ); + + drop(frame); + assert!( + pool.stats().pooled >= 1, + "buffer must return to pool only after frame drop" + ); +} + +#[tokio::test] +async fn enqueue_c2me_close_unblocks_after_queue_drain() { + let (tx, mut rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[0x41]), + flags: 0, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let close_task = tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); + + tokio::time::sleep(TokioDuration::from_millis(10)).await; + + let first = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap() + .expect("first queued item must be present"); + assert!(matches!(first, C2MeCommand::Data { .. })); + + close_task.await.unwrap().expect("close enqueue must succeed after drain"); + + let second = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap() + .expect("close command must follow after queue drain"); + assert!(matches!(second, C2MeCommand::Close)); +} + +#[tokio::test] +async fn enqueue_c2me_close_full_then_receiver_drop_fails_cleanly() { + let (tx, rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[0x42]), + flags: 0, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let close_task = tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); + + tokio::time::sleep(TokioDuration::from_millis(10)).await; + drop(rx); + + let result = timeout(TokioDuration::from_secs(1), close_task) + .await + .expect("close task must finish") + .expect("close task must not panic"); + assert!( + result.is_err(), + "close enqueue must fail cleanly when receiver is dropped under pressure" + ); +} + +#[tokio::test] +async fn process_me_writer_response_ack_obeys_flush_policy() { + let (writer_side, _reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + let immediate = process_me_writer_response( + MeResponse::Ack(0x11223344), + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "user", + &bytes_me2c, + 77, + true, + false, + ) + .await + .expect("ack response must be processed"); + + assert!(matches!( + immediate, + MeWriterResponseOutcome::Continue { + frames: 1, + bytes: 4, + flush_immediately: true, + } + )); + + let delayed = process_me_writer_response( + MeResponse::Ack(0x55667788), + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "user", + &bytes_me2c, + 77, + false, + false, + ) + .await + .expect("ack response must be processed"); + + assert!(matches!( + delayed, + MeWriterResponseOutcome::Continue { + frames: 1, + bytes: 4, + flush_immediately: false, + } + )); +} + +#[tokio::test] +async fn process_me_writer_response_data_updates_byte_accounting() { + let (writer_side, _reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + let payload = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9]; + let outcome = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload.clone()), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "user", + &bytes_me2c, + 88, + false, + false, + ) + .await + .expect("data response must be processed"); + + assert!(matches!( + outcome, + MeWriterResponseOutcome::Continue { + frames: 1, + bytes, + flush_immediately: false, + } if bytes == payload.len() + )); + assert_eq!( + bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), + payload.len() as u64, + "ME->C byte accounting must increase by emitted payload size" + ); +}