Harden ME Writer Cancellation paths

This commit is contained in:
Alexey
2026-05-10 14:09:10 +03:00
parent beed6b4679
commit 900b574fb8
4 changed files with 359 additions and 58 deletions

View File

@@ -1367,7 +1367,7 @@ where
} else { } else {
None None
}; };
let _ = writer.flush().await; let _ = flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await;
let flush_duration_us = flush_started_at.map(|started| { let flush_duration_us = flush_started_at.map(|started| {
started started
.elapsed() .elapsed()
@@ -1430,7 +1430,8 @@ where
} else { } else {
None None
}; };
let _ = writer.flush().await; let _ =
flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await;
let flush_duration_us = flush_started_at.map(|started| { let flush_duration_us = flush_started_at.map(|started| {
started started
.elapsed() .elapsed()
@@ -1498,7 +1499,11 @@ where
} else { } else {
None None
}; };
let _ = writer.flush().await; let _ = flush_client_or_cancel(
&mut writer,
&flow_cancel_me_writer,
)
.await;
let flush_duration_us = flush_started_at.map(|started| { let flush_duration_us = flush_started_at.map(|started| {
started started
.elapsed() .elapsed()
@@ -1565,7 +1570,11 @@ where
} else { } else {
None None
}; };
let _ = writer.flush().await; let _ = flush_client_or_cancel(
&mut writer,
&flow_cancel_me_writer,
)
.await;
let flush_duration_us = flush_started_at.map(|started| { let flush_duration_us = flush_started_at.map(|started| {
started started
.elapsed() .elapsed()
@@ -1611,7 +1620,7 @@ where
} else { } else {
None None
}; };
writer.flush().await.map_err(ProxyError::Io)?; flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?;
let flush_duration_us = flush_started_at.map(|started| { let flush_duration_us = flush_started_at.map(|started| {
started started
.elapsed() .elapsed()
@@ -2512,8 +2521,16 @@ where
.await?; .await?;
let write_mode = let write_mode =
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) match write_client_payload(
.await client_writer,
proto_tag,
flags,
&data,
rng,
frame_buf,
cancel,
)
.await
{ {
Ok(mode) => mode, Ok(mode) => mode,
Err(err) => { Err(err) => {
@@ -2556,7 +2573,7 @@ where
None, None,
) )
.await?; .await?;
write_client_ack(client_writer, proto_tag, confirm).await?; write_client_ack(client_writer, proto_tag, confirm, cancel).await?;
stats.increment_me_d2c_ack_frames_total(); stats.increment_me_d2c_ack_frames_total();
Ok(MeWriterResponseOutcome::Continue { Ok(MeWriterResponseOutcome::Continue {
@@ -2608,6 +2625,7 @@ async fn write_client_payload<W>(
data: &[u8], data: &[u8],
rng: &SecureRandom, rng: &SecureRandom,
frame_buf: &mut Vec<u8>, frame_buf: &mut Vec<u8>,
cancel: &CancellationToken,
) -> Result<MeD2cWriteMode> ) -> Result<MeD2cWriteMode>
where where
W: AsyncWrite + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static,
@@ -2635,21 +2653,12 @@ where
frame_buf.reserve(wire_len); frame_buf.reserve(wire_len);
frame_buf.push(first); frame_buf.push(first);
frame_buf.extend_from_slice(data); frame_buf.extend_from_slice(data);
client_writer write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
.write_all(frame_buf.as_slice())
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Coalesced MeD2cWriteMode::Coalesced
} else { } else {
let header = [first]; let header = [first];
client_writer write_all_client_or_cancel(client_writer, &header, cancel).await?;
.write_all(&header) write_all_client_or_cancel(client_writer, data, cancel).await?;
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Split MeD2cWriteMode::Split
} }
} else if len_words < (1 << 24) { } else if len_words < (1 << 24) {
@@ -2664,21 +2673,12 @@ where
frame_buf.reserve(wire_len); frame_buf.reserve(wire_len);
frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]); frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]);
frame_buf.extend_from_slice(data); frame_buf.extend_from_slice(data);
client_writer write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
.write_all(frame_buf.as_slice())
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Coalesced MeD2cWriteMode::Coalesced
} else { } else {
let header = [first, lw[0], lw[1], lw[2]]; let header = [first, lw[0], lw[1], lw[2]];
client_writer write_all_client_or_cancel(client_writer, &header, cancel).await?;
.write_all(&header) write_all_client_or_cancel(client_writer, data, cancel).await?;
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Split MeD2cWriteMode::Split
} }
} else { } else {
@@ -2713,21 +2713,12 @@ where
frame_buf.resize(start + padding_len, 0); frame_buf.resize(start + padding_len, 0);
rng.fill(&mut frame_buf[start..]); rng.fill(&mut frame_buf[start..]);
} }
client_writer write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
.write_all(frame_buf.as_slice())
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Coalesced MeD2cWriteMode::Coalesced
} else { } else {
let header = len_val.to_le_bytes(); let header = len_val.to_le_bytes();
client_writer write_all_client_or_cancel(client_writer, &header, cancel).await?;
.write_all(&header) write_all_client_or_cancel(client_writer, data, cancel).await?;
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
if padding_len > 0 { if padding_len > 0 {
frame_buf.clear(); frame_buf.clear();
if frame_buf.capacity() < padding_len { if frame_buf.capacity() < padding_len {
@@ -2735,10 +2726,7 @@ where
} }
frame_buf.resize(padding_len, 0); frame_buf.resize(padding_len, 0);
rng.fill(frame_buf.as_mut_slice()); rng.fill(frame_buf.as_mut_slice());
client_writer write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
.write_all(frame_buf.as_slice())
.await
.map_err(ProxyError::Io)?;
} }
MeD2cWriteMode::Split MeD2cWriteMode::Split
} }
@@ -2752,6 +2740,7 @@ async fn write_client_ack<W>(
client_writer: &mut CryptoWriter<W>, client_writer: &mut CryptoWriter<W>,
proto_tag: ProtoTag, proto_tag: ProtoTag,
confirm: u32, confirm: u32,
cancel: &CancellationToken,
) -> Result<()> ) -> Result<()>
where where
W: AsyncWrite + Unpin + Send + 'static, W: AsyncWrite + Unpin + Send + 'static,
@@ -2761,10 +2750,34 @@ where
} else { } else {
confirm.to_le_bytes() confirm.to_le_bytes()
}; };
client_writer write_all_client_or_cancel(client_writer, &bytes, cancel).await
.write_all(&bytes) }
.await
.map_err(ProxyError::Io) async fn write_all_client_or_cancel<W>(
client_writer: &mut CryptoWriter<W>,
bytes: &[u8],
cancel: &CancellationToken,
) -> Result<()>
where
W: AsyncWrite + Unpin + Send + 'static,
{
tokio::select! {
result = client_writer.write_all(bytes) => result.map_err(ProxyError::Io),
_ = cancel.cancelled() => Err(ProxyError::Proxy("ME client writer cancelled".into())),
}
}
async fn flush_client_or_cancel<W>(
client_writer: &mut CryptoWriter<W>,
cancel: &CancellationToken,
) -> Result<()>
where
W: AsyncWrite + Unpin + Send + 'static,
{
tokio::select! {
result = client_writer.flush() => result.map_err(ProxyError::Io),
_ = cancel.cancelled() => Err(ProxyError::Proxy("ME client writer cancelled".into())),
}
} }
#[cfg(test)] #[cfg(test)]

