diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 7e73eb8..fa29529 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -5,6 +5,7 @@ use crate::network::dns_overrides::resolve_socket_addr; use crate::protocol::tls; use crate::stats::beobachten::BeobachtenStore; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; +use crate::transport::socket::configure_tcp_socket; #[cfg(unix)] use nix::ifaddrs::getifaddrs; use rand::rngs::StdRng; @@ -36,6 +37,8 @@ const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200); #[cfg(test)] const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100); const MASK_BUFFER_SIZE: usize = 8192; +const MASK_BUFFER_GROW_AFTER_BYTES: usize = 256 * 1024; +const MASK_BUFFER_MAX_SIZE: usize = 64 * 1024; #[cfg(unix)] #[cfg(not(test))] const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(300); @@ -53,6 +56,27 @@ struct MaskTcpTarget<'a> { port: u16, } +fn mask_copy_read_len(total: usize, byte_cap: usize) -> usize { + // Keep short scanner probes on the small baseline buffer and grow only + // after the session has proven to be sustained masking relay traffic. + let active_buffer_size = if total >= MASK_BUFFER_GROW_AFTER_BYTES { + MASK_BUFFER_MAX_SIZE + } else { + MASK_BUFFER_SIZE + }; + + if byte_cap == 0 { + return active_buffer_size; + } + + let remaining_budget = byte_cap.saturating_sub(total); + if remaining_budget == 0 { + return 0; + } + + remaining_budget.min(active_buffer_size) +} + async fn copy_with_idle_timeout( reader: &mut R, writer: &mut W, @@ -64,21 +88,18 @@ where R: AsyncRead + Unpin, W: AsyncWrite + Unpin, { - let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]); + let mut buf = vec![0u8; MASK_BUFFER_SIZE]; let mut total = 0usize; let mut ended_by_eof = false; - let unlimited = byte_cap == 0; loop { - let read_len = if unlimited { - MASK_BUFFER_SIZE - } else { - let remaining_budget = byte_cap.saturating_sub(total); - if remaining_budget == 0 { - break; - } - remaining_budget.min(MASK_BUFFER_SIZE) - }; + let read_len = mask_copy_read_len(total, byte_cap); + if read_len == 0 { + break; + } + if buf.len() < read_len { + buf.resize(read_len, 0); + } let read_res = timeout(idle_timeout, reader.read(&mut buf[..read_len])).await; let n = match read_res { Ok(Ok(n)) => n, @@ -877,6 +898,12 @@ fn build_mask_proxy_header( } } +fn configure_mask_backend_socket(stream: &TcpStream) { + if let Err(e) = configure_tcp_socket(stream, false, Duration::from_secs(0)) { + debug!(error = %e, "Failed to configure mask backend socket"); + } +} + /// Handle a bad client by forwarding to mask host pub async fn handle_bad_client( reader: R, @@ -1047,6 +1074,7 @@ pub async fn handle_bad_client( let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; match connect_result { Ok(Ok(stream)) => { + configure_mask_backend_socket(&stream); let proxy_header = build_mask_proxy_header(config.censorship.mask_proxy_protocol, peer, local_addr); @@ -1190,20 +1218,17 @@ async fn consume_client_data( idle_timeout: Duration, ) { // Keep drain path fail-closed under slow-loris stalls. - let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]); + let mut buf = vec![0u8; MASK_BUFFER_SIZE]; let mut total = 0usize; - let unlimited = byte_cap == 0; loop { - let read_len = if unlimited { - MASK_BUFFER_SIZE - } else { - let remaining_budget = byte_cap.saturating_sub(total); - if remaining_budget == 0 { - break; - } - remaining_budget.min(MASK_BUFFER_SIZE) - }; + let read_len = mask_copy_read_len(total, byte_cap); + if read_len == 0 { + break; + } + if buf.len() < read_len { + buf.resize(read_len, 0); + } let n = match timeout(idle_timeout, reader.read(&mut buf[..read_len])).await { Ok(Ok(n)) => n, Ok(Err(_)) | Err(_) => break, @@ -1214,7 +1239,7 @@ async fn consume_client_data( } total = total.saturating_add(n); - if !unlimited && total >= byte_cap { + if byte_cap != 0 && total >= byte_cap { break; } } @@ -1332,6 +1357,10 @@ mod masking_interface_cache_concurrency_security_tests; #[path = "tests/masking_production_cap_regression_security_tests.rs"] mod masking_production_cap_regression_security_tests; +#[cfg(test)] +#[path = "tests/masking_relay_manual_perf_tests.rs"] +mod masking_relay_manual_perf_tests; + #[cfg(test)] #[path = "tests/masking_extended_attack_surface_security_tests.rs"] mod masking_extended_attack_surface_security_tests; diff --git a/src/proxy/tests/masking_relay_manual_perf_tests.rs b/src/proxy/tests/masking_relay_manual_perf_tests.rs new file mode 100644 index 0000000..f10bd8a --- /dev/null +++ b/src/proxy/tests/masking_relay_manual_perf_tests.rs @@ -0,0 +1,111 @@ +use super::*; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio::time::{Duration, Instant}; + +const PERF_TOTAL_BYTES: usize = 64 * 1024 * 1024; + +struct PatternReader { + remaining: usize, + chunk: usize, + read_calls: AtomicUsize, +} + +impl PatternReader { + fn new(total: usize, chunk: usize) -> Self { + Self { + remaining: total, + chunk, + read_calls: AtomicUsize::new(0), + } + } + + fn read_calls(&self) -> usize { + self.read_calls.load(Ordering::Relaxed) + } +} + +impl AsyncRead for PatternReader { + fn poll_read( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + self.read_calls.fetch_add(1, Ordering::Relaxed); + if self.remaining == 0 { + return Poll::Ready(Ok(())); + } + + let take = self.remaining.min(self.chunk).min(buf.remaining()); + if take == 0 { + return Poll::Ready(Ok(())); + } + + static PATTERN: [u8; MASK_BUFFER_MAX_SIZE] = [0xA5; MASK_BUFFER_MAX_SIZE]; + buf.put_slice(&PATTERN[..take]); + self.remaining -= take; + Poll::Ready(Ok(())) + } +} + +#[derive(Default)] +struct CountingWriter { + written: usize, +} + +impl AsyncWrite for CountingWriter { + fn poll_write( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.written = self.written.saturating_add(buf.len()); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +#[ignore = "manual benchmark: throughput-sensitive and host-dependent"] +async fn masking_copy_with_idle_timeout_manual_throughput() { + let mut reader = PatternReader::new(PERF_TOTAL_BYTES, MASK_BUFFER_MAX_SIZE); + let mut writer = CountingWriter::default(); + let started = Instant::now(); + + let outcome = copy_with_idle_timeout( + &mut reader, + &mut writer, + PERF_TOTAL_BYTES, + true, + Duration::from_secs(30), + ) + .await; + + let elapsed = started.elapsed(); + let mb = PERF_TOTAL_BYTES as f64 / (1024.0 * 1024.0); + let mbps = mb / elapsed.as_secs_f64(); + + assert_eq!(outcome.total, PERF_TOTAL_BYTES); + assert_eq!(writer.written, PERF_TOTAL_BYTES); + assert!( + !outcome.ended_by_eof, + "manual throughput run should terminate at byte cap" + ); + + eprintln!( + "masking manual throughput: bytes={} elapsed_ms={} mib_per_sec={:.2} read_calls={}", + PERF_TOTAL_BYTES, + elapsed.as_millis(), + mbps, + reader.read_calls() + ); +}