diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index c82c9fe..0f54245 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -381,7 +381,7 @@ fn validate_tls_handshake_at_time_with_boot_cap( let mut msg = handshake.to_vec(); msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); - let mut first_match: Option = None; + let mut first_match: Option<(&String, u32)> = None; for (user, secret) in secrets { let computed = sha256_hmac(secret, &msg); @@ -421,16 +421,16 @@ fn validate_tls_handshake_at_time_with_boot_cap( } if first_match.is_none() { - first_match = Some(TlsValidation { - user: user.clone(), - session_id: session_id.clone(), - digest, - timestamp, - }); + first_match = Some((user, timestamp)); } } - first_match + first_match.map(|(user, timestamp)| TlsValidation { + user: user.clone(), + session_id, + digest, + timestamp, + }) } fn curve25519_prime() -> BigUint { diff --git a/src/protocol/tls_security_tests.rs b/src/protocol/tls_security_tests.rs index c25a517..98d7319 100644 --- a/src/protocol/tls_security_tests.rs +++ b/src/protocol/tls_security_tests.rs @@ -9,12 +9,19 @@ use crate::crypto::sha256_hmac; /// [TLS_DIGEST_POS..+32] : digest = HMAC XOR [0..0 || timestamp_le] /// [TLS_DIGEST_POS+32] : session_id_len = 32 /// [TLS_DIGEST_POS+33..+65] : session_id filler (0x42) -fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { - let session_id_len: usize = 32; +fn make_valid_tls_handshake_with_session_id( + secret: &[u8], + timestamp: u32, + session_id: &[u8], +) -> Vec { + let session_id_len = session_id.len(); + assert!(session_id_len <= u8::MAX as usize); let len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 + session_id_len; let mut handshake = vec![0x42u8; len]; handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = session_id_len as u8; + let sid_start = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + handshake[sid_start..sid_start + session_id_len].copy_from_slice(session_id); // Zero the digest slot before computing HMAC (mirrors what validate does). handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0); @@ -34,6 +41,10 @@ fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { handshake } +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + make_valid_tls_handshake_with_session_id(secret, timestamp, &[0x42; 32]) +} + // ------------------------------------------------------------------ // Happy-path sanity // ------------------------------------------------------------------ @@ -311,6 +322,20 @@ fn too_short_handshake_rejected_without_panic() { assert!(validate_tls_handshake(&[], &secrets, true).is_none()); } +#[test] +fn all_prefix_lengths_below_minimum_rejected_without_panic() { + let min_len = TLS_DIGEST_POS + TLS_DIGEST_LEN + 1; + let secrets = vec![("u".to_string(), b"s".to_vec())]; + + for len in 0..min_len { + let h = vec![0u8; len]; + assert!( + validate_tls_handshake(&h, &secrets, true).is_none(), + "prefix length {len} below minimum must be rejected" + ); + } +} + #[test] fn claimed_session_id_overflows_buffer_rejected() { let session_id_len: usize = 32; @@ -332,6 +357,30 @@ fn max_session_id_len_255_does_not_panic() { assert!(validate_tls_handshake(&h, &secrets, true).is_none()); } +#[test] +fn one_byte_session_id_validates_and_is_preserved() { + let secret = b"sid_len_1_test"; + let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &[0xAB]); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake(&handshake, &secrets, true) + .expect("one-byte session_id handshake must validate"); + assert_eq!(result.session_id, vec![0xAB]); +} + +#[test] +fn max_session_id_len_255_with_valid_digest_is_accepted() { + let secret = b"sid_len_255_test"; + let session_id = vec![0xCCu8; 255]; + let handshake = make_valid_tls_handshake_with_session_id(secret, 0, &session_id); + let secrets = vec![("u".to_string(), secret.to_vec())]; + + let result = validate_tls_handshake(&handshake, &secrets, true) + .expect("session_id_len=255 with valid digest must validate"); + assert_eq!(result.session_id.len(), 255); + assert_eq!(result.session_id, session_id); +} + // ------------------------------------------------------------------ // Adversarial digest values // ------------------------------------------------------------------ @@ -867,6 +916,23 @@ fn test_parse_tls_record_header() { assert_eq!(result.1, 16384); } +#[test] +fn parse_tls_record_header_rejects_invalid_versions() { + let invalid = [ + [0x16, 0x03, 0x00, 0x00, 0x10], + [0x16, 0x02, 0x00, 0x00, 0x10], + [0x16, 0x03, 0x02, 0x00, 0x10], + [0x16, 0x04, 0x00, 0x00, 0x10], + ]; + for header in invalid { + assert!( + parse_tls_record_header(&header).is_none(), + "invalid TLS record version {:?} must be rejected", + [header[1], header[2]] + ); + } +} + #[test] fn test_gen_fake_x25519_key() { let rng = crate::crypto::SecureRandom::new(); @@ -1168,6 +1234,47 @@ fn extract_sni_rejects_when_extension_block_is_truncated() { assert!(extract_sni_from_client_hello(&ch).is_none()); } +#[test] +fn extract_sni_rejects_session_id_len_overflow() { + let mut ch = build_client_hello_with_exts(Vec::new(), "example.com"); + let sid_len_pos = 5 + 4 + 2 + 32; + ch[sid_len_pos] = 255; + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_cipher_suites_len_overflow() { + let mut ch = build_client_hello_with_exts(Vec::new(), "example.com"); + let sid_len_pos = 5 + 4 + 2 + 32; + let cipher_len_pos = sid_len_pos + 1 + ch[sid_len_pos] as usize; + ch[cipher_len_pos] = 0xFF; + ch[cipher_len_pos + 1] = 0xFF; + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_compression_methods_len_overflow() { + let mut ch = build_client_hello_with_exts(Vec::new(), "example.com"); + let sid_len_pos = 5 + 4 + 2 + 32; + let cipher_len_pos = sid_len_pos + 1 + ch[sid_len_pos] as usize; + let cipher_len = u16::from_be_bytes([ch[cipher_len_pos], ch[cipher_len_pos + 1]]) as usize; + let comp_len_pos = cipher_len_pos + 2 + cipher_len; + ch[comp_len_pos] = 0xFF; + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_alpn_returns_empty_on_session_id_len_overflow() { + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&3u16.to_be_bytes()); + alpn_data.push(2); + alpn_data.extend_from_slice(b"h2"); + let mut ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test"); + let sid_len_pos = 5 + 4 + 2 + 32; + ch[sid_len_pos] = 255; + assert!(extract_alpn_from_client_hello(&ch).is_empty()); +} + #[test] fn extract_alpn_rejects_when_extension_block_is_truncated() { let mut ext_blob = Vec::new(); diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index ef98144..dbd50d5 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -4,11 +4,14 @@ use std::net::SocketAddr; use std::collections::HashSet; -use std::net::IpAddr; +use std::net::{IpAddr, Ipv6Addr}; use std::sync::Arc; use std::sync::{Mutex, OnceLock}; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; use std::time::{Duration, Instant}; use dashmap::DashMap; +use dashmap::mapref::entry::Entry; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, warn, trace}; use zeroize::Zeroize; @@ -57,6 +60,16 @@ fn auth_probe_state_map() -> &'static DashMap { AUTH_PROBE_STATE.get_or_init(DashMap::new) } +fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr { + match peer_ip { + IpAddr::V4(ip) => IpAddr::V4(ip), + IpAddr::V6(ip) => { + let [a, b, c, d, _, _, _, _] = ip.segments(); + IpAddr::V6(Ipv6Addr::new(a, b, c, d, 0, 0, 0, 0)) + } + } +} + fn auth_probe_backoff(fail_streak: u32) -> Duration { if fail_streak < AUTH_PROBE_BACKOFF_START_FAILS { return Duration::ZERO; @@ -74,7 +87,15 @@ fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool { now.duration_since(state.last_seen) > retention } +fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize { + let mut hasher = DefaultHasher::new(); + peer_ip.hash(&mut hasher); + now.hash(&mut hasher); + hasher.finish() as usize +} + fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { + let peer_ip = normalize_auth_probe_ip(peer_ip); let state = auth_probe_state_map(); let Some(entry) = state.get(&peer_ip) else { return false; @@ -88,6 +109,7 @@ fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool { } fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) { + let peer_ip = normalize_auth_probe_ip(peer_ip); let state = auth_probe_state_map(); auth_probe_record_failure_with_state(state, peer_ip, now); } @@ -97,24 +119,35 @@ fn auth_probe_record_failure_with_state( peer_ip: IpAddr, now: Instant, ) { - if let Some(mut entry) = state.get_mut(&peer_ip) { - if auth_probe_state_expired(&entry, now) { - *entry = AuthProbeState { - fail_streak: 1, - blocked_until: now + auth_probe_backoff(1), - last_seen: now, - }; + let make_new_state = || AuthProbeState { + fail_streak: 1, + blocked_until: now + auth_probe_backoff(1), + last_seen: now, + }; + + let update_existing = |entry: &mut AuthProbeState| { + if auth_probe_state_expired(entry, now) { + *entry = make_new_state(); + } else { + entry.fail_streak = entry.fail_streak.saturating_add(1); + entry.last_seen = now; + entry.blocked_until = now + auth_probe_backoff(entry.fail_streak); + } + }; + + match state.entry(peer_ip) { + Entry::Occupied(mut entry) => { + update_existing(entry.get_mut()); return; } - entry.fail_streak = entry.fail_streak.saturating_add(1); - entry.last_seen = now; - entry.blocked_until = now + auth_probe_backoff(entry.fail_streak); - return; - }; + Entry::Vacant(_) => {} + } if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { let mut stale_keys = Vec::new(); + let mut eviction_candidates = Vec::new(); for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) { + eviction_candidates.push(*entry.key()); if auth_probe_state_expired(entry.value(), now) { stale_keys.push(*entry.key()); } @@ -123,23 +156,27 @@ fn auth_probe_record_failure_with_state( state.remove(&stale_key); } if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES { - return; + if eviction_candidates.is_empty() { + return; + } + let idx = auth_probe_eviction_offset(peer_ip, now) % eviction_candidates.len(); + let evict_key = eviction_candidates[idx]; + state.remove(&evict_key); } } - state.insert(peer_ip, AuthProbeState { - fail_streak: 0, - blocked_until: now, - last_seen: now, - }); - - if let Some(mut entry) = state.get_mut(&peer_ip) { - entry.fail_streak = 1; - entry.blocked_until = now + auth_probe_backoff(1); + match state.entry(peer_ip) { + Entry::Occupied(mut entry) => { + update_existing(entry.get_mut()); + } + Entry::Vacant(entry) => { + entry.insert(make_new_state()); + } } } fn auth_probe_record_success(peer_ip: IpAddr) { + let peer_ip = normalize_auth_probe_ip(peer_ip); let state = auth_probe_state_map(); state.remove(&peer_ip); } @@ -153,6 +190,7 @@ fn clear_auth_probe_state_for_testing() { #[cfg(test)] fn auth_probe_fail_streak_for_testing(peer_ip: IpAddr) -> Option { + let peer_ip = normalize_auth_probe_ip(peer_ip); let state = AUTH_PROBE_STATE.get()?; state.get(&peer_ip).map(|entry| entry.fail_streak) } @@ -177,6 +215,12 @@ fn clear_warned_secrets_for_testing() { } } +#[cfg(test)] +fn warned_secrets_test_lock() -> &'static Mutex<()> { + static TEST_LOCK: OnceLock> = OnceLock::new(); + TEST_LOCK.get_or_init(|| Mutex::new(())) +} + fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option) { let key = (name.to_string(), reason.to_string()); let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new())); diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index f2d7d03..6bdc345 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -4,6 +4,7 @@ use dashmap::DashMap; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; use std::time::{Duration, Instant}; +use tokio::sync::Barrier; fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { let session_id_len: usize = 32; @@ -84,7 +85,6 @@ fn make_valid_tls_client_hello_with_alpn( } fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { - clear_auth_probe_state_for_testing(); let mut cfg = ProxyConfig::default(); cfg.access.users.clear(); cfg.access @@ -369,6 +369,9 @@ async fn invalid_tls_probe_does_not_pollute_replay_cache() { #[tokio::test] async fn empty_decoded_secret_is_rejected() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); clear_warned_secrets_for_testing(); let config = test_config_with_secret_hex(""); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); @@ -393,6 +396,9 @@ async fn empty_decoded_secret_is_rejected() { #[tokio::test] async fn wrong_length_decoded_secret_is_rejected() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); clear_warned_secrets_for_testing(); let config = test_config_with_secret_hex("aa"); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); @@ -443,6 +449,12 @@ async fn invalid_mtproto_probe_does_not_pollute_replay_cache() { #[tokio::test] async fn mixed_secret_lengths_keep_valid_user_authenticating() { + let _probe_guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); clear_warned_secrets_for_testing(); clear_auth_probe_state_for_testing(); let good_secret = [0x22u8; 16]; @@ -708,6 +720,9 @@ fn mode_policy_matrix_is_stable_for_all_tag_transport_mode_combinations() { #[test] fn invalid_secret_warning_keys_do_not_collide_on_colon_boundaries() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); clear_warned_secrets_for_testing(); warn_invalid_secret_once("a:b", "c", ACCESS_SECRET_BYTES, Some(1)); @@ -755,8 +770,9 @@ async fn repeated_invalid_tls_probes_trigger_pre_auth_throttle() { } assert!( - auth_probe_is_throttled_for_testing(peer.ip()), - "invalid probe burst must activate per-IP pre-auth throttle" + auth_probe_fail_streak_for_testing(peer.ip()) + .is_some_and(|streak| streak >= AUTH_PROBE_BACKOFF_START_FAILS), + "invalid probe burst must grow pre-auth failure streak to backoff threshold" ); } @@ -855,7 +871,7 @@ fn auth_probe_capacity_prunes_stale_entries_for_new_ips() { } #[test] -fn auth_probe_capacity_stays_fail_closed_when_map_is_fresh_and_full() { +fn auth_probe_capacity_forces_bounded_eviction_when_map_is_fresh_and_full() { let state = DashMap::new(); let now = Instant::now(); @@ -880,12 +896,215 @@ fn auth_probe_capacity_stays_fail_closed_when_map_is_fresh_and_full() { auth_probe_record_failure_with_state(&state, newcomer, now); assert!( - state.get(&newcomer).is_none(), - "when all entries are fresh and full, new probes must not be admitted" + state.get(&newcomer).is_some(), + "when all entries are fresh and full, one bounded eviction must admit a new probe source" ); assert_eq!( state.len(), AUTH_PROBE_TRACK_MAX_ENTRIES, - "auth probe map must stay at the configured cap" + "auth probe map must stay at the configured cap after forced eviction" + ); +} + +#[test] +fn auth_probe_ipv6_is_bucketed_by_prefix_64() { + let state = DashMap::new(); + let now = Instant::now(); + + let ip_a = IpAddr::V6("2001:db8:abcd:1234:1:2:3:4".parse().unwrap()); + let ip_b = IpAddr::V6("2001:db8:abcd:1234:ffff:eeee:dddd:cccc".parse().unwrap()); + + auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_a), now); + auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_b), now); + + let normalized = normalize_auth_probe_ip(ip_a); + assert_eq!( + state.len(), + 1, + "IPv6 sources in the same /64 must share one pre-auth throttle bucket" + ); + assert_eq!( + state.get(&normalized).map(|entry| entry.fail_streak), + Some(2), + "failures from the same /64 must accumulate in one throttle state" + ); +} + +#[test] +fn auth_probe_ipv6_different_prefixes_use_distinct_buckets() { + let state = DashMap::new(); + let now = Instant::now(); + + let ip_a = IpAddr::V6("2001:db8:1111:2222:1:2:3:4".parse().unwrap()); + let ip_b = IpAddr::V6("2001:db8:1111:3333:1:2:3:4".parse().unwrap()); + + auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_a), now); + auth_probe_record_failure_with_state(&state, normalize_auth_probe_ip(ip_b), now); + + assert_eq!( + state.len(), + 2, + "different IPv6 /64 prefixes must not share throttle buckets" + ); + assert_eq!( + state.get(&normalize_auth_probe_ip(ip_a)).map(|entry| entry.fail_streak), + Some(1) + ); + assert_eq!( + state.get(&normalize_auth_probe_ip(ip_b)).map(|entry| entry.fail_streak), + Some(1) + ); +} + +#[test] +fn auth_probe_success_clears_whole_ipv6_prefix_bucket() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let now = Instant::now(); + let ip_fail = IpAddr::V6("2001:db8:aaaa:bbbb:1:2:3:4".parse().unwrap()); + let ip_success = IpAddr::V6("2001:db8:aaaa:bbbb:ffff:eeee:dddd:cccc".parse().unwrap()); + + auth_probe_record_failure(ip_fail, now); + assert_eq!( + auth_probe_fail_streak_for_testing(ip_fail), + Some(1), + "precondition: normalized prefix bucket must exist" + ); + + auth_probe_record_success(ip_success); + assert_eq!( + auth_probe_fail_streak_for_testing(ip_fail), + None, + "success from the same /64 must clear the shared bucket" + ); +} + +#[test] +fn auth_probe_eviction_offset_varies_with_input() { + let now = Instant::now(); + let ip1 = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 10)); + let ip2 = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 11)); + + let a = auth_probe_eviction_offset(ip1, now); + let b = auth_probe_eviction_offset(ip1, now); + let c = auth_probe_eviction_offset(ip2, now); + + assert_eq!(a, b, "same input must yield deterministic offset"); + assert_ne!(a, c, "different peer IPs should not collapse to one offset"); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn auth_probe_concurrent_failures_do_not_lose_fail_streak_updates() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let peer_ip: IpAddr = "198.51.100.90".parse().unwrap(); + let tasks = 128usize; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::with_capacity(tasks); + + for _ in 0..tasks { + let barrier = barrier.clone(); + handles.push(tokio::spawn(async move { + barrier.wait().await; + auth_probe_record_failure(peer_ip, Instant::now()); + })); + } + + for handle in handles { + handle + .await + .expect("concurrent failure recording task must not panic"); + } + + let streak = auth_probe_fail_streak_for_testing(peer_ip) + .expect("tracked peer must exist after concurrent failure burst"); + assert_eq!( + streak as usize, + tasks, + "concurrent failures for one source must account every attempt" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn invalid_probe_noise_from_other_ips_does_not_break_valid_tls_handshake() { + let _guard = auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + + let secret = [0x31u8; 16]; + let config = Arc::new(test_config_with_secret_hex("31313131313131313131313131313131")); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + let victim_peer: SocketAddr = "198.51.100.91:44391".parse().unwrap(); + let valid = Arc::new(make_valid_tls_handshake(&secret, 0)); + + let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32]; + invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32; + let invalid = Arc::new(invalid); + + let mut noise_tasks = Vec::new(); + for idx in 0..96u16 { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let invalid = invalid.clone(); + noise_tasks.push(tokio::spawn(async move { + let octet = ((idx % 200) + 1) as u8; + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, octet)), 45000 + idx); + let result = handle_tls_handshake( + &invalid, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(result, HandshakeResult::BadClient { .. })); + })); + } + + let victim_config = config.clone(); + let victim_replay_checker = replay_checker.clone(); + let victim_rng = rng.clone(); + let victim_valid = valid.clone(); + let victim_task = tokio::spawn(async move { + handle_tls_handshake( + &victim_valid, + tokio::io::empty(), + tokio::io::sink(), + victim_peer, + &victim_config, + &victim_replay_checker, + &victim_rng, + None, + ) + .await + }); + + for task in noise_tasks { + task.await.expect("noise task must not panic"); + } + + let victim_result = victim_task + .await + .expect("victim handshake task must not panic"); + assert!( + matches!(victim_result, HandshakeResult::Success(_)), + "invalid probe noise from other IPs must not block a valid victim handshake" + ); + assert_eq!( + auth_probe_fail_streak_for_testing(victim_peer.ip()), + None, + "successful victim handshake must not retain pre-auth failure streak" ); } diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index e347d73..9a23c5b 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -223,10 +223,10 @@ async fn relay_to_mask( initial_data: &[u8], ) where - R: AsyncRead + Unpin + Send, - W: AsyncWrite + Unpin + Send, - MR: AsyncRead + Unpin + Send, - MW: AsyncWrite + Unpin + Send, + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, + MR: AsyncRead + Unpin + Send + 'static, + MW: AsyncWrite + Unpin + Send + 'static, { // Send initial data to mask host if mask_write.write_all(initial_data).await.is_err() { @@ -236,39 +236,16 @@ where return; } - let mut client_buf = vec![0u8; MASK_BUFFER_SIZE]; - let mut mask_buf = vec![0u8; MASK_BUFFER_SIZE]; - - loop { - tokio::select! { - client_read = reader.read(&mut client_buf) => { - match client_read { - Ok(0) | Err(_) => { - let _ = mask_write.shutdown().await; - break; - } - Ok(n) => { - if mask_write.write_all(&client_buf[..n]).await.is_err() { - break; - } - } - } - } - mask_read_res = mask_read.read(&mut mask_buf) => { - match mask_read_res { - Ok(0) | Err(_) => { - let _ = writer.shutdown().await; - break; - } - Ok(n) => { - if writer.write_all(&mask_buf[..n]).await.is_err() { - break; - } - } - } - } + let _ = tokio::join!( + async { + let _ = tokio::io::copy(&mut reader, &mut mask_write).await; + let _ = mask_write.shutdown().await; + }, + async { + let _ = tokio::io::copy(&mut mask_read, &mut writer).await; + let _ = writer.shutdown().await; } - } + ); } /// Just consume all data from client without responding diff --git a/src/proxy/masking_security_tests.rs b/src/proxy/masking_security_tests.rs index 52e9f69..25b6a76 100644 --- a/src/proxy/masking_security_tests.rs +++ b/src/proxy/masking_security_tests.rs @@ -1,12 +1,14 @@ use super::*; use crate::config::ProxyConfig; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; use tokio::net::TcpListener; #[cfg(unix)] use tokio::net::UnixListener; -use tokio::time::{timeout, Duration}; +use tokio::time::{sleep, timeout, Duration}; #[tokio::test] async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() { @@ -542,9 +544,188 @@ impl tokio::io::AsyncWrite for PendingWriter { } } +struct DropTrackedPendingReader { + dropped: Arc, +} + +impl tokio::io::AsyncRead for DropTrackedPendingReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Poll::Pending + } +} + +impl Drop for DropTrackedPendingReader { + fn drop(&mut self) { + self.dropped.store(true, Ordering::SeqCst); + } +} + +struct DropTrackedPendingWriter { + dropped: Arc, +} + +impl tokio::io::AsyncWrite for DropTrackedPendingWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +impl Drop for DropTrackedPendingWriter { + fn drop(&mut self) { + self.dropped.store(true, Ordering::SeqCst); + } +} + #[tokio::test] async fn proxy_header_write_timeout_returns_false() { let mut writer = PendingWriter; let ok = write_proxy_header_with_timeout(&mut writer, b"PROXY UNKNOWN\r\n").await; assert!(!ok, "Proxy header writes that never complete must time out"); } + +#[tokio::test] +async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stalls() { + let (mut client_feed_writer, client_feed_reader) = duplex(64); + let (mut client_visible_reader, client_visible_writer) = duplex(64); + let (mut backend_feed_writer, backend_feed_reader) = duplex(64); + + // Make client->mask direction immediately active so the c2m path blocks on PendingWriter. + client_feed_writer.write_all(b"X").await.unwrap(); + + let relay = tokio::spawn(async move { + relay_to_mask( + client_feed_reader, + client_visible_writer, + backend_feed_reader, + PendingWriter, + b"", + ) + .await; + }); + + // Allow relay tasks to start, then emulate mask backend response. + sleep(Duration::from_millis(20)).await; + backend_feed_writer.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap(); + backend_feed_writer.shutdown().await.unwrap(); + + let mut observed = vec![0u8; 19]; + timeout(Duration::from_secs(1), client_visible_reader.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, b"HTTP/1.1 200 OK\r\n\r\n"); + + relay.abort(); + let _ = relay.await; +} + +#[tokio::test] +async fn relay_to_mask_preserves_backend_response_after_client_half_close() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let request = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let backend_task = tokio::spawn({ + let request = request.clone(); + let response = response.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut observed_req = vec![0u8; request.len()]; + stream.read_exact(&mut observed_req).await.unwrap(); + assert_eq!(observed_req, request); + stream.write_all(&response).await.unwrap(); + stream.shutdown().await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.77:55001".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (mut client_write, client_read) = duplex(1024); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + let beobachten = BeobachtenStore::new(); + + let fallback_task = tokio::spawn(async move { + handle_bad_client( + client_read, + client_visible_writer, + &request, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + client_write.shutdown().await.unwrap(); + + let mut observed_resp = vec![0u8; response.len()]; + timeout(Duration::from_secs(1), client_visible_reader.read_exact(&mut observed_resp)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed_resp, response); + + timeout(Duration::from_secs(1), fallback_task).await.unwrap().unwrap(); + timeout(Duration::from_secs(1), backend_task).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() { + let reader_dropped = Arc::new(AtomicBool::new(false)); + let writer_dropped = Arc::new(AtomicBool::new(false)); + let mask_reader_dropped = Arc::new(AtomicBool::new(false)); + let mask_writer_dropped = Arc::new(AtomicBool::new(false)); + + let reader = DropTrackedPendingReader { + dropped: reader_dropped.clone(), + }; + let writer = DropTrackedPendingWriter { + dropped: writer_dropped.clone(), + }; + let mask_read = DropTrackedPendingReader { + dropped: mask_reader_dropped.clone(), + }; + let mask_write = DropTrackedPendingWriter { + dropped: mask_writer_dropped.clone(), + }; + + let timed = timeout( + Duration::from_millis(40), + relay_to_mask(reader, writer, mask_read, mask_write, b""), + ) + .await; + + assert!(timed.is_err(), "stalled relay must be bounded by timeout"); + + assert!(reader_dropped.load(Ordering::SeqCst)); + assert!(writer_dropped.load(Ordering::SeqCst)); + assert!(mask_reader_dropped.load(Ordering::SeqCst)); + assert!(mask_writer_dropped.load(Ordering::SeqCst)); +} diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index ba01c74..1acbdc1 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -7,7 +7,6 @@ use std::time::{Duration, Instant}; #[cfg(test)] use std::sync::Mutex; -use bytes::{Bytes, BytesMut}; use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::{mpsc, oneshot, watch}; @@ -24,11 +23,11 @@ use crate::proxy::route_mode::{ cutover_stagger_delay, }; use crate::stats::Stats; -use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; enum C2MeCommand { - Data { payload: Bytes, flags: u32 }, + Data { payload: PooledBuffer, flags: u32 }, Close, } @@ -107,7 +106,11 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool { if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { let mut stale_keys = Vec::new(); + let mut eviction_candidate = None; for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) { + if eviction_candidate.is_none() { + eviction_candidate = Some(*entry.key()); + } if now.duration_since(*entry.value()) >= DESYNC_DEDUP_WINDOW { stale_keys.push(*entry.key()); } @@ -116,6 +119,11 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool { dedup.remove(&stale_key); } if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES { + let Some(evict_key) = eviction_candidate else { + return false; + }; + dedup.remove(&evict_key); + dedup.insert(key, now); return false; } } @@ -677,7 +685,7 @@ async fn read_client_payload( forensics: &RelayForensicsState, frame_counter: &mut u64, stats: &Stats, -) -> Result> +) -> Result> where R: AsyncRead + Unpin + Send + 'static, { @@ -784,25 +792,21 @@ where len }; - let chunk_cap = buffer_pool.buffer_size().max(1024); - let mut payload = BytesMut::with_capacity(len.min(chunk_cap)); - let mut remaining = len; - while remaining > 0 { - let chunk_len = remaining.min(chunk_cap); - let mut chunk = buffer_pool.get(); - chunk.resize(chunk_len, 0); - read_exact_with_timeout(client_reader, &mut chunk[..chunk_len], frame_read_timeout) - .await?; - payload.extend_from_slice(&chunk[..chunk_len]); - remaining -= chunk_len; + let mut payload = buffer_pool.get(); + payload.clear(); + let current_cap = payload.capacity(); + if current_cap < len { + payload.reserve(len - current_cap); } + payload.resize(len, 0); + read_exact_with_timeout(client_reader, &mut payload[..len], frame_read_timeout).await?; // Secure Intermediate: strip validated trailing padding bytes. if proto_tag == ProtoTag::Secure { payload.truncate(secure_payload_len); } *frame_counter += 1; - return Ok(Some((payload.freeze(), quickack))); + return Ok(Some((payload, quickack))); } } diff --git a/src/proxy/middle_relay_security_tests.rs b/src/proxy/middle_relay_security_tests.rs index a2f89f8..a2a6c3e 100644 --- a/src/proxy/middle_relay_security_tests.rs +++ b/src/proxy/middle_relay_security_tests.rs @@ -1,7 +1,9 @@ use super::*; +use bytes::Bytes; use crate::crypto::AesCtr; +use crate::crypto::SecureRandom; use crate::stats::Stats; -use crate::stream::{BufferPool, CryptoReader}; +use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::AtomicU64; @@ -9,6 +11,21 @@ use tokio::io::AsyncWriteExt; use tokio::io::duplex; use tokio::time::{Duration as TokioDuration, timeout}; +fn make_pooled_payload(data: &[u8]) -> PooledBuffer { + let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); + let mut payload = pool.get(); + payload.resize(data.len(), 0); + payload[..data.len()].copy_from_slice(data); + payload +} + +fn make_pooled_payload_from(pool: &Arc, data: &[u8]) -> PooledBuffer { + let mut payload = pool.get(); + payload.resize(data.len(), 0); + payload[..data.len()].copy_from_slice(data); + payload +} + #[test] fn should_yield_sender_only_on_budget_with_backlog() { assert!(!should_yield_c2me_sender(0, true)); @@ -23,7 +40,7 @@ async fn enqueue_c2me_command_uses_try_send_fast_path() { enqueue_c2me_command( &tx, C2MeCommand::Data { - payload: Bytes::from_static(&[1, 2, 3]), + payload: make_pooled_payload(&[1, 2, 3]), flags: 0, }, ) @@ -47,7 +64,7 @@ async fn enqueue_c2me_command_uses_try_send_fast_path() { async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { let (tx, mut rx) = mpsc::channel::(1); tx.send(C2MeCommand::Data { - payload: Bytes::from_static(&[9]), + payload: make_pooled_payload(&[9]), flags: 9, }) .await @@ -58,7 +75,7 @@ async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { enqueue_c2me_command( &tx2, C2MeCommand::Data { - payload: Bytes::from_static(&[7, 7]), + payload: make_pooled_payload(&[7, 7]), flags: 7, }, ) @@ -84,6 +101,74 @@ async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { } } +#[tokio::test] +async fn enqueue_c2me_command_closed_channel_recycles_payload() { + let pool = Arc::new(BufferPool::with_config(64, 4)); + let payload = make_pooled_payload_from(&pool, &[1, 2, 3, 4]); + let (tx, rx) = mpsc::channel::(1); + drop(rx); + + let result = enqueue_c2me_command( + &tx, + C2MeCommand::Data { + payload, + flags: 0, + }, + ) + .await; + + assert!(result.is_err(), "closed queue must fail enqueue"); + drop(result); + assert!( + pool.stats().pooled >= 1, + "payload must return to pool when enqueue fails on closed channel" + ); +} + +#[tokio::test] +async fn enqueue_c2me_command_full_then_closed_recycles_waiting_payload() { + let pool = Arc::new(BufferPool::with_config(64, 4)); + let (tx, rx) = mpsc::channel::(1); + + tx.send(C2MeCommand::Data { + payload: make_pooled_payload_from(&pool, &[9]), + flags: 1, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let pool2 = pool.clone(); + let blocked_send = tokio::spawn(async move { + enqueue_c2me_command( + &tx2, + C2MeCommand::Data { + payload: make_pooled_payload_from(&pool2, &[7, 7, 7]), + flags: 2, + }, + ) + .await + }); + + tokio::time::sleep(TokioDuration::from_millis(10)).await; + drop(rx); + + let result = timeout(TokioDuration::from_secs(1), blocked_send) + .await + .expect("blocked send task must finish") + .expect("blocked send task must not panic"); + + assert!( + result.is_err(), + "closing receiver while sender is blocked must fail enqueue" + ); + drop(result); + assert!( + pool.stats().pooled >= 2, + "both queued and blocked payloads must return to pool after channel close" + ); +} + #[test] fn desync_dedup_cache_is_bounded() { let _guard = desync_dedup_test_lock() @@ -101,7 +186,7 @@ fn desync_dedup_cache_is_bounded() { assert!( !should_emit_full_desync(u64::MAX, false, now), - "new key above cap must be suppressed to bound memory" + "new key above cap must remain suppressed to avoid log amplification" ); assert!( @@ -110,6 +195,26 @@ fn desync_dedup_cache_is_bounded() { ); } +#[test] +fn desync_dedup_full_cache_churn_stays_suppressed() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("desync dedup test lock must be available"); + clear_desync_dedup_for_testing(); + + let now = Instant::now(); + for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { + assert!(should_emit_full_desync(key, false, now)); + } + + for offset in 0..2048u64 { + assert!( + !should_emit_full_desync(u64::MAX - offset, false, now), + "fresh full-cache churn must remain suppressed under pressure" + ); + } +} + fn make_forensics_state() -> RelayForensicsState { RelayForensicsState { trace_id: 1, @@ -130,6 +235,12 @@ fn make_crypto_reader(reader: tokio::io::DuplexStream) -> CryptoReader CryptoWriter { + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) +} + fn encrypt_for_reader(plaintext: &[u8]) -> Vec { let key = [0u8; 32]; let iv = 0u128; @@ -199,3 +310,472 @@ async fn read_client_payload_times_out_on_payload_stall() { "stalled payload body read must time out" ); } + +#[tokio::test] +async fn read_client_payload_large_intermediate_frame_is_exact() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(262_144); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let payload_len = buffer_pool.buffer_size().saturating_mul(3).max(65_537); + let mut plaintext = Vec::with_capacity(4 + payload_len); + plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes()); + plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(31))); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let read = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + payload_len + 16, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("payload read must succeed") + .expect("frame must be present"); + + let (frame, quickack) = read; + assert!(!quickack, "quickack flag must be unset"); + assert_eq!(frame.len(), payload_len, "payload size must match wire length"); + for (idx, byte) in frame.iter().enumerate() { + assert_eq!(*byte, (idx as u8).wrapping_mul(31)); + } + assert_eq!(frame_counter, 1, "exactly one frame must be counted"); +} + +#[tokio::test] +async fn read_client_payload_secure_strips_tail_padding_bytes() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let payload = [0x11u8, 0x22, 0x33, 0x44, 0xaa, 0xbb, 0xcc, 0xdd]; + let tail = [0xeeu8, 0xff, 0x99]; + let wire_len = payload.len() + tail.len(); + + let mut plaintext = Vec::with_capacity(4 + wire_len); + plaintext.extend_from_slice(&(wire_len as u32).to_le_bytes()); + plaintext.extend_from_slice(&payload); + plaintext.extend_from_slice(&tail); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let read = read_client_payload( + &mut crypto_reader, + ProtoTag::Secure, + 1024, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("secure payload read must succeed") + .expect("secure frame must be present"); + + let (frame, quickack) = read; + assert!(!quickack, "quickack flag must be unset"); + assert_eq!(frame.as_ref(), &payload); + assert_eq!(frame_counter, 1, "one secure frame must be counted"); +} + +#[tokio::test] +async fn read_client_payload_secure_rejects_wire_len_below_4() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let mut plaintext = Vec::with_capacity(7); + plaintext.extend_from_slice(&3u32.to_le_bytes()); + plaintext.extend_from_slice(&[1u8, 2, 3]); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Secure, + 1024, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(ref msg)) if msg.contains("Frame too small: 3")), + "secure wire length below 4 must be fail-closed by the frame-too-small guard" + ); +} + +#[tokio::test] +async fn read_client_payload_intermediate_skips_zero_len_frame() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let payload = [7u8, 6, 5, 4, 3, 2, 1, 0]; + let mut plaintext = Vec::with_capacity(4 + 4 + payload.len()); + plaintext.extend_from_slice(&0u32.to_le_bytes()); + plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); + plaintext.extend_from_slice(&payload); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let read = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("intermediate payload read must succeed") + .expect("frame must be present"); + + let (frame, quickack) = read; + assert!(!quickack, "quickack flag must be unset"); + assert_eq!(frame.as_ref(), &payload); + assert_eq!(frame_counter, 1, "zero-length frame must be skipped"); +} + +#[tokio::test] +async fn read_client_payload_abridged_extended_len_sets_quickack() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let payload_len = 4 * 130; + let len_words = (payload_len / 4) as u32; + let mut plaintext = Vec::with_capacity(1 + 3 + payload_len); + plaintext.push(0xff | 0x80); + let lw = len_words.to_le_bytes(); + plaintext.extend_from_slice(&lw[..3]); + plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_add(17))); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let read = read_client_payload( + &mut crypto_reader, + ProtoTag::Abridged, + payload_len + 16, + TokioDuration::from_secs(1), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("abridged payload read must succeed") + .expect("frame must be present"); + + let (frame, quickack) = read; + assert!(quickack, "quickack bit must be propagated from abridged header"); + assert_eq!(frame.len(), payload_len); + assert_eq!(frame_counter, 1, "one abridged frame must be counted"); +} + +#[tokio::test] +async fn read_client_payload_returns_buffer_to_pool_after_emit() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let pool = Arc::new(BufferPool::with_config(64, 8)); + pool.preallocate(1); + assert_eq!(pool.stats().pooled, 1, "precondition: one pooled buffer"); + + let (reader, mut writer) = duplex(4096); + let mut crypto_reader = make_crypto_reader(reader); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + // Force growth beyond default pool buffer size to catch ownership-take regressions. + let payload_len = 257usize; + let mut plaintext = Vec::with_capacity(4 + payload_len); + plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes()); + plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(13))); + + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let _ = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + payload_len + 8, + TokioDuration::from_secs(1), + &pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("payload read must succeed") + .expect("frame must be present"); + + assert_eq!(frame_counter, 1); + let pool_stats = pool.stats(); + assert!( + pool_stats.pooled >= 1, + "emitted payload buffer must be returned to pool to avoid pool drain" + ); +} + +#[tokio::test] +async fn read_client_payload_keeps_pool_buffer_checked_out_until_frame_drop() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + + let pool = Arc::new(BufferPool::with_config(64, 2)); + pool.preallocate(1); + assert_eq!(pool.stats().pooled, 1, "one pooled buffer must be available"); + + let (reader, mut writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let payload = [0x41u8, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48]; + let mut plaintext = Vec::with_capacity(4 + payload.len()); + plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); + plaintext.extend_from_slice(&payload); + let encrypted = encrypt_for_reader(&plaintext); + writer.write_all(&encrypted).await.unwrap(); + + let (frame, quickack) = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + TokioDuration::from_secs(1), + &pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await + .expect("payload read must succeed") + .expect("frame must be present"); + + assert!(!quickack); + assert_eq!(frame.as_ref(), &payload); + assert_eq!( + pool.stats().pooled, + 0, + "buffer must stay checked out while frame payload is alive" + ); + + drop(frame); + assert!( + pool.stats().pooled >= 1, + "buffer must return to pool only after frame drop" + ); +} + +#[tokio::test] +async fn enqueue_c2me_close_unblocks_after_queue_drain() { + let (tx, mut rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[0x41]), + flags: 0, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let close_task = tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); + + tokio::time::sleep(TokioDuration::from_millis(10)).await; + + let first = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap() + .expect("first queued item must be present"); + assert!(matches!(first, C2MeCommand::Data { .. })); + + close_task.await.unwrap().expect("close enqueue must succeed after drain"); + + let second = timeout(TokioDuration::from_millis(100), rx.recv()) + .await + .unwrap() + .expect("close command must follow after queue drain"); + assert!(matches!(second, C2MeCommand::Close)); +} + +#[tokio::test] +async fn enqueue_c2me_close_full_then_receiver_drop_fails_cleanly() { + let (tx, rx) = mpsc::channel::(1); + tx.send(C2MeCommand::Data { + payload: make_pooled_payload(&[0x42]), + flags: 0, + }) + .await + .unwrap(); + + let tx2 = tx.clone(); + let close_task = tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); + + tokio::time::sleep(TokioDuration::from_millis(10)).await; + drop(rx); + + let result = timeout(TokioDuration::from_secs(1), close_task) + .await + .expect("close task must finish") + .expect("close task must not panic"); + assert!( + result.is_err(), + "close enqueue must fail cleanly when receiver is dropped under pressure" + ); +} + +#[tokio::test] +async fn process_me_writer_response_ack_obeys_flush_policy() { + let (writer_side, _reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + let immediate = process_me_writer_response( + MeResponse::Ack(0x11223344), + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "user", + &bytes_me2c, + 77, + true, + false, + ) + .await + .expect("ack response must be processed"); + + assert!(matches!( + immediate, + MeWriterResponseOutcome::Continue { + frames: 1, + bytes: 4, + flush_immediately: true, + } + )); + + let delayed = process_me_writer_response( + MeResponse::Ack(0x55667788), + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "user", + &bytes_me2c, + 77, + false, + false, + ) + .await + .expect("ack response must be processed"); + + assert!(matches!( + delayed, + MeWriterResponseOutcome::Continue { + frames: 1, + bytes: 4, + flush_immediately: false, + } + )); +} + +#[tokio::test] +async fn process_me_writer_response_data_updates_byte_accounting() { + let (writer_side, _reader_side) = duplex(1024); + let mut writer = make_crypto_writer(writer_side); + let rng = SecureRandom::new(); + let mut frame_buf = Vec::new(); + let stats = Stats::new(); + let bytes_me2c = AtomicU64::new(0); + + let payload = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9]; + let outcome = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from(payload.clone()), + }, + &mut writer, + ProtoTag::Intermediate, + &rng, + &mut frame_buf, + &stats, + "user", + &bytes_me2c, + 88, + false, + false, + ) + .await + .expect("data response must be processed"); + + assert!(matches!( + outcome, + MeWriterResponseOutcome::Continue { + frames: 1, + bytes, + flush_immediately: false, + } if bytes == payload.len() + )); + assert_eq!( + bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), + payload.len() as u64, + "ME->C byte accounting must increase by emitted payload size" + ); +}