View File

@@ -637,6 +637,22 @@ fn unknown_dc_log_path_revalidation_rejects_parent_swapped_to_symlink() {
"telemt-unknown-dc-parent-swap-{}", "telemt-unknown-dc-parent-swap-{}",
std::process::id() std::process::id()
)); ));
if let Ok(meta) = fs::symlink_metadata(&parent) {
if meta.file_type().is_symlink() || meta.is_file() {
fs::remove_file(&parent).expect("stale parent-swap path must be removable");
} else {
fs::remove_dir_all(&parent).expect("stale parent-swap directory must be removable");
}
}
let moved = parent.with_extension("bak");
if let Ok(meta) = fs::symlink_metadata(&moved) {
if meta.file_type().is_symlink() || meta.is_file() {
fs::remove_file(&moved).expect("stale parent-swap backup path must be removable");
} else {
fs::remove_dir_all(&moved)
.expect("stale parent-swap backup directory must be removable");
}
}
fs::create_dir_all(&parent).expect("parent-swap test parent must be creatable"); fs::create_dir_all(&parent).expect("parent-swap test parent must be creatable");
let rel_candidate = format!( let rel_candidate = format!(
@@ -646,8 +662,6 @@ fn unknown_dc_log_path_revalidation_rejects_parent_swapped_to_symlink() {
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate) let sanitized = sanitize_unknown_dc_log_path(&rel_candidate)
.expect("candidate must sanitize before parent swap"); .expect("candidate must sanitize before parent swap");
let moved = parent.with_extension("bak");
let _ = fs::remove_dir_all(&moved);
fs::rename(&parent, &moved).expect("parent must be movable for swap simulation"); fs::rename(&parent, &moved).expect("parent must be movable for swap simulation");
symlink("/tmp", &parent).expect("symlink replacement for parent must be creatable"); symlink("/tmp", &parent).expect("symlink replacement for parent must be creatable");
@@ -720,6 +734,24 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() {
"telemt-unknown-dc-parent-swap-openat-{}", "telemt-unknown-dc-parent-swap-openat-{}",
std::process::id() std::process::id()
)); ));
if let Ok(meta) = fs::symlink_metadata(&base) {
if meta.file_type().is_symlink() || meta.is_file() {
fs::remove_file(&base).expect("stale parent-swap-openat path must be removable");
} else {
fs::remove_dir_all(&base)
.expect("stale parent-swap-openat directory must be removable");
}
}
let moved = base.with_extension("bak");
if let Ok(meta) = fs::symlink_metadata(&moved) {
if meta.file_type().is_symlink() || meta.is_file() {
fs::remove_file(&moved)
.expect("stale parent-swap-openat backup path must be removable");
} else {
fs::remove_dir_all(&moved)
.expect("stale parent-swap-openat backup directory must be removable");
}
}
fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable"); fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable");
let rel_candidate = format!( let rel_candidate = format!(
@@ -743,8 +775,6 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() {
let outside_target = outside_parent.join("unknown-dc.log"); let outside_target = outside_parent.join("unknown-dc.log");
let _ = fs::remove_file(&outside_target); let _ = fs::remove_file(&outside_target);
let moved = base.with_extension("bak");
let _ = fs::remove_dir_all(&moved);
fs::rename(&base, &moved).expect("base parent must be movable for swap simulation"); fs::rename(&base, &moved).expect("base parent must be movable for swap simulation");
symlink(&outside_parent, &base).expect("base parent symlink replacement must be creatable"); symlink(&outside_parent, &base).expect("base parent symlink replacement must be creatable");

View File

@@ -13,6 +13,8 @@ struct CountedWriter {
fail_writes: bool, fail_writes: bool,
} }
struct StalledWriter;
impl CountedWriter { impl CountedWriter {
fn new(write_calls: Arc<AtomicUsize>, fail_writes: bool) -> Self { fn new(write_calls: Arc<AtomicUsize>, fail_writes: bool) -> Self {
Self { Self {
@@ -49,12 +51,36 @@ impl AsyncWrite for CountedWriter {
} }
} }
impl AsyncWrite for StalledWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<io::Result<usize>> {
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Pending
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Pending
}
}
fn make_crypto_writer(inner: CountedWriter) -> CryptoWriter<CountedWriter> { fn make_crypto_writer(inner: CountedWriter) -> CryptoWriter<CountedWriter> {
let key = [0u8; 32]; let key = [0u8; 32];
let iv = 0u128; let iv = 0u128;
CryptoWriter::new(inner, AesCtr::new(&key, iv), 8 * 1024) CryptoWriter::new(inner, AesCtr::new(&key, iv), 8 * 1024)
} }
fn make_stalled_crypto_writer() -> CryptoWriter<StalledWriter> {
let key = [0u8; 32];
let iv = 0u128;
CryptoWriter::new(StalledWriter, AesCtr::new(&key, iv), 8 * 1024)
}
#[tokio::test] #[tokio::test]
async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() { async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() {
let stats = Stats::new(); let stats = Stats::new();
@@ -189,3 +215,53 @@ async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() {
); );
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
} }
#[tokio::test]
async fn me_writer_data_write_obeys_flow_cancellation() {
let stats = Stats::new();
let user = "middle-me-writer-cancel-user";
let mut writer = make_stalled_crypto_writer();
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let cancel = CancellationToken::new();
cancel.cancel();
let result = process_me_writer_response_with_traffic_lease(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0x31, 0x32, 0x33, 0x34]),
route_permit: None,
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
user,
None,
None,
0,
None,
&cancel,
&bytes_me2c,
13,
true,
false,
)
.await;
assert!(
matches!(result, Err(ProxyError::Proxy(ref message)) if message == "ME client writer cancelled"),
"cancelled middle writer must return a bounded cancellation error"
);
assert_eq!(
bytes_me2c.load(Ordering::Relaxed),
0,
"cancelled write must not advance committed ME->C bytes"
);
assert_eq!(
stats.get_user_total_octets(user),
0,
"cancelled write must not advance user output telemetry"
);
}

View File

@@ -4,10 +4,67 @@ use std::io;
use std::pin::Pin; use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::task::{Context, Poll}; use std::task::{Context, Poll, Wake};
use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use tokio::time::Instant; use tokio::time::Instant;
enum ReadStep {
Data(Vec<u8>),
Pending,
Eof,
Error,
}
struct ScriptedReader {
scripted_reads: Arc<Mutex<VecDeque<ReadStep>>>,
read_calls: Arc<AtomicUsize>,
}
impl ScriptedReader {
fn new(script: Vec<ReadStep>, read_calls: Arc<AtomicUsize>) -> Self {
Self {
scripted_reads: Arc::new(Mutex::new(script.into())),
read_calls,
}
}
}
impl AsyncRead for ScriptedReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
this.read_calls.fetch_add(1, Ordering::Relaxed);
let step = this
.scripted_reads
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.pop_front()
.unwrap_or(ReadStep::Eof);
match step {
ReadStep::Data(data) => {
let n = data.len().min(buf.remaining());
buf.put_slice(&data[..n]);
Poll::Ready(Ok(()))
}
ReadStep::Pending => Poll::Pending,
ReadStep::Eof => Poll::Ready(Ok(())),
ReadStep::Error => Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"forced read failure",
))),
}
}
}
struct NoopWake;
impl Wake for NoopWake {
fn wake(self: Arc<Self>) {}
}
struct ScriptedWriter { struct ScriptedWriter {
scripted_writes: Arc<Mutex<VecDeque<usize>>>, scripted_writes: Arc<Mutex<VecDeque<usize>>>,
write_calls: Arc<AtomicUsize>, write_calls: Arc<AtomicUsize>,
@@ -80,6 +137,131 @@ fn make_stats_io_with_script(
(io, stats, write_calls, quota_exceeded) (io, stats, write_calls, quota_exceeded)
} }
fn make_stats_io_with_read_script(
user: &str,
quota_limit: u64,
precharged_quota: u64,
script: Vec<ReadStep>,
) -> (
StatsIo<ScriptedReader>,
Arc<Stats>,
Arc<AtomicUsize>,
Arc<AtomicBool>,
) {
let stats = Arc::new(Stats::new());
if precharged_quota > 0 {
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), precharged_quota);
}
let read_calls = Arc::new(AtomicUsize::new(0));
let quota_exceeded = Arc::new(AtomicBool::new(false));
let io = StatsIo::new(
ScriptedReader::new(script, read_calls.clone()),
Arc::new(SharedCounters::new()),
stats.clone(),
user.to_string(),
Some(quota_limit),
quota_exceeded.clone(),
Instant::now(),
);
(io, stats, read_calls, quota_exceeded)
}
fn poll_read_once<R: AsyncRead + Unpin>(
io: &mut StatsIo<R>,
storage: &mut [u8],
) -> Poll<io::Result<usize>> {
let waker = Arc::new(NoopWake).into();
let mut cx = Context::from_waker(&waker);
let mut read_buf = ReadBuf::new(storage);
let before = read_buf.filled().len();
match Pin::new(io).poll_read(&mut cx, &mut read_buf) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len() - before)),
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
Poll::Pending => Poll::Pending,
}
}
#[test]
fn direct_c2s_quota_refunds_unused_on_short_read() {
let user = "direct-c2s-short-read-refund-user";
let (mut io, stats, read_calls, quota_exceeded) = make_stats_io_with_read_script(
user,
64,
0,
vec![ReadStep::Data(vec![0x11; 5])],
);
let mut storage = [0u8; 16];
let n = match poll_read_once(&mut io, &mut storage) {
Poll::Ready(Ok(n)) => n,
other => panic!("short read must complete, got {other:?}"),
};
assert_eq!(n, 5);
assert_eq!(read_calls.load(Ordering::Relaxed), 1);
assert_eq!(stats.get_user_quota_used(user), 5);
assert_eq!(stats.get_quota_refund_bytes_total(), 11);
assert!(!quota_exceeded.load(Ordering::Acquire));
}
#[test]
fn direct_c2s_quota_refunds_full_reservation_on_pending() {
let user = "direct-c2s-pending-refund-user";
let (mut io, stats, read_calls, quota_exceeded) =
make_stats_io_with_read_script(user, 64, 0, vec![ReadStep::Pending]);
let mut storage = [0u8; 16];
assert!(matches!(
poll_read_once(&mut io, &mut storage),
Poll::Pending
));
assert_eq!(read_calls.load(Ordering::Relaxed), 1);
assert_eq!(stats.get_user_quota_used(user), 0);
assert_eq!(stats.get_quota_refund_bytes_total(), 16);
assert!(!quota_exceeded.load(Ordering::Acquire));
}
#[test]
fn direct_c2s_quota_refunds_full_reservation_on_eof() {
let user = "direct-c2s-eof-refund-user";
let (mut io, stats, read_calls, quota_exceeded) =
make_stats_io_with_read_script(user, 64, 0, vec![ReadStep::Eof]);
let mut storage = [0u8; 16];
let n = match poll_read_once(&mut io, &mut storage) {
Poll::Ready(Ok(n)) => n,
other => panic!("EOF read must complete with zero bytes, got {other:?}"),
};
assert_eq!(n, 0);
assert_eq!(read_calls.load(Ordering::Relaxed), 1);
assert_eq!(stats.get_user_quota_used(user), 0);
assert_eq!(stats.get_quota_refund_bytes_total(), 16);
assert!(!quota_exceeded.load(Ordering::Acquire));
}
#[test]
fn direct_c2s_quota_refunds_full_reservation_on_error() {
let user = "direct-c2s-error-refund-user";
let (mut io, stats, read_calls, quota_exceeded) =
make_stats_io_with_read_script(user, 64, 0, vec![ReadStep::Error]);
let mut storage = [0u8; 16];
let error = match poll_read_once(&mut io, &mut storage) {
Poll::Ready(Err(error)) => error,
other => panic!("error read must return error, got {other:?}"),
};
assert_eq!(error.kind(), io::ErrorKind::BrokenPipe);
assert_eq!(read_calls.load(Ordering::Relaxed), 1);
assert_eq!(stats.get_user_quota_used(user), 0);
assert_eq!(stats.get_quota_refund_bytes_total(), 16);
assert!(!quota_exceeded.load(Ordering::Acquire));
}
#[tokio::test] #[tokio::test]
async fn direct_partial_write_charges_only_committed_bytes_without_double_charge() { async fn direct_partial_write_charges_only_committed_bytes_without_double_charge() {
let user = "direct-partial-charge-user"; let user = "direct-partial-charge-user";