use crate::config::ProxyConfig; use rand::rngs::StdRng; use rand::SeedableRng; use std::io; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; use tokio::io::AsyncWrite; #[cfg(test)] mod tests { use super::*; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use std::task::{RawWaker, RawWakerVTable, Waker}; unsafe fn wake_counter_clone(data: *const ()) -> RawWaker { let arc = Arc::::from_raw(data.cast::()); let cloned = Arc::clone(&arc); let _ = Arc::into_raw(arc); RawWaker::new(Arc::into_raw(cloned).cast::<()>(), &WAKE_COUNTER_WAKER_VTABLE) } unsafe fn wake_counter_wake(data: *const ()) { let arc = Arc::::from_raw(data.cast::()); arc.fetch_add(1, Ordering::SeqCst); } unsafe fn wake_counter_wake_by_ref(data: *const ()) { let arc = Arc::::from_raw(data.cast::()); arc.fetch_add(1, Ordering::SeqCst); let _ = Arc::into_raw(arc); } unsafe fn wake_counter_drop(data: *const ()) { let _ = Arc::::from_raw(data.cast::()); } static WAKE_COUNTER_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new( wake_counter_clone, wake_counter_wake, wake_counter_wake_by_ref, wake_counter_drop, ); fn wake_counter_waker(counter: Arc) -> Waker { let raw = RawWaker::new( Arc::into_raw(counter).cast::<()>(), &WAKE_COUNTER_WAKER_VTABLE, ); // SAFETY: `raw` points to a valid `Arc` and uses a vtable // that preserves Arc reference-counting semantics. unsafe { Waker::from_raw(raw) } } #[test] fn pending_count_writer_write_pending_does_not_spurious_wake() { let counter = Arc::new(AtomicUsize::new(0)); let waker = wake_counter_waker(Arc::clone(&counter)); let mut cx = Context::from_waker(&waker); let mut writer = PendingCountWriter::new(RecordingWriter::new(), 1, 0); let poll = Pin::new(&mut writer).poll_write(&mut cx, b"x"); assert!(matches!(poll, Poll::Pending)); assert_eq!(counter.load(Ordering::SeqCst), 0); } #[test] fn pending_count_writer_flush_pending_does_not_spurious_wake() { let counter = Arc::new(AtomicUsize::new(0)); let waker = wake_counter_waker(Arc::clone(&counter)); let mut cx = Context::from_waker(&waker); let mut writer = PendingCountWriter::new(RecordingWriter::new(), 0, 1); let poll = Pin::new(&mut writer).poll_flush(&mut cx); assert!(matches!(poll, Poll::Pending)); assert_eq!(counter.load(Ordering::SeqCst), 0); } } // In-memory AsyncWrite that records both per-write and per-flush granularity. pub struct RecordingWriter { pub writes: Vec>, pub flushed: Vec>, current_record: Vec, } impl RecordingWriter { pub fn new() -> Self { Self { writes: Vec::new(), flushed: Vec::new(), current_record: Vec::new(), } } pub fn total_bytes(&self) -> usize { self.writes.iter().map(|w| w.len()).sum() } } impl Default for RecordingWriter { fn default() -> Self { Self::new() } } impl AsyncWrite for RecordingWriter { fn poll_write( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let me = self.as_mut().get_mut(); me.writes.push(buf.to_vec()); me.current_record.extend_from_slice(buf); Poll::Ready(Ok(buf.len())) } fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { let me = self.as_mut().get_mut(); let record = std::mem::take(&mut me.current_record); if !record.is_empty() { me.flushed.push(record); } Poll::Ready(Ok(())) } fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } } // Returns Poll::Pending for the first N write/flush calls, then delegates. pub struct PendingCountWriter { pub inner: W, pub write_pending_remaining: usize, pub flush_pending_remaining: usize, } impl PendingCountWriter { pub fn new(inner: W, write_pending: usize, flush_pending: usize) -> Self { Self { inner, write_pending_remaining: write_pending, flush_pending_remaining: flush_pending, } } } impl AsyncWrite for PendingCountWriter { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let me = self.as_mut().get_mut(); if me.write_pending_remaining > 0 { me.write_pending_remaining -= 1; return Poll::Pending; } Pin::new(&mut me.inner).poll_write(cx, buf) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let me = self.as_mut().get_mut(); if me.flush_pending_remaining > 0 { me.flush_pending_remaining -= 1; return Poll::Pending; } Pin::new(&mut me.inner).poll_flush(cx) } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_shutdown(cx) } } pub fn seeded_rng(seed: u64) -> StdRng { StdRng::seed_from_u64(seed) } pub fn tls_only_config() -> Arc { let mut cfg = ProxyConfig::default(); cfg.general.modes.tls = true; Arc::new(cfg) } pub fn handshake_test_config(secret_hex: &str) -> ProxyConfig { let mut cfg = ProxyConfig::default(); cfg.access.users.clear(); cfg.access .users .insert("test-user".to_string(), secret_hex.to_string()); cfg.access.ignore_time_skew = true; cfg.censorship.mask = true; cfg.censorship.mask_host = Some("127.0.0.1".to_string()); cfg.censorship.mask_port = 0; cfg }