diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 2a84353..d833019 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1977,3 +1977,7 @@ mod middle_relay_tiny_frame_debt_concurrency_security_tests; #[cfg(test)] #[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"] mod middle_relay_tiny_frame_debt_proto_chunking_security_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_atomic_quota_invariant_tests.rs"] +mod middle_relay_atomic_quota_invariant_tests; diff --git a/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs b/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs new file mode 100644 index 0000000..7c176bc --- /dev/null +++ b/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs @@ -0,0 +1,189 @@ +use super::*; +use crate::crypto::AesCtr; +use bytes::Bytes; +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::AsyncWrite; + +struct CountedWriter { + write_calls: Arc, + fail_writes: bool, +} + +impl CountedWriter { + fn new(write_calls: Arc, fail_writes: bool) -> Self { + Self { + write_calls, + fail_writes, + } + } +} + +impl AsyncWrite for CountedWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + this.write_calls.fetch_add(1, Ordering::Relaxed); + if this.fail_writes { + Poll::Ready(Err(io::Error::new( + io::ErrorKind::BrokenPipe, + "forced write failure", + ))) + } else { + Poll::Ready(Ok(buf.len())) + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +fn make_crypto_writer(inner: CountedWriter) -> CryptoWriter { + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(inner, AesCtr::new(&key, iv), 8 * 1024) +} + +#[tokio::test] +async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() { + let stats = Stats::new(); + let user = "middle-me-writer-no-rollback-user"; + let user_stats = stats.get_or_create_user_stats_handle(user); + let write_calls = Arc::new(AtomicUsize::new(0)); + let mut writer = make_crypto_writer(CountedWriter::new(write_calls.clone(), true)); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + let payload = Bytes::from_static(&[0x11, 0x22, 0x33, 0x44, 0x55]); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: payload.clone(), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(user_stats.as_ref()), + Some(64), + 0, + &bytes_me2c, + 11, + true, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Io(_))), + "write failure must propagate as I/O error" + ); + assert!( + write_calls.load(Ordering::Relaxed) > 0, + "writer must be attempted after successful quota reservation" + ); + assert_eq!( + stats.get_user_quota_used(user), + payload.len() as u64, + "reserved quota must not roll back on write failure" + ); + assert_eq!( + stats.get_quota_write_fail_bytes_total(), + payload.len() as u64, + "write-fail byte metric must include failed payload size" + ); + assert_eq!( + stats.get_quota_write_fail_events_total(), + 1, + "write-fail events metric must increment once" + ); + assert_eq!( + stats.get_user_total_octets(user), + 0, + "telemetry octets_to should not advance when write fails" + ); + assert_eq!( + bytes_me2c.load(Ordering::Relaxed), + 0, + "ME->C committed byte counter must not advance on write failure" + ); +} + +#[tokio::test] +async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() { + let stats = Stats::new(); + let user = "middle-me-writer-precheck-user"; + let limit = 8u64; + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), limit); + + let write_calls = Arc::new(AtomicUsize::new(0)); + let mut writer = make_crypto_writer(CountedWriter::new(write_calls.clone(), false)); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAA, 0xBB, 0xCC]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(user_stats.as_ref()), + Some(limit), + 0, + &bytes_me2c, + 12, + true, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::DataQuotaExceeded { .. })), + "pre-write quota rejection must return typed quota error" + ); + assert_eq!( + write_calls.load(Ordering::Relaxed), + 0, + "writer must not be polled when pre-write quota reservation fails" + ); + assert_eq!( + stats.get_me_d2c_quota_reject_pre_write_total(), + 1, + "pre-write quota reject metric must increment" + ); + assert_eq!( + stats.get_user_quota_used(user), + limit, + "failed pre-write reservation must keep previous quota usage unchanged" + ); + assert_eq!( + stats.get_quota_write_fail_bytes_total(), + 0, + "write-fail bytes metric must stay unchanged on pre-write reject" + ); + assert_eq!( + stats.get_quota_write_fail_events_total(), + 0, + "write-fail events metric must stay unchanged on pre-write reject" + ); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} diff --git a/src/proxy/tests/relay_atomic_quota_invariant_tests.rs b/src/proxy/tests/relay_atomic_quota_invariant_tests.rs new file mode 100644 index 0000000..1bb00a6 --- /dev/null +++ b/src/proxy/tests/relay_atomic_quota_invariant_tests.rs @@ -0,0 +1,243 @@ +use super::*; +use std::collections::VecDeque; +use std::io; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; +use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio::time::Instant; + +struct ScriptedWriter { + scripted_writes: Arc>>, + write_calls: Arc, +} + +impl ScriptedWriter { + fn new(script: &[usize], write_calls: Arc) -> Self { + Self { + scripted_writes: Arc::new(Mutex::new(script.iter().copied().collect())), + write_calls, + } + } +} + +impl AsyncWrite for ScriptedWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + this.write_calls.fetch_add(1, Ordering::Relaxed); + let planned = this + .scripted_writes + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + .pop_front() + .unwrap_or(buf.len()); + Poll::Ready(Ok(planned.min(buf.len()))) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +fn make_stats_io_with_script( + user: &str, + quota_limit: u64, + precharged_quota: u64, + script: &[usize], +) -> ( + StatsIo, + Arc, + Arc, + Arc, +) { + let stats = Arc::new(Stats::new()); + if precharged_quota > 0 { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), precharged_quota); + } + + let write_calls = Arc::new(AtomicUsize::new(0)); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let io = StatsIo::new( + ScriptedWriter::new(script, write_calls.clone()), + Arc::new(SharedCounters::new()), + stats.clone(), + user.to_string(), + Some(quota_limit), + quota_exceeded.clone(), + Instant::now(), + ); + + (io, stats, write_calls, quota_exceeded) +} + +#[tokio::test] +async fn direct_partial_write_charges_only_committed_bytes_without_double_charge() { + let user = "direct-partial-charge-user"; + let (mut io, stats, write_calls, quota_exceeded) = + make_stats_io_with_script(user, 1_048_576, 0, &[8 * 1024, 8 * 1024, 48 * 1024]); + let payload = vec![0xAB; 64 * 1024]; + + let n1 = io + .write(&payload) + .await + .expect("first partial write must succeed"); + let n2 = io + .write(&payload) + .await + .expect("second partial write must succeed"); + let n3 = io.write(&payload).await.expect("tail write must succeed"); + + assert_eq!(n1, 8 * 1024); + assert_eq!(n2, 8 * 1024); + assert_eq!(n3, 48 * 1024); + assert_eq!(write_calls.load(Ordering::Relaxed), 3); + assert_eq!( + stats.get_user_quota_used(user), + (n1 + n2 + n3) as u64, + "quota accounting must follow committed bytes only" + ); + assert_eq!( + stats.get_user_total_octets(user), + (n1 + n2 + n3) as u64, + "telemetry octets should match committed bytes on successful writes" + ); + assert!( + !quota_exceeded.load(Ordering::Acquire), + "quota flag should stay false under large remaining budget" + ); +} + +#[tokio::test] +async fn direct_hybrid_branch_selection_matches_contract() { + let near_limit = 256 * 1024u64; + let near_remaining = 32 * 1024u64; + let (mut near_io, _stats, _calls, _flag) = make_stats_io_with_script( + "direct-near-limit-hard-check-user", + near_limit, + near_limit - near_remaining, + &[4 * 1024], + ); + let near_payload = vec![0x11; 4 * 1024]; + let near_written = near_io + .write(&near_payload) + .await + .expect("near-limit write must succeed"); + assert_eq!(near_written, 4 * 1024); + assert_eq!( + near_io.quota_bytes_since_check, 0, + "near-limit branch must go through immediate hard check" + ); + + let (mut far_small_io, _stats, _calls, _flag) = + make_stats_io_with_script("direct-far-small-amortized-user", 1_048_576, 0, &[4 * 1024]); + let far_small_payload = vec![0x22; 4 * 1024]; + let far_small_written = far_small_io + .write(&far_small_payload) + .await + .expect("small far-from-limit write must succeed"); + assert_eq!(far_small_written, 4 * 1024); + assert_eq!( + far_small_io.quota_bytes_since_check, + 4 * 1024, + "small far-from-limit write must go through amortized path" + ); + + let (mut far_large_io, _stats, _calls, _flag) = make_stats_io_with_script( + "direct-far-large-hard-check-user", + 1_048_576, + 0, + &[32 * 1024], + ); + let far_large_payload = vec![0x33; 32 * 1024]; + let far_large_written = far_large_io + .write(&far_large_payload) + .await + .expect("large write must succeed"); + assert_eq!(far_large_written, 32 * 1024); + assert_eq!( + far_large_io.quota_bytes_since_check, 0, + "large write must force immediate hard check even far from limit" + ); +} + +#[tokio::test] +async fn remaining_before_zero_rejects_without_calling_inner_writer() { + let user = "direct-zero-remaining-user"; + let limit = 8u64; + let (mut io, stats, write_calls, quota_exceeded) = + make_stats_io_with_script(user, limit, limit, &[1]); + + let err = io + .write(&[0x44]) + .await + .expect_err("write must fail when remaining quota is zero"); + + assert!( + is_quota_io_error(&err), + "zero-remaining gate must return typed quota I/O error" + ); + assert_eq!( + write_calls.load(Ordering::Relaxed), + 0, + "inner poll_write must not be called when remaining quota is zero" + ); + assert!( + quota_exceeded.load(Ordering::Acquire), + "zero-remaining gate must set exceeded flag" + ); + assert_eq!(stats.get_user_quota_used(user), limit); +} + +#[tokio::test] +async fn exceeded_flag_blocks_following_poll_before_inner_write() { + let user = "direct-exceeded-visibility-user"; + let (mut io, stats, write_calls, quota_exceeded) = + make_stats_io_with_script(user, 1, 0, &[1, 1]); + + let first = io + .write(&[0x55]) + .await + .expect("first byte should consume remaining quota"); + assert_eq!(first, 1); + assert!( + quota_exceeded.load(Ordering::Acquire), + "hard check should store quota_exceeded after boundary hit" + ); + + let second = io + .write(&[0x66]) + .await + .expect_err("next write must be rejected by early exceeded gate"); + assert!( + is_quota_io_error(&second), + "following write must fail with typed quota error" + ); + assert_eq!( + write_calls.load(Ordering::Relaxed), + 1, + "second write must be cut before touching inner writer" + ); + assert_eq!(stats.get_user_quota_used(user), 1); +} + +#[test] +fn adaptive_interval_clamp_matches_contract() { + assert_eq!(quota_adaptive_interval_bytes(0), 4 * 1024); + assert_eq!(quota_adaptive_interval_bytes(2 * 1024), 4 * 1024); + assert_eq!(quota_adaptive_interval_bytes(32 * 1024), 16 * 1024); + assert_eq!(quota_adaptive_interval_bytes(256 * 1024), 64 * 1024); + + assert!(should_immediate_quota_check(32 * 1024, 4 * 1024)); + assert!(should_immediate_quota_check(1_048_576, 32 * 1024)); + assert!(!should_immediate_quota_check(1_048_576, 4 * 1024)); +} diff --git a/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs index 5ee6522..e80690b 100644 --- a/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs +++ b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs @@ -29,6 +29,11 @@ async fn read_available(reader: &mut R, budget: total } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn positive_quota_path_forwards_both_directions_within_limit() { let stats = Arc::new(Stats::new()); @@ -63,14 +68,14 @@ async fn positive_quota_path_forwards_both_directions_within_limit() { let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); assert!(relay_result.is_ok()); - assert!(stats.get_user_total_octets(user) <= 16); + assert!(stats.get_user_quota_used(user) <= 16); } #[tokio::test] async fn negative_preloaded_quota_forbids_any_forwarding() { let stats = Arc::new(Stats::new()); let user = "quota-extended-negative-user"; - stats.add_user_octets_from(user, 8); + preload_user_quota(stats.as_ref(), user, 8); let (mut client_peer, relay_client) = duplex(1024); let (relay_server, mut server_peer) = duplex(1024); @@ -98,7 +103,7 @@ async fn negative_preloaded_quota_forbids_any_forwarding() { let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert!(stats.get_user_total_octets(user) <= 8); + assert!(stats.get_user_quota_used(user) <= 8); } #[tokio::test] @@ -189,7 +194,7 @@ async fn adversarial_blackhat_alternating_jitter_does_not_overshoot_quota() { let relay_result = timeout(Duration::from_secs(3), relay).await.unwrap().unwrap(); assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); assert!(total_forwarded <= quota as usize); - assert!(stats.get_user_total_octets(user) <= quota); + assert!(stats.get_user_quota_used(user) <= quota); } #[tokio::test] @@ -252,7 +257,7 @@ async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() { let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); assert!(total_forwarded <= quota as usize); - assert!(stats.get_user_total_octets(&user) <= quota); + assert!(stats.get_user_quota_used(&user) <= quota); } } @@ -327,6 +332,6 @@ async fn stress_parallel_relays_for_one_user_obey_global_quota() { delivered += task.await.unwrap(); } - assert!(stats.get_user_total_octets(&user) <= quota); + assert!(stats.get_user_quota_used(&user) <= quota); assert!(delivered <= quota as usize); } diff --git a/src/proxy/tests/relay_quota_model_adversarial_tests.rs b/src/proxy/tests/relay_quota_model_adversarial_tests.rs index 5714f48..73fd393 100644 --- a/src/proxy/tests/relay_quota_model_adversarial_tests.rs +++ b/src/proxy/tests/relay_quota_model_adversarial_tests.rs @@ -96,7 +96,7 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() "fuzz case {case}: delivered bytes exceed quota" ); assert!( - stats.get_user_total_octets(&user) <= quota, + stats.get_user_quota_used(&user) <= quota, "fuzz case {case}: accounted bytes exceed quota" ); } @@ -118,7 +118,7 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() assert_is_prefix(&recv_at_server, &sent_c2s, "C->S final"); assert_is_prefix(&recv_at_client, &sent_s2c, "S->C final"); assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize); - assert!(stats.get_user_total_octets(&user) <= quota); + assert!(stats.get_user_quota_used(&user) <= quota); } } @@ -209,7 +209,7 @@ async fn adversarial_dual_direction_cutoff_race_allows_at_most_one_forwarded_byt relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 1); + assert!(stats.get_user_quota_used(user) <= 1); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -305,7 +305,7 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode } assert!( - stats.get_user_total_octets(user) <= quota, + stats.get_user_quota_used(user) <= quota, "global per-user quota must never overshoot under concurrent multi-relay model load" ); assert!( diff --git a/src/proxy/tests/relay_quota_overflow_regression_tests.rs b/src/proxy/tests/relay_quota_overflow_regression_tests.rs index dfbab85..a59954e 100644 --- a/src/proxy/tests/relay_quota_overflow_regression_tests.rs +++ b/src/proxy/tests/relay_quota_overflow_regression_tests.rs @@ -19,13 +19,18 @@ async fn read_available(reader: &mut R, budget_ms: u64) -> total } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_accounting() { let stats = Arc::new(Stats::new()); let user = "quota-overflow-regression-client-chunk"; // Leave only 1 byte remaining under quota. - stats.add_user_octets_from(user, 9); + preload_user_quota(stats.as_ref(), user, 9); let (mut client_peer, relay_client) = duplex(2048); let (relay_server, mut server_peer) = duplex(2048); @@ -68,7 +73,7 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_ Err(ProxyError::DataQuotaExceeded { .. }) )); assert!( - stats.get_user_total_octets(user) <= 10, + stats.get_user_quota_used(user) <= 10, "accounted bytes must never exceed quota after overflowing chunk" ); } @@ -79,7 +84,7 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of let user = "quota-overflow-regression-boundary"; // Leave exactly 4 bytes remaining. - stats.add_user_octets_from(user, 6); + preload_user_quota(stats.as_ref(), user, 6); let (mut client_peer, relay_client) = duplex(2048); let (relay_server, mut server_peer) = duplex(2048); @@ -131,7 +136,7 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 10); + assert!(stats.get_user_quota_used(user) <= 10); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -201,7 +206,7 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { "aggregate forwarded bytes across relays must stay within global user quota" ); assert!( - stats.get_user_total_octets(user) <= quota, + stats.get_user_quota_used(user) <= quota, "global accounted bytes must stay within quota under overflow stress" ); }