From 3b86a883b909f03846c0fb548eba3c58fe6d5831 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Sat, 21 Mar 2026 14:14:58 +0400 Subject: [PATCH] Add comprehensive tests for relay quota management and adversarial scenarios - Introduced `relay_quota_boundary_blackhat_tests.rs` to validate behavior under quota limits, including edge cases and adversarial conditions. - Added `relay_quota_model_adversarial_tests.rs` to ensure quota management maintains integrity during bidirectional communication and various load scenarios. - Created `relay_quota_overflow_regression_tests.rs` to address overflow issues and ensure that quota limits are respected during aggressive data transmission. - Implemented `route_mode_coherence_adversarial_tests.rs` to verify the consistency of route mode transitions and timestamp management across different relay modes. --- src/proxy/client.rs | 4 + src/proxy/relay.rs | 38 +- src/proxy/route_mode.rs | 17 +- ...nt_masking_probe_evasion_blackhat_tests.rs | 344 +++++++++++++++ .../relay_quota_boundary_blackhat_tests.rs | 416 ++++++++++++++++++ .../relay_quota_model_adversarial_tests.rs | 300 +++++++++++++ .../relay_quota_overflow_regression_tests.rs | 194 ++++++++ .../route_mode_coherence_adversarial_tests.rs | 228 ++++++++++ 8 files changed, 1529 insertions(+), 12 deletions(-) create mode 100644 src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs create mode 100644 src/proxy/tests/relay_quota_boundary_blackhat_tests.rs create mode 100644 src/proxy/tests/relay_quota_model_adversarial_tests.rs create mode 100644 src/proxy/tests/relay_quota_overflow_regression_tests.rs create mode 100644 src/proxy/tests/route_mode_coherence_adversarial_tests.rs diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 5eb2a22..65b893d 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -1273,3 +1273,7 @@ mod masking_shape_hardening_redteam_expected_fail_tests; #[cfg(test)] #[path = "tests/client_masking_shape_classifier_fuzz_redteam_expected_fail_tests.rs"] mod masking_shape_classifier_fuzz_redteam_expected_fail_tests; + +#[cfg(test)] +#[path = "tests/client_masking_probe_evasion_blackhat_tests.rs"] +mod masking_probe_evasion_blackhat_tests; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 949f2c2..c0cf3d4 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -364,6 +364,25 @@ impl AsyncRead for StatsIo { Poll::Ready(Ok(())) => { let n = buf.filled().len() - before; if n > 0 { + let mut reached_quota_boundary = false; + if let Some(limit) = this.quota_limit { + let used = this.stats.get_user_total_octets(&this.user); + if used >= limit { + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } + + let remaining = limit - used; + if (n as u64) > remaining { + // Fail closed: when a single read chunk would cross quota, + // stop relay immediately without accounting beyond the cap. + this.quota_exceeded.store(true, Ordering::Relaxed); + return Poll::Ready(Err(quota_io_error())); + } + + reached_quota_boundary = (n as u64) == remaining; + } + // C→S: client sent data this.counters.c2s_bytes.fetch_add(n as u64, Ordering::Relaxed); this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed); @@ -372,11 +391,8 @@ impl AsyncRead for StatsIo { this.stats.add_user_octets_from(&this.user, n as u64); this.stats.increment_user_msgs_from(&this.user); - if let Some(limit) = this.quota_limit - && this.stats.get_user_total_octets(&this.user) >= limit - { + if reached_quota_boundary { this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); } trace!(user = %this.user, bytes = n, "C->S"); @@ -701,4 +717,16 @@ mod adversarial_tests; #[cfg(test)] #[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"] -mod relay_quota_lock_pressure_adversarial_tests; \ No newline at end of file +mod relay_quota_lock_pressure_adversarial_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_boundary_blackhat_tests.rs"] +mod relay_quota_boundary_blackhat_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_model_adversarial_tests.rs"] +mod relay_quota_model_adversarial_tests; + +#[cfg(test)] +#[path = "tests/relay_quota_overflow_regression_tests.rs"] +mod relay_quota_overflow_regression_tests; \ No newline at end of file diff --git a/src/proxy/route_mode.rs b/src/proxy/route_mode.rs index a3dea85..e2232d2 100644 --- a/src/proxy/route_mode.rs +++ b/src/proxy/route_mode.rs @@ -71,6 +71,12 @@ impl RouteRuntimeController { if state.mode == mode { return false; } + if matches!(mode, RelayRouteMode::Direct) { + self.direct_since_epoch_secs + .store(now_epoch_secs(), Ordering::Relaxed); + } else { + self.direct_since_epoch_secs.store(0, Ordering::Relaxed); + } state.mode = mode; state.generation = state.generation.saturating_add(1); next = Some(*state); @@ -81,13 +87,6 @@ impl RouteRuntimeController { return None; } - if matches!(mode, RelayRouteMode::Direct) { - self.direct_since_epoch_secs - .store(now_epoch_secs(), Ordering::Relaxed); - } else { - self.direct_since_epoch_secs.store(0, Ordering::Relaxed); - } - next } } @@ -135,3 +134,7 @@ pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duratio #[cfg(test)] #[path = "tests/route_mode_security_tests.rs"] mod security_tests; + +#[cfg(test)] +#[path = "tests/route_mode_coherence_adversarial_tests.rs"] +mod coherence_adversarial_tests; diff --git a/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs b/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs new file mode 100644 index 0000000..1208071 --- /dev/null +++ b/src/proxy/tests/client_masking_probe_evasion_blackhat_tests.rs @@ -0,0 +1,344 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; + +const REPLY_404: &[u8] = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n"; + +fn make_test_upstream_manager(stats: Arc) -> Arc { + Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats, + )) +} + +fn masking_config(mask_port: u16) -> Arc { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + Arc::new(cfg) +} + +async fn run_generic_probe_and_capture_prefix(payload: Vec, expected_prefix: Vec) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let reply = REPLY_404.to_vec(); + let prefix_len = expected_prefix.len(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; prefix_len]; + stream.read_exact(&mut got).await.unwrap(); + stream.write_all(&reply).await.unwrap(); + got + }); + + let config = masking_config(backend_addr.port()); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.210:55110".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&payload).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let mut observed = vec![0u8; REPLY_404.len()]; + tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, REPLY_404); + + let got = tokio::time::timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, expected_prefix); + + let result = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); +} + +async fn read_http_probe_header(stream: &mut TcpStream) -> Vec { + let mut out = Vec::with_capacity(96); + let mut one = [0u8; 1]; + + loop { + stream.read_exact(&mut one).await.unwrap(); + out.push(one[0]); + if out.ends_with(b"\r\n\r\n") { + break; + } + assert!( + out.len() <= 512, + "probe header exceeded sane limit while waiting for terminator" + ); + } + + out +} + +#[tokio::test] +async fn blackhat_fragmented_plain_http_probe_masks_and_preserves_prefix() { + let payload = b"GET /probe-evasion HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + run_generic_probe_and_capture_prefix(payload.clone(), payload).await; +} + +#[tokio::test] +async fn blackhat_invalid_tls_like_probe_masks_and_preserves_header_prefix() { + let payload = vec![0x16, 0x03, 0x03, 0x00, 0x64, 0x01, 0x00]; + run_generic_probe_and_capture_prefix(payload.clone(), payload).await; +} + +#[tokio::test] +async fn integration_client_handler_plain_probe_masks_and_preserves_prefix() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let payload = b"GET /integration-probe HTTP/1.1\r\nHost: a.example\r\n\r\n".to_vec(); + let expected_prefix = payload.clone(); + + let accept_task = tokio::spawn(async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = vec![0u8; expected_prefix.len()]; + stream.read_exact(&mut got).await.unwrap(); + stream.write_all(REPLY_404).await.unwrap(); + got + }); + + let config = masking_config(backend_addr.port()); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&payload).await.unwrap(); + client.shutdown().await.unwrap(); + + let mut observed = vec![0u8; REPLY_404.len()]; + tokio::time::timeout(Duration::from_secs(2), client.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, REPLY_404); + + let got = tokio::time::timeout(Duration::from_secs(2), accept_task) + .await + .unwrap() + .unwrap(); + assert_eq!(got, payload); + + let result = tokio::time::timeout(Duration::from_secs(2), server_task) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); +} + +#[tokio::test] +async fn light_fuzz_small_probe_variants_always_mask_and_preserve_declared_prefix() { + let mut rng = StdRng::seed_from_u64(0xA11E_5EED_F0F0_CAFE); + + for i in 0..24usize { + let mut payload = if rng.random::() { + b"GET /fuzz HTTP/1.1\r\nHost: fuzz.example\r\n\r\n".to_vec() + } else { + vec![0x16, 0x03, 0x03, 0x00, 0x64] + }; + + let tail_len = rng.random_range(0..=8usize); + for _ in 0..tail_len { + payload.push(rng.random::()); + } + + let expected_prefix = payload.clone(); + run_generic_probe_and_capture_prefix(payload, expected_prefix).await; + + if i % 6 == 0 { + tokio::task::yield_now().await; + } + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_probe_mix_masks_all_sessions_without_cross_leakage() { + let session_count = 12usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected = std::collections::HashSet::new(); + for idx in 0..session_count { + let probe = format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes(); + expected.insert(probe); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected; + for _ in 0..session_count { + let (mut stream, _) = listener.accept().await.unwrap(); + let head = read_http_probe_header(&mut stream).await; + stream.write_all(REPLY_404).await.unwrap(); + assert!(remaining.remove(&head), "backend received unexpected or duplicated probe prefix"); + } + assert!(remaining.is_empty(), "all session prefixes must be observed exactly once"); + }); + + let mut tasks = Vec::with_capacity(session_count); + for idx in 0..session_count { + let config = masking_config(backend_addr.port()); + let stats = Arc::new(Stats::new()); + let upstream_manager = make_test_upstream_manager(stats.clone()); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let probe = format!("GET /stress-{idx} HTTP/1.1\r\nHost: s{idx}.example\r\n\r\n").into_bytes(); + let peer: SocketAddr = format!("203.0.113.{}:{}", 30 + idx, 56000 + idx) + .parse() + .unwrap(); + + tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let mut observed = vec![0u8; REPLY_404.len()]; + tokio::time::timeout(Duration::from_secs(2), client_side.read_exact(&mut observed)) + .await + .unwrap() + .unwrap(); + assert_eq!(observed, REPLY_404); + + let result = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + assert!(result.is_ok()); + })); + } + + for task in tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(4), accept_task) + .await + .unwrap() + .unwrap(); +} diff --git a/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs b/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs new file mode 100644 index 0000000..c8395aa --- /dev/null +++ b/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs @@ -0,0 +1,416 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::sync::Arc; +use tokio::io::{duplex, AsyncRead, AsyncReadExt, AsyncWriteExt}; +use tokio::time::{timeout, Duration, Instant}; + +async fn read_available(reader: &mut R, budget: Duration) -> usize { + let start = Instant::now(); + let mut total = 0usize; + let mut buf = [0u8; 256]; + + loop { + let elapsed = start.elapsed(); + if elapsed >= budget { + break; + } + let remaining = budget.saturating_sub(elapsed); + match timeout(remaining, reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => total = total.saturating_add(n), + Ok(Err(_)) | Err(_) => break, + } + } + + total +} + +#[tokio::test] +async fn integration_full_duplex_exact_budget_then_hard_cutoff() { + let stats = Arc::new(Stats::new()); + let user = "quota-full-duplex-boundary-user"; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 1024, + 1024, + user, + Arc::clone(&stats), + Some(10), + Arc::new(BufferPool::new()), + )); + + client_peer.write_all(&[0x10, 0x11, 0x12, 0x13]).await.unwrap(); + let mut c2s = [0u8; 4]; + server_peer.read_exact(&mut c2s).await.unwrap(); + assert_eq!(c2s, [0x10, 0x11, 0x12, 0x13]); + + server_peer + .write_all(&[0x20, 0x21, 0x22, 0x23, 0x24, 0x25]) + .await + .unwrap(); + let mut s2c = [0u8; 6]; + client_peer.read_exact(&mut s2c).await.unwrap(); + assert_eq!(s2c, [0x20, 0x21, 0x22, 0x23, 0x24, 0x25]); + + let _ = client_peer.write_all(&[0x99]).await; + let _ = server_peer.write_all(&[0x88]).await; + + let mut probe_server = [0u8; 1]; + let mut probe_client = [0u8; 1]; + let leaked_to_server = timeout(Duration::from_millis(120), server_peer.read(&mut probe_server)).await; + let leaked_to_client = timeout(Duration::from_millis(120), client_peer.read(&mut probe_client)).await; + + assert!( + !matches!(leaked_to_server, Ok(Ok(n)) if n > 0), + "once quota is exhausted, no extra client byte must be forwarded" + ); + assert!( + !matches!(leaked_to_client, Ok(Ok(n)) if n > 0), + "once quota is exhausted, no extra server byte must be forwarded" + ); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate under quota cutoff") + .expect("relay task must not panic"); + + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-full-duplex-boundary-user" + )); + assert!(stats.get_user_total_octets(user) <= 10); +} + +#[tokio::test] +async fn negative_preloaded_quota_blocks_both_directions_immediately() { + let stats = Arc::new(Stats::new()); + let user = "quota-preloaded-cutoff-user"; + stats.add_user_octets_from(user, 5); + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 512, + 512, + user, + Arc::clone(&stats), + Some(5), + Arc::new(BufferPool::new()), + )); + + let _ = tokio::join!( + client_peer.write_all(&[0x41, 0x42]), + server_peer.write_all(&[0x51, 0x52]), + ); + + let leaked_to_server = read_available(&mut server_peer, Duration::from_millis(120)).await; + let leaked_to_client = read_available(&mut client_peer, Duration::from_millis(120)).await; + + assert_eq!(leaked_to_server, 0, "preloaded limit must block C->S immediately"); + assert_eq!(leaked_to_client, 0, "preloaded limit must block S->C immediately"); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate under preloaded cutoff") + .expect("relay task must not panic"); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(stats.get_user_total_octets(user) <= 5); +} + +#[tokio::test] +async fn edge_quota_one_bidirectional_race_allows_at_most_one_forwarded_octet() { + let stats = Arc::new(Stats::new()); + let user = "quota-one-race-user"; + + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let _ = tokio::join!(client_peer.write_all(&[0xAA]), server_peer.write_all(&[0xBB])); + + let mut to_server = [0u8; 1]; + let mut to_client = [0u8; 1]; + + let delivered_server = match timeout(Duration::from_millis(120), server_peer.read(&mut to_server)).await { + Ok(Ok(n)) => n, + _ => 0, + }; + let delivered_client = match timeout(Duration::from_millis(120), client_peer.read(&mut to_client)).await { + Ok(Ok(n)) => n, + _ => 0, + }; + + assert!( + delivered_server + delivered_client <= 1, + "quota=1 must not allow >1 forwarded byte across both directions" + ); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate under quota=1") + .expect("relay task must not panic"); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(stats.get_user_total_octets(user) <= 1); +} + +#[tokio::test] +async fn adversarial_blackhat_alternating_fragmented_jitter_never_overshoots_global_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-blackhat-jitter-user"; + let quota = 32u64; + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(quota), + Arc::new(BufferPool::new()), + )); + + let mut delivered_to_server = 0usize; + let mut delivered_to_client = 0usize; + + for i in 0..256usize { + if relay.is_finished() { + break; + } + + if (i & 1) == 0 { + let _ = client_peer.write_all(&[(i as u8) ^ 0x5A]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(4), server_peer.read(&mut one)).await { + delivered_to_server = delivered_to_server.saturating_add(n); + } + } else { + let _ = server_peer.write_all(&[(i as u8) ^ 0xA5]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(4), client_peer.read(&mut one)).await { + delivered_to_client = delivered_to_client.saturating_add(n); + } + } + + tokio::time::sleep(Duration::from_millis(((i % 3) + 1) as u64)).await; + } + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate under black-hat jitter attack") + .expect("relay task must not panic"); + + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!( + delivered_to_server + delivered_to_client <= quota as usize, + "combined forwarded bytes must never exceed configured quota" + ); + assert!(stats.get_user_total_octets(user) <= quota); +} + +#[tokio::test] +async fn light_fuzz_randomized_schedule_preserves_quota_and_forwarded_byte_invariants() { + let mut rng = StdRng::seed_from_u64(0xD15C_A11E_F00D_BAAD); + + for case in 0..48u64 { + let stats = Arc::new(Stats::new()); + let user = format!("quota-fuzz-schedule-{case}"); + let quota = rng.random_range(1u64..=32u64); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + Arc::clone(&relay_stats), + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut delivered_total = 0usize; + + for _ in 0..96usize { + if relay.is_finished() { + break; + } + + if rng.random::() { + let _ = client_peer.write_all(&[rng.random::()]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(3), server_peer.read(&mut one)).await { + delivered_total = delivered_total.saturating_add(n); + } + } else { + let _ = server_peer.write_all(&[rng.random::()]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(3), client_peer.read(&mut one)).await { + delivered_total = delivered_total.saturating_add(n); + } + } + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("fuzz relay must terminate") + .expect("fuzz relay task must not panic"); + + assert!( + relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + "relay must either close cleanly or terminate via typed quota error" + ); + assert!( + delivered_total <= quota as usize, + "fuzz case {case}: forwarded bytes must not exceed quota" + ); + assert!( + stats.get_user_total_octets(&user) <= quota, + "fuzz case {case}: accounted bytes must not exceed quota" + ); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_multi_relay_same_user_mixed_direction_jitter_respects_global_quota() { + let stats = Arc::new(Stats::new()); + let user = "quota-stress-multi-relay-user"; + let quota = 64u64; + + let mut workers = Vec::new(); + + for worker_id in 0..4u8 { + let stats = Arc::clone(&stats); + let user = user.to_string(); + + workers.push(tokio::spawn(async move { + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + Arc::clone(&stats), + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut delivered = 0usize; + + for step in 0..96u8 { + if relay.is_finished() { + break; + } + + if ((step as usize + worker_id as usize) & 1) == 0 { + let _ = client_peer.write_all(&[step ^ 0x3C]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(3), server_peer.read(&mut one)).await { + delivered = delivered.saturating_add(n); + } + } else { + let _ = server_peer.write_all(&[step ^ 0xC3]).await; + let mut one = [0u8; 1]; + if let Ok(Ok(n)) = timeout(Duration::from_millis(3), client_peer.read(&mut one)).await { + delivered = delivered.saturating_add(n); + } + } + + tokio::time::sleep(Duration::from_millis((((worker_id as u64) + (step as u64)) % 3) + 1)).await; + } + + drop(client_peer); + drop(server_peer); + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("stress relay must terminate") + .expect("stress relay task must not panic"); + + assert!( + relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + "stress relay must either close cleanly or terminate via typed quota error" + ); + delivered + })); + } + + let mut delivered_sum = 0usize; + for worker in workers { + delivered_sum = delivered_sum.saturating_add(worker.await.expect("stress worker must not panic")); + } + + assert!( + stats.get_user_total_octets(user) <= quota, + "global per-user quota must hold under concurrent mixed-direction relay stress" + ); + assert!( + delivered_sum <= quota as usize, + "combined delivered bytes across relays must stay within global quota" + ); +} diff --git a/src/proxy/tests/relay_quota_model_adversarial_tests.rs b/src/proxy/tests/relay_quota_model_adversarial_tests.rs new file mode 100644 index 0000000..0a06ba8 --- /dev/null +++ b/src/proxy/tests/relay_quota_model_adversarial_tests.rs @@ -0,0 +1,300 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::sync::Arc; +use tokio::io::{duplex, AsyncRead, AsyncReadExt, AsyncWriteExt}; +use tokio::sync::Barrier; +use tokio::time::{timeout, Duration}; + +fn assert_is_prefix(received: &[u8], sent: &[u8], direction: &str) { + assert!( + sent.starts_with(received), + "{direction} stream corruption: received={} sent={} (received must be prefix of sent)", + received.len(), + sent.len() + ); +} + +async fn drain_available(reader: &mut R, out: &mut Vec, rounds: usize) { + for _ in 0..rounds { + let mut buf = [0u8; 64]; + match timeout(Duration::from_millis(2), reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => out.extend_from_slice(&buf[..n]), + Ok(Err(_)) | Err(_) => break, + } + } +} + +#[tokio::test] +async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() { + let mut rng = StdRng::seed_from_u64(0xC0DE_CAFE_D15C_F00D); + + for case in 0..64u64 { + let stats = Arc::new(Stats::new()); + let user = format!("quota-model-fuzz-{case}"); + let quota = rng.random_range(1u64..=64u64); + + let (mut client_peer, relay_client) = duplex(8192); + let (relay_server, mut server_peer) = duplex(8192); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + &relay_user, + relay_stats, + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut sent_c2s = Vec::new(); + let mut sent_s2c = Vec::new(); + let mut recv_at_server = Vec::new(); + let mut recv_at_client = Vec::new(); + + for _ in 0..96usize { + if relay.is_finished() { + break; + } + + let do_c2s = rng.random::(); + let chunk_len = rng.random_range(1usize..=12usize); + let mut chunk = vec![0u8; chunk_len]; + for b in &mut chunk { + *b = rng.random::(); + } + + if do_c2s { + if client_peer.write_all(&chunk).await.is_ok() { + sent_c2s.extend_from_slice(&chunk); + } + } else if server_peer.write_all(&chunk).await.is_ok() { + sent_s2c.extend_from_slice(&chunk); + } + + drain_available(&mut server_peer, &mut recv_at_server, 2).await; + drain_available(&mut client_peer, &mut recv_at_client, 2).await; + + assert_is_prefix(&recv_at_server, &sent_c2s, "C->S"); + assert_is_prefix(&recv_at_client, &sent_s2c, "S->C"); + assert!( + recv_at_server.len() + recv_at_client.len() <= quota as usize, + "fuzz case {case}: delivered bytes exceed quota" + ); + assert!( + stats.get_user_total_octets(&user) <= quota, + "fuzz case {case}: accounted bytes exceed quota" + ); + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("fuzz relay must terminate") + .expect("fuzz relay task must not panic"); + + assert!( + relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + "fuzz case {case}: relay must end cleanly or with typed quota error" + ); + + assert_is_prefix(&recv_at_server, &sent_c2s, "C->S final"); + assert_is_prefix(&recv_at_client, &sent_s2c, "S->C final"); + assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize); + assert!(stats.get_user_total_octets(&user) <= quota); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn adversarial_dual_direction_cutoff_race_allows_at_most_one_forwarded_byte() { + let stats = Arc::new(Stats::new()); + let user = "quota-dual-race-user"; + + let (mut client_peer, relay_client) = duplex(1024); + let (relay_server, mut server_peer) = duplex(1024); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 128, + 128, + user, + Arc::clone(&stats), + Some(1), + Arc::new(BufferPool::new()), + )); + + let gate = Arc::new(Barrier::new(3)); + + let writer_c2s = { + let gate = Arc::clone(&gate); + tokio::spawn(async move { + gate.wait().await; + let _ = client_peer.write_all(&[0xA1]).await; + client_peer + }) + }; + + let writer_s2c = { + let gate = Arc::clone(&gate); + tokio::spawn(async move { + gate.wait().await; + let _ = server_peer.write_all(&[0xB2]).await; + server_peer + }) + }; + + gate.wait().await; + + let mut client_peer = writer_c2s.await.expect("c2s writer must not panic"); + let mut server_peer = writer_s2c.await.expect("s2c writer must not panic"); + + let mut got_at_server = [0u8; 1]; + let mut got_at_client = [0u8; 1]; + + let n_server = match timeout(Duration::from_millis(120), server_peer.read(&mut got_at_server)).await { + Ok(Ok(n)) => n, + _ => 0, + }; + let n_client = match timeout(Duration::from_millis(120), client_peer.read(&mut got_at_client)).await { + Ok(Ok(n)) => n, + _ => 0, + }; + + assert!( + n_server + n_client <= 1, + "quota=1 race must not forward both concurrent direction bytes" + ); + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("quota race relay must terminate") + .expect("quota race relay task must not panic"); + + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(stats.get_user_total_octets(user) <= 1); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_model_load() { + let stats = Arc::new(Stats::new()); + let user = "quota-model-stress-user"; + let quota = 96u64; + + let mut workers = Vec::new(); + for worker_id in 0..6u64 { + let stats = Arc::clone(&stats); + let user = user.to_string(); + + workers.push(tokio::spawn(async move { + let mut rng = StdRng::seed_from_u64(0x9E37_79B9_7F4A_7C15 ^ worker_id); + + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 192, + 192, + &relay_user, + relay_stats, + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + let mut sent_c2s = Vec::new(); + let mut sent_s2c = Vec::new(); + let mut recv_at_server = Vec::new(); + let mut recv_at_client = Vec::new(); + + for _ in 0..64usize { + if relay.is_finished() { + break; + } + + let choose_c2s = rng.random::(); + let len = rng.random_range(1usize..=10usize); + let mut payload = vec![0u8; len]; + for b in &mut payload { + *b = rng.random::(); + } + + if choose_c2s { + if client_peer.write_all(&payload).await.is_ok() { + sent_c2s.extend_from_slice(&payload); + } + } else if server_peer.write_all(&payload).await.is_ok() { + sent_s2c.extend_from_slice(&payload); + } + + drain_available(&mut server_peer, &mut recv_at_server, 2).await; + drain_available(&mut client_peer, &mut recv_at_client, 2).await; + + assert_is_prefix(&recv_at_server, &sent_c2s, "stress C->S"); + assert_is_prefix(&recv_at_client, &sent_s2c, "stress S->C"); + } + + drop(client_peer); + drop(server_peer); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("stress relay must terminate") + .expect("stress relay task must not panic"); + + assert!( + relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + "stress relay must end cleanly or with typed quota error" + ); + + recv_at_server.len() + recv_at_client.len() + })); + } + + let mut delivered_sum = 0usize; + for worker in workers { + delivered_sum = delivered_sum.saturating_add(worker.await.expect("worker must not panic")); + } + + assert!( + stats.get_user_total_octets(user) <= quota, + "global per-user quota must never overshoot under concurrent multi-relay model load" + ); + assert!( + delivered_sum <= quota as usize, + "aggregate delivered bytes across relays must remain within global quota" + ); +} diff --git a/src/proxy/tests/relay_quota_overflow_regression_tests.rs b/src/proxy/tests/relay_quota_overflow_regression_tests.rs new file mode 100644 index 0000000..207d603 --- /dev/null +++ b/src/proxy/tests/relay_quota_overflow_regression_tests.rs @@ -0,0 +1,194 @@ +use super::relay_bidirectional; +use crate::error::ProxyError; +use crate::stats::Stats; +use crate::stream::BufferPool; +use std::sync::Arc; +use tokio::io::{duplex, AsyncRead, AsyncReadExt, AsyncWriteExt}; +use tokio::time::{timeout, Duration}; + +async fn read_available(reader: &mut R, budget_ms: u64) -> usize { + let mut total = 0usize; + loop { + let mut buf = [0u8; 64]; + match timeout(Duration::from_millis(budget_ms), reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => total = total.saturating_add(n), + Ok(Err(_)) | Err(_) => break, + } + } + total +} + +#[tokio::test] +async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_accounting() { + let stats = Arc::new(Stats::new()); + let user = "quota-overflow-regression-client-chunk"; + + // Leave only 1 byte remaining under quota. + stats.add_user_octets_from(user, 9); + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 512, + 512, + user, + Arc::clone(&stats), + Some(10), + Arc::new(BufferPool::new()), + )); + + // Single chunk attempts to cross remaining budget (4 > 1). + client_peer.write_all(&[0x11, 0x22, 0x33, 0x44]).await.unwrap(); + client_peer.shutdown().await.unwrap(); + + let forwarded = read_available(&mut server_peer, 60).await; + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate after quota overflow attempt") + .expect("relay task must not panic"); + + assert_eq!( + forwarded, 0, + "overflowing C->S chunk must not be forwarded when it exceeds remaining quota" + ); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!( + stats.get_user_total_octets(user) <= 10, + "accounted bytes must never exceed quota after overflowing chunk" + ); +} + +#[tokio::test] +async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_off() { + let stats = Arc::new(Stats::new()); + let user = "quota-overflow-regression-boundary"; + + // Leave exactly 4 bytes remaining. + stats.add_user_octets_from(user, 6); + + let (mut client_peer, relay_client) = duplex(2048); + let (relay_server, mut server_peer) = duplex(2048); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay = tokio::spawn(relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 256, + 256, + user, + Arc::clone(&stats), + Some(10), + Arc::new(BufferPool::new()), + )); + + // Exact boundary write should pass once. + client_peer.write_all(&[0xAA, 0xBB, 0xCC, 0xDD]).await.unwrap(); + + let mut exact = [0u8; 4]; + timeout(Duration::from_secs(1), server_peer.read_exact(&mut exact)) + .await + .unwrap() + .unwrap(); + assert_eq!(exact, [0xAA, 0xBB, 0xCC, 0xDD]); + + // Any extra byte after boundary should be rejected/cut off. + let _ = client_peer.write_all(&[0xEE]).await; + client_peer.shutdown().await.unwrap(); + + let leaked_after = read_available(&mut server_peer, 60).await; + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("relay must terminate at quota boundary") + .expect("relay task must not panic"); + + assert_eq!( + leaked_after, 0, + "no bytes may pass after exact boundary is consumed" + ); + assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + assert!(stats.get_user_total_octets(user) <= 10); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { + let stats = Arc::new(Stats::new()); + let user = "quota-overflow-regression-stress"; + let quota = 12u64; + + let mut handles = Vec::new(); + for _ in 0..4usize { + let stats = Arc::clone(&stats); + let user = user.to_string(); + + handles.push(tokio::spawn(async move { + let (mut client_peer, relay_client) = duplex(4096); + let (relay_server, mut server_peer) = duplex(4096); + let (client_reader, client_writer) = tokio::io::split(relay_client); + let (server_reader, server_writer) = tokio::io::split(relay_server); + + let relay_user = user.clone(); + let relay_stats = Arc::clone(&stats); + let relay = tokio::spawn(async move { + relay_bidirectional( + client_reader, + client_writer, + server_reader, + server_writer, + 192, + 192, + &relay_user, + relay_stats, + Some(quota), + Arc::new(BufferPool::new()), + ) + .await + }); + + // Aggressive sender tries to overflow shared user quota. + let burst = vec![0x5Au8; 64]; + let _ = client_peer.write_all(&burst).await; + let _ = client_peer.shutdown().await; + + let mut forwarded = 0usize; + forwarded = forwarded.saturating_add(read_available(&mut server_peer, 40).await); + + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .expect("stress relay must terminate") + .expect("stress relay task must not panic"); + + assert!( + relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })), + "stress relay must finish cleanly or with typed quota error" + ); + forwarded + })); + } + + let mut forwarded_sum = 0usize; + for handle in handles { + forwarded_sum = forwarded_sum.saturating_add(handle.await.expect("worker must not panic")); + } + + assert!( + forwarded_sum <= quota as usize, + "aggregate forwarded bytes across relays must stay within global user quota" + ); + assert!( + stats.get_user_total_octets(user) <= quota, + "global accounted bytes must stay within quota under overflow stress" + ); +} diff --git a/src/proxy/tests/route_mode_coherence_adversarial_tests.rs b/src/proxy/tests/route_mode_coherence_adversarial_tests.rs new file mode 100644 index 0000000..e1e8e0a --- /dev/null +++ b/src/proxy/tests/route_mode_coherence_adversarial_tests.rs @@ -0,0 +1,228 @@ +use super::*; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::sync::Arc; + +#[test] +fn positive_direct_cutover_sets_timestamp_and_snapshot_coherently() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + let rx = runtime.subscribe(); + + assert!( + runtime.direct_since_epoch_secs().is_none(), + "middle startup must not expose direct-since timestamp" + ); + + let emitted = runtime + .set_mode(RelayRouteMode::Direct) + .expect("middle->direct must emit cutover"); + let observed = *rx.borrow(); + + assert_eq!(observed, emitted, "watch snapshot must match emitted cutover"); + assert_eq!(observed.mode, RelayRouteMode::Direct); + assert!( + runtime.direct_since_epoch_secs().is_some(), + "direct cutover must publish a non-empty direct-since timestamp" + ); +} + +#[test] +fn negative_idempotent_set_mode_does_not_mutate_timestamp_or_generation() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + + let before_state = runtime.snapshot(); + let before_ts = runtime.direct_since_epoch_secs(); + + let changed = runtime.set_mode(RelayRouteMode::Direct); + + let after_state = runtime.snapshot(); + let after_ts = runtime.direct_since_epoch_secs(); + + assert!(changed.is_none(), "idempotent set_mode must return None"); + assert_eq!( + after_state.generation, before_state.generation, + "idempotent set_mode must not advance generation" + ); + assert_eq!( + after_ts, before_ts, + "idempotent set_mode must not alter direct-since timestamp" + ); +} + +#[test] +fn edge_middle_cutover_clears_timestamp() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Direct); + let rx = runtime.subscribe(); + + assert!( + runtime.direct_since_epoch_secs().is_some(), + "direct startup must expose direct-since timestamp" + ); + + let emitted = runtime + .set_mode(RelayRouteMode::Middle) + .expect("direct->middle must emit cutover"); + let observed = *rx.borrow(); + + assert_eq!(observed, emitted, "watch snapshot must match emitted cutover"); + assert_eq!(observed.mode, RelayRouteMode::Middle); + assert!( + runtime.direct_since_epoch_secs().is_none(), + "middle cutover must clear direct-since timestamp" + ); +} + +#[test] +fn adversarial_blackhat_probe_sequence_observes_consistent_mode_timestamp_pairs() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + let rx = runtime.subscribe(); + + for _ in 0..2048usize { + let emitted_direct = runtime + .set_mode(RelayRouteMode::Direct) + .expect("middle->direct must emit"); + let observed_direct = *rx.borrow(); + assert_eq!(observed_direct, emitted_direct); + assert!( + runtime.direct_since_epoch_secs().is_some(), + "direct observation must never expose empty timestamp" + ); + + let emitted_middle = runtime + .set_mode(RelayRouteMode::Middle) + .expect("direct->middle must emit"); + let observed_middle = *rx.borrow(); + assert_eq!(observed_middle, emitted_middle); + assert!( + runtime.direct_since_epoch_secs().is_none(), + "middle observation must never expose direct timestamp" + ); + } +} + +#[test] +fn integration_subscriber_and_runtime_gates_stay_coherent_across_cutovers() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + let rx = runtime.subscribe(); + + let plan = [ + RelayRouteMode::Direct, + RelayRouteMode::Middle, + RelayRouteMode::Direct, + RelayRouteMode::Middle, + RelayRouteMode::Direct, + ]; + + let mut expected_generation = 0u64; + + for mode in plan { + let emitted = runtime + .set_mode(mode) + .expect("each planned transition toggles mode and must emit"); + expected_generation = expected_generation.saturating_add(1); + + let watched = *rx.borrow(); + let snapshot = runtime.snapshot(); + + assert_eq!(emitted.mode, mode); + assert_eq!(emitted.generation, expected_generation); + assert_eq!(watched, emitted); + assert_eq!(snapshot, emitted); + + if matches!(mode, RelayRouteMode::Direct) { + assert!(runtime.direct_since_epoch_secs().is_some()); + } else { + assert!(runtime.direct_since_epoch_secs().is_none()); + } + } +} + +#[test] +fn light_fuzz_random_mode_plan_preserves_timestamp_and_generation_invariants() { + let runtime = RouteRuntimeController::new(RelayRouteMode::Middle); + let mut rng = StdRng::seed_from_u64(0x5EED_CAFE_D15C_A11E); + + let mut expected_mode = RelayRouteMode::Middle; + let mut expected_generation = 0u64; + + for _ in 0..25_000usize { + let candidate = if rng.random::() { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + + let changed = runtime.set_mode(candidate); + if candidate == expected_mode { + assert!(changed.is_none(), "idempotent fuzz step must not emit"); + continue; + } + + expected_mode = candidate; + expected_generation = expected_generation.saturating_add(1); + + let emitted = changed.expect("non-idempotent fuzz step must emit"); + assert_eq!(emitted.mode, expected_mode); + assert_eq!(emitted.generation, expected_generation); + + let snapshot = runtime.snapshot(); + assert_eq!(snapshot, emitted, "snapshot must match emitted cutover"); + + if matches!(snapshot.mode, RelayRouteMode::Direct) { + assert!( + runtime.direct_since_epoch_secs().is_some(), + "direct fuzz state must expose timestamp" + ); + } else { + assert!( + runtime.direct_since_epoch_secs().is_none(), + "middle fuzz state must clear timestamp" + ); + } + } +} + +#[test] +fn stress_parallel_subscribers_never_observe_generation_regression() { + let runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); + + let mut readers = Vec::new(); + for _ in 0..4usize { + let runtime = Arc::clone(&runtime); + readers.push(std::thread::spawn(move || { + let rx = runtime.subscribe(); + let mut last = rx.borrow().generation; + for _ in 0..10_000usize { + let current = rx.borrow().generation; + assert!( + current >= last, + "watch generation must be monotonic for every subscriber" + ); + last = current; + std::thread::yield_now(); + } + })); + } + + for step in 0..20_000usize { + let mode = if (step & 1) == 0 { + RelayRouteMode::Direct + } else { + RelayRouteMode::Middle + }; + let _ = runtime.set_mode(mode); + } + + for reader in readers { + reader + .join() + .expect("parallel subscriber reader must not panic"); + } + + let final_state = runtime.snapshot(); + if matches!(final_state.mode, RelayRouteMode::Direct) { + assert!(runtime.direct_since_epoch_secs().is_some()); + } else { + assert!(runtime.direct_since_epoch_secs().is_none()); + } +}