From 900b574fb874b26b60412b3c7ad68a0b2429da1a Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Sun, 10 May 2026 14:09:10 +0300 Subject: [PATCH] Harden ME Writer Cancellation paths --- src/proxy/middle_relay.rs | 117 ++++++----- .../tests/direct_relay_security_tests.rs | 38 +++- ...ddle_relay_atomic_quota_invariant_tests.rs | 76 +++++++ .../relay_atomic_quota_invariant_tests.rs | 186 +++++++++++++++++- 4 files changed, 359 insertions(+), 58 deletions(-) diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 425cff4..6b0329c 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -1367,7 +1367,7 @@ where } else { None }; - let _ = writer.flush().await; + let _ = flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await; let flush_duration_us = flush_started_at.map(|started| { started .elapsed() @@ -1430,7 +1430,8 @@ where } else { None }; - let _ = writer.flush().await; + let _ = + flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await; let flush_duration_us = flush_started_at.map(|started| { started .elapsed() @@ -1498,7 +1499,11 @@ where } else { None }; - let _ = writer.flush().await; + let _ = flush_client_or_cancel( + &mut writer, + &flow_cancel_me_writer, + ) + .await; let flush_duration_us = flush_started_at.map(|started| { started .elapsed() @@ -1565,7 +1570,11 @@ where } else { None }; - let _ = writer.flush().await; + let _ = flush_client_or_cancel( + &mut writer, + &flow_cancel_me_writer, + ) + .await; let flush_duration_us = flush_started_at.map(|started| { started .elapsed() @@ -1611,7 +1620,7 @@ where } else { None }; - writer.flush().await.map_err(ProxyError::Io)?; + flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?; let flush_duration_us = flush_started_at.map(|started| { started .elapsed() @@ -2512,8 +2521,16 @@ where .await?; let write_mode = - match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) - .await + match write_client_payload( + client_writer, + proto_tag, + flags, + &data, + rng, + frame_buf, + cancel, + ) + .await { Ok(mode) => mode, Err(err) => { @@ -2556,7 +2573,7 @@ where None, ) .await?; - write_client_ack(client_writer, proto_tag, confirm).await?; + write_client_ack(client_writer, proto_tag, confirm, cancel).await?; stats.increment_me_d2c_ack_frames_total(); Ok(MeWriterResponseOutcome::Continue { @@ -2608,6 +2625,7 @@ async fn write_client_payload( data: &[u8], rng: &SecureRandom, frame_buf: &mut Vec, + cancel: &CancellationToken, ) -> Result where W: AsyncWrite + Unpin + Send + 'static, @@ -2635,21 +2653,12 @@ where frame_buf.reserve(wire_len); frame_buf.push(first); frame_buf.extend_from_slice(data); - client_writer - .write_all(frame_buf.as_slice()) - .await - .map_err(ProxyError::Io)?; + write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?; MeD2cWriteMode::Coalesced } else { let header = [first]; - client_writer - .write_all(&header) - .await - .map_err(ProxyError::Io)?; - client_writer - .write_all(data) - .await - .map_err(ProxyError::Io)?; + write_all_client_or_cancel(client_writer, &header, cancel).await?; + write_all_client_or_cancel(client_writer, data, cancel).await?; MeD2cWriteMode::Split } } else if len_words < (1 << 24) { @@ -2664,21 +2673,12 @@ where frame_buf.reserve(wire_len); frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]); frame_buf.extend_from_slice(data); - client_writer - .write_all(frame_buf.as_slice()) - .await - .map_err(ProxyError::Io)?; + write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?; MeD2cWriteMode::Coalesced } else { let header = [first, lw[0], lw[1], lw[2]]; - client_writer - .write_all(&header) - .await - .map_err(ProxyError::Io)?; - client_writer - .write_all(data) - .await - .map_err(ProxyError::Io)?; + write_all_client_or_cancel(client_writer, &header, cancel).await?; + write_all_client_or_cancel(client_writer, data, cancel).await?; MeD2cWriteMode::Split } } else { @@ -2713,21 +2713,12 @@ where frame_buf.resize(start + padding_len, 0); rng.fill(&mut frame_buf[start..]); } - client_writer - .write_all(frame_buf.as_slice()) - .await - .map_err(ProxyError::Io)?; + write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?; MeD2cWriteMode::Coalesced } else { let header = len_val.to_le_bytes(); - client_writer - .write_all(&header) - .await - .map_err(ProxyError::Io)?; - client_writer - .write_all(data) - .await - .map_err(ProxyError::Io)?; + write_all_client_or_cancel(client_writer, &header, cancel).await?; + write_all_client_or_cancel(client_writer, data, cancel).await?; if padding_len > 0 { frame_buf.clear(); if frame_buf.capacity() < padding_len { @@ -2735,10 +2726,7 @@ where } frame_buf.resize(padding_len, 0); rng.fill(frame_buf.as_mut_slice()); - client_writer - .write_all(frame_buf.as_slice()) - .await - .map_err(ProxyError::Io)?; + write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?; } MeD2cWriteMode::Split } @@ -2752,6 +2740,7 @@ async fn write_client_ack( client_writer: &mut CryptoWriter, proto_tag: ProtoTag, confirm: u32, + cancel: &CancellationToken, ) -> Result<()> where W: AsyncWrite + Unpin + Send + 'static, @@ -2761,10 +2750,34 @@ where } else { confirm.to_le_bytes() }; - client_writer - .write_all(&bytes) - .await - .map_err(ProxyError::Io) + write_all_client_or_cancel(client_writer, &bytes, cancel).await +} + +async fn write_all_client_or_cancel( + client_writer: &mut CryptoWriter, + bytes: &[u8], + cancel: &CancellationToken, +) -> Result<()> +where + W: AsyncWrite + Unpin + Send + 'static, +{ + tokio::select! { + result = client_writer.write_all(bytes) => result.map_err(ProxyError::Io), + _ = cancel.cancelled() => Err(ProxyError::Proxy("ME client writer cancelled".into())), + } +} + +async fn flush_client_or_cancel( + client_writer: &mut CryptoWriter, + cancel: &CancellationToken, +) -> Result<()> +where + W: AsyncWrite + Unpin + Send + 'static, +{ + tokio::select! { + result = client_writer.flush() => result.map_err(ProxyError::Io), + _ = cancel.cancelled() => Err(ProxyError::Proxy("ME client writer cancelled".into())), + } } #[cfg(test)] diff --git a/src/proxy/tests/direct_relay_security_tests.rs b/src/proxy/tests/direct_relay_security_tests.rs index f7ffd0d..73b6bff 100644 --- a/src/proxy/tests/direct_relay_security_tests.rs +++ b/src/proxy/tests/direct_relay_security_tests.rs @@ -637,6 +637,22 @@ fn unknown_dc_log_path_revalidation_rejects_parent_swapped_to_symlink() { "telemt-unknown-dc-parent-swap-{}", std::process::id() )); + if let Ok(meta) = fs::symlink_metadata(&parent) { + if meta.file_type().is_symlink() || meta.is_file() { + fs::remove_file(&parent).expect("stale parent-swap path must be removable"); + } else { + fs::remove_dir_all(&parent).expect("stale parent-swap directory must be removable"); + } + } + let moved = parent.with_extension("bak"); + if let Ok(meta) = fs::symlink_metadata(&moved) { + if meta.file_type().is_symlink() || meta.is_file() { + fs::remove_file(&moved).expect("stale parent-swap backup path must be removable"); + } else { + fs::remove_dir_all(&moved) + .expect("stale parent-swap backup directory must be removable"); + } + } fs::create_dir_all(&parent).expect("parent-swap test parent must be creatable"); let rel_candidate = format!( @@ -646,8 +662,6 @@ fn unknown_dc_log_path_revalidation_rejects_parent_swapped_to_symlink() { let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) .expect("candidate must sanitize before parent swap"); - let moved = parent.with_extension("bak"); - let _ = fs::remove_dir_all(&moved); fs::rename(&parent, &moved).expect("parent must be movable for swap simulation"); symlink("/tmp", &parent).expect("symlink replacement for parent must be creatable"); @@ -720,6 +734,24 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() { "telemt-unknown-dc-parent-swap-openat-{}", std::process::id() )); + if let Ok(meta) = fs::symlink_metadata(&base) { + if meta.file_type().is_symlink() || meta.is_file() { + fs::remove_file(&base).expect("stale parent-swap-openat path must be removable"); + } else { + fs::remove_dir_all(&base) + .expect("stale parent-swap-openat directory must be removable"); + } + } + let moved = base.with_extension("bak"); + if let Ok(meta) = fs::symlink_metadata(&moved) { + if meta.file_type().is_symlink() || meta.is_file() { + fs::remove_file(&moved) + .expect("stale parent-swap-openat backup path must be removable"); + } else { + fs::remove_dir_all(&moved) + .expect("stale parent-swap-openat backup directory must be removable"); + } + } fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable"); let rel_candidate = format!( @@ -743,8 +775,6 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() { let outside_target = outside_parent.join("unknown-dc.log"); let _ = fs::remove_file(&outside_target); - let moved = base.with_extension("bak"); - let _ = fs::remove_dir_all(&moved); fs::rename(&base, &moved).expect("base parent must be movable for swap simulation"); symlink(&outside_parent, &base).expect("base parent symlink replacement must be creatable"); diff --git a/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs b/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs index 18bd583..e1b3511 100644 --- a/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs +++ b/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs @@ -13,6 +13,8 @@ struct CountedWriter { fail_writes: bool, } +struct StalledWriter; + impl CountedWriter { fn new(write_calls: Arc, fail_writes: bool) -> Self { Self { @@ -49,12 +51,36 @@ impl AsyncWrite for CountedWriter { } } +impl AsyncWrite for StalledWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Pending + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Pending + } +} + fn make_crypto_writer(inner: CountedWriter) -> CryptoWriter { let key = [0u8; 32]; let iv = 0u128; CryptoWriter::new(inner, AesCtr::new(&key, iv), 8 * 1024) } +fn make_stalled_crypto_writer() -> CryptoWriter { + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(StalledWriter, AesCtr::new(&key, iv), 8 * 1024) +} + #[tokio::test] async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() { let stats = Stats::new(); @@ -189,3 +215,53 @@ async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() { ); assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); } + +#[tokio::test] +async fn me_writer_data_write_obeys_flow_cancellation() { + let stats = Stats::new(); + let user = "middle-me-writer-cancel-user"; + let mut writer = make_stalled_crypto_writer(); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + let cancel = CancellationToken::new(); + cancel.cancel(); + + let result = process_me_writer_response_with_traffic_lease( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0x31, 0x32, 0x33, 0x34]), + route_permit: None, + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + None, + None, + 0, + None, + &cancel, + &bytes_me2c, + 13, + true, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Proxy(ref message)) if message == "ME client writer cancelled"), + "cancelled middle writer must return a bounded cancellation error" + ); + assert_eq!( + bytes_me2c.load(Ordering::Relaxed), + 0, + "cancelled write must not advance committed ME->C bytes" + ); + assert_eq!( + stats.get_user_total_octets(user), + 0, + "cancelled write must not advance user output telemetry" + ); +} diff --git a/src/proxy/tests/relay_atomic_quota_invariant_tests.rs b/src/proxy/tests/relay_atomic_quota_invariant_tests.rs index 1bb00a6..237d244 100644 --- a/src/proxy/tests/relay_atomic_quota_invariant_tests.rs +++ b/src/proxy/tests/relay_atomic_quota_invariant_tests.rs @@ -4,10 +4,67 @@ use std::io; use std::pin::Pin; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Arc, Mutex}; -use std::task::{Context, Poll}; -use tokio::io::{AsyncWrite, AsyncWriteExt}; +use std::task::{Context, Poll, Wake}; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; use tokio::time::Instant; +enum ReadStep { + Data(Vec), + Pending, + Eof, + Error, +} + +struct ScriptedReader { + scripted_reads: Arc>>, + read_calls: Arc, +} + +impl ScriptedReader { + fn new(script: Vec, read_calls: Arc) -> Self { + Self { + scripted_reads: Arc::new(Mutex::new(script.into())), + read_calls, + } + } +} + +impl AsyncRead for ScriptedReader { + fn poll_read( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + this.read_calls.fetch_add(1, Ordering::Relaxed); + let step = this + .scripted_reads + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + .pop_front() + .unwrap_or(ReadStep::Eof); + match step { + ReadStep::Data(data) => { + let n = data.len().min(buf.remaining()); + buf.put_slice(&data[..n]); + Poll::Ready(Ok(())) + } + ReadStep::Pending => Poll::Pending, + ReadStep::Eof => Poll::Ready(Ok(())), + ReadStep::Error => Poll::Ready(Err(io::Error::new( + io::ErrorKind::BrokenPipe, + "forced read failure", + ))), + } + } +} + +struct NoopWake; + +impl Wake for NoopWake { + fn wake(self: Arc) {} +} + struct ScriptedWriter { scripted_writes: Arc>>, write_calls: Arc, @@ -80,6 +137,131 @@ fn make_stats_io_with_script( (io, stats, write_calls, quota_exceeded) } +fn make_stats_io_with_read_script( + user: &str, + quota_limit: u64, + precharged_quota: u64, + script: Vec, +) -> ( + StatsIo, + Arc, + Arc, + Arc, +) { + let stats = Arc::new(Stats::new()); + if precharged_quota > 0 { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), precharged_quota); + } + + let read_calls = Arc::new(AtomicUsize::new(0)); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let io = StatsIo::new( + ScriptedReader::new(script, read_calls.clone()), + Arc::new(SharedCounters::new()), + stats.clone(), + user.to_string(), + Some(quota_limit), + quota_exceeded.clone(), + Instant::now(), + ); + + (io, stats, read_calls, quota_exceeded) +} + +fn poll_read_once( + io: &mut StatsIo, + storage: &mut [u8], +) -> Poll> { + let waker = Arc::new(NoopWake).into(); + let mut cx = Context::from_waker(&waker); + let mut read_buf = ReadBuf::new(storage); + let before = read_buf.filled().len(); + match Pin::new(io).poll_read(&mut cx, &mut read_buf) { + Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len() - before)), + Poll::Ready(Err(error)) => Poll::Ready(Err(error)), + Poll::Pending => Poll::Pending, + } +} + +#[test] +fn direct_c2s_quota_refunds_unused_on_short_read() { + let user = "direct-c2s-short-read-refund-user"; + let (mut io, stats, read_calls, quota_exceeded) = make_stats_io_with_read_script( + user, + 64, + 0, + vec![ReadStep::Data(vec![0x11; 5])], + ); + let mut storage = [0u8; 16]; + + let n = match poll_read_once(&mut io, &mut storage) { + Poll::Ready(Ok(n)) => n, + other => panic!("short read must complete, got {other:?}"), + }; + + assert_eq!(n, 5); + assert_eq!(read_calls.load(Ordering::Relaxed), 1); + assert_eq!(stats.get_user_quota_used(user), 5); + assert_eq!(stats.get_quota_refund_bytes_total(), 11); + assert!(!quota_exceeded.load(Ordering::Acquire)); +} + +#[test] +fn direct_c2s_quota_refunds_full_reservation_on_pending() { + let user = "direct-c2s-pending-refund-user"; + let (mut io, stats, read_calls, quota_exceeded) = + make_stats_io_with_read_script(user, 64, 0, vec![ReadStep::Pending]); + let mut storage = [0u8; 16]; + + assert!(matches!( + poll_read_once(&mut io, &mut storage), + Poll::Pending + )); + assert_eq!(read_calls.load(Ordering::Relaxed), 1); + assert_eq!(stats.get_user_quota_used(user), 0); + assert_eq!(stats.get_quota_refund_bytes_total(), 16); + assert!(!quota_exceeded.load(Ordering::Acquire)); +} + +#[test] +fn direct_c2s_quota_refunds_full_reservation_on_eof() { + let user = "direct-c2s-eof-refund-user"; + let (mut io, stats, read_calls, quota_exceeded) = + make_stats_io_with_read_script(user, 64, 0, vec![ReadStep::Eof]); + let mut storage = [0u8; 16]; + + let n = match poll_read_once(&mut io, &mut storage) { + Poll::Ready(Ok(n)) => n, + other => panic!("EOF read must complete with zero bytes, got {other:?}"), + }; + + assert_eq!(n, 0); + assert_eq!(read_calls.load(Ordering::Relaxed), 1); + assert_eq!(stats.get_user_quota_used(user), 0); + assert_eq!(stats.get_quota_refund_bytes_total(), 16); + assert!(!quota_exceeded.load(Ordering::Acquire)); +} + +#[test] +fn direct_c2s_quota_refunds_full_reservation_on_error() { + let user = "direct-c2s-error-refund-user"; + let (mut io, stats, read_calls, quota_exceeded) = + make_stats_io_with_read_script(user, 64, 0, vec![ReadStep::Error]); + let mut storage = [0u8; 16]; + + let error = match poll_read_once(&mut io, &mut storage) { + Poll::Ready(Err(error)) => error, + other => panic!("error read must return error, got {other:?}"), + }; + + assert_eq!(error.kind(), io::ErrorKind::BrokenPipe); + assert_eq!(read_calls.load(Ordering::Relaxed), 1); + assert_eq!(stats.get_user_quota_used(user), 0); + assert_eq!(stats.get_quota_refund_bytes_total(), 16); + assert!(!quota_exceeded.load(Ordering::Acquire)); +} + #[tokio::test] async fn direct_partial_write_charges_only_committed_bytes_without_double_charge() { let user = "direct-partial-charge-user";