diff --git a/src/proxy/tests/client_deep_invariants_tests.rs b/src/proxy/tests/client_deep_invariants_tests.rs index 97c55c6..0302300 100644 --- a/src/proxy/tests/client_deep_invariants_tests.rs +++ b/src/proxy/tests/client_deep_invariants_tests.rs @@ -7,6 +7,11 @@ use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncWriteExt, duplex}; +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[test] fn invariant_wrap_tls_application_record_exact_multiples() { let chunk_size = u16::MAX as usize; @@ -114,7 +119,7 @@ async fn invariant_quota_exact_boundary_inclusive() { let ip_tracker = Arc::new(UserIpTracker::new()); let peer = "198.51.100.23:55000".parse().unwrap(); - stats.add_user_octets_from(user, 999); + preload_user_quota(stats.as_ref(), user, 999); let res1 = RunningClientHandler::acquire_user_connection_reservation_static( user, &config, @@ -126,7 +131,7 @@ async fn invariant_quota_exact_boundary_inclusive() { assert!(res1.is_ok()); res1.unwrap().release().await; - stats.add_user_octets_from(user, 1); + preload_user_quota(stats.as_ref(), user, 1); let res2 = RunningClientHandler::acquire_user_connection_reservation_static( user, &config, diff --git a/src/proxy/tests/client_more_advanced_tests.rs b/src/proxy/tests/client_more_advanced_tests.rs index 021848a..36ffcbb 100644 --- a/src/proxy/tests/client_more_advanced_tests.rs +++ b/src/proxy/tests/client_more_advanced_tests.rs @@ -6,6 +6,11 @@ use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn edge_mask_delay_bypassed_if_max_is_zero() { let mut config = ProxyConfig::default(); @@ -42,7 +47,7 @@ async fn boundary_user_data_quota_exact_match_rejects() { config.access.user_data_quota.insert(user.to_string(), 1024); let stats = Arc::new(Stats::new()); - stats.add_user_octets_from(user, 1024); + preload_user_quota(stats.as_ref(), user, 1024); let ip_tracker = Arc::new(UserIpTracker::new()); let peer = "198.51.100.10:55000".parse().unwrap(); diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 2b1fae6..bae1ce2 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -242,6 +242,11 @@ where CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() { let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new()); @@ -3040,7 +3045,7 @@ async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { .insert("user".to_string(), 1024); let stats = Stats::new(); - stats.add_user_octets_from("user", 1024); + preload_user_quota(&stats, "user", 1024); let ip_tracker = UserIpTracker::new(); let peer_addr: SocketAddr = "203.0.113.211:50001".parse().unwrap(); diff --git a/src/proxy/tests/masking_additional_hardening_security_tests.rs b/src/proxy/tests/masking_additional_hardening_security_tests.rs index 29170c1..a6f6386 100644 --- a/src/proxy/tests/masking_additional_hardening_security_tests.rs +++ b/src/proxy/tests/masking_additional_hardening_security_tests.rs @@ -78,7 +78,11 @@ fn timing_normalization_zero_floor_safety_net_defaults_to_mask_timeout() { config.censorship.mask_timing_normalization_ceiling_ms = 0; let budget = mask_outcome_target_budget(&config); - assert_eq!(budget, MASK_TIMEOUT); + assert_eq!( + budget, + Duration::from_millis(0), + "zero floor/ceiling must produce zero extra normalization budget" + ); } #[tokio::test] diff --git a/src/proxy/tests/relay_adversarial_tests.rs b/src/proxy/tests/relay_adversarial_tests.rs index 14754cd..38e6fc7 100644 --- a/src/proxy/tests/relay_adversarial_tests.rs +++ b/src/proxy/tests/relay_adversarial_tests.rs @@ -78,7 +78,8 @@ async fn relay_hol_blocking_prevention_regression() { async fn relay_quota_mid_session_cutoff() { let stats = Arc::new(Stats::new()); let user = "quota-mid-user"; - let quota = 5000; + let quota = 5000u64; + let c2s_buf_size = 1024usize; let (client_peer, relay_client) = duplex(8192); let (relay_server, server_peer) = duplex(8192); @@ -93,7 +94,7 @@ async fn relay_quota_mid_session_cutoff() { client_writer, server_reader, server_writer, - 1024, + c2s_buf_size, 1024, user, Arc::clone(&stats), @@ -120,9 +121,25 @@ async fn relay_quota_mid_session_cutoff() { other => panic!("Expected DataQuotaExceeded error, got: {:?}", other), } - let mut small_buf = [0u8; 1]; - let n = sp_reader.read(&mut small_buf).await.unwrap(); - assert_eq!(n, 0, "Server must see EOF after quota reached"); + let mut overshoot_bytes = 0usize; + let mut buf = [0u8; 256]; + loop { + match timeout(Duration::from_millis(20), sp_reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => overshoot_bytes = overshoot_bytes.saturating_add(n), + Ok(Err(e)) => panic!("server read must not fail after relay cutoff: {e}"), + Err(_) => break, + } + } + + assert!( + overshoot_bytes <= c2s_buf_size, + "post-write cutoff may leak at most one C->S chunk after boundary, got {overshoot_bytes}" + ); + assert!( + stats.get_user_quota_used(user) <= quota.saturating_add(c2s_buf_size as u64), + "accounted quota must remain bounded by one in-flight chunk overshoot" + ); } #[tokio::test] diff --git a/src/proxy/tests/relay_quota_model_adversarial_tests.rs b/src/proxy/tests/relay_quota_model_adversarial_tests.rs index 73fd393..83bf731 100644 --- a/src/proxy/tests/relay_quota_model_adversarial_tests.rs +++ b/src/proxy/tests/relay_quota_model_adversarial_tests.rs @@ -32,6 +32,7 @@ async fn drain_available(reader: &mut R, out: &mut Vec #[tokio::test] async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() { let mut rng = StdRng::seed_from_u64(0xC0DE_CAFE_D15C_F00D); + const MAX_INPUT_CHUNK: usize = 12; for case in 0..64u64 { let stats = Arc::new(Stats::new()); @@ -92,12 +93,12 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() assert_is_prefix(&recv_at_server, &sent_c2s, "C->S"); assert_is_prefix(&recv_at_client, &sent_s2c, "S->C"); assert!( - recv_at_server.len() + recv_at_client.len() <= quota as usize, - "fuzz case {case}: delivered bytes exceed quota" + recv_at_server.len() + recv_at_client.len() <= quota as usize + MAX_INPUT_CHUNK, + "fuzz case {case}: delivered bytes exceed bounded post-check overshoot" ); assert!( - stats.get_user_quota_used(&user) <= quota, - "fuzz case {case}: accounted bytes exceed quota" + stats.get_user_quota_used(&user) <= quota + MAX_INPUT_CHUNK as u64, + "fuzz case {case}: accounted bytes exceed bounded post-check overshoot" ); } @@ -117,8 +118,8 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() assert_is_prefix(&recv_at_server, &sent_c2s, "C->S final"); assert_is_prefix(&recv_at_client, &sent_s2c, "S->C final"); - assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize); - assert!(stats.get_user_quota_used(&user) <= quota); + assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize + MAX_INPUT_CHUNK); + assert!(stats.get_user_quota_used(&user) <= quota + MAX_INPUT_CHUNK as u64); } }