mirror of
https://github.com/telemt/telemt.git
synced 2026-05-13 15:21:44 +03:00
Harden ME Writer Cancellation paths
This commit is contained in:
@@ -1367,7 +1367,7 @@ where
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let _ = writer.flush().await;
|
let _ = flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await;
|
||||||
let flush_duration_us = flush_started_at.map(|started| {
|
let flush_duration_us = flush_started_at.map(|started| {
|
||||||
started
|
started
|
||||||
.elapsed()
|
.elapsed()
|
||||||
@@ -1430,7 +1430,8 @@ where
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let _ = writer.flush().await;
|
let _ =
|
||||||
|
flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await;
|
||||||
let flush_duration_us = flush_started_at.map(|started| {
|
let flush_duration_us = flush_started_at.map(|started| {
|
||||||
started
|
started
|
||||||
.elapsed()
|
.elapsed()
|
||||||
@@ -1498,7 +1499,11 @@ where
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let _ = writer.flush().await;
|
let _ = flush_client_or_cancel(
|
||||||
|
&mut writer,
|
||||||
|
&flow_cancel_me_writer,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
let flush_duration_us = flush_started_at.map(|started| {
|
let flush_duration_us = flush_started_at.map(|started| {
|
||||||
started
|
started
|
||||||
.elapsed()
|
.elapsed()
|
||||||
@@ -1565,7 +1570,11 @@ where
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let _ = writer.flush().await;
|
let _ = flush_client_or_cancel(
|
||||||
|
&mut writer,
|
||||||
|
&flow_cancel_me_writer,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
let flush_duration_us = flush_started_at.map(|started| {
|
let flush_duration_us = flush_started_at.map(|started| {
|
||||||
started
|
started
|
||||||
.elapsed()
|
.elapsed()
|
||||||
@@ -1611,7 +1620,7 @@ where
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
writer.flush().await.map_err(ProxyError::Io)?;
|
flush_client_or_cancel(&mut writer, &flow_cancel_me_writer).await?;
|
||||||
let flush_duration_us = flush_started_at.map(|started| {
|
let flush_duration_us = flush_started_at.map(|started| {
|
||||||
started
|
started
|
||||||
.elapsed()
|
.elapsed()
|
||||||
@@ -2512,7 +2521,15 @@ where
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let write_mode =
|
let write_mode =
|
||||||
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
|
match write_client_payload(
|
||||||
|
client_writer,
|
||||||
|
proto_tag,
|
||||||
|
flags,
|
||||||
|
&data,
|
||||||
|
rng,
|
||||||
|
frame_buf,
|
||||||
|
cancel,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(mode) => mode,
|
Ok(mode) => mode,
|
||||||
@@ -2556,7 +2573,7 @@ where
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
write_client_ack(client_writer, proto_tag, confirm).await?;
|
write_client_ack(client_writer, proto_tag, confirm, cancel).await?;
|
||||||
stats.increment_me_d2c_ack_frames_total();
|
stats.increment_me_d2c_ack_frames_total();
|
||||||
|
|
||||||
Ok(MeWriterResponseOutcome::Continue {
|
Ok(MeWriterResponseOutcome::Continue {
|
||||||
@@ -2608,6 +2625,7 @@ async fn write_client_payload<W>(
|
|||||||
data: &[u8],
|
data: &[u8],
|
||||||
rng: &SecureRandom,
|
rng: &SecureRandom,
|
||||||
frame_buf: &mut Vec<u8>,
|
frame_buf: &mut Vec<u8>,
|
||||||
|
cancel: &CancellationToken,
|
||||||
) -> Result<MeD2cWriteMode>
|
) -> Result<MeD2cWriteMode>
|
||||||
where
|
where
|
||||||
W: AsyncWrite + Unpin + Send + 'static,
|
W: AsyncWrite + Unpin + Send + 'static,
|
||||||
@@ -2635,21 +2653,12 @@ where
|
|||||||
frame_buf.reserve(wire_len);
|
frame_buf.reserve(wire_len);
|
||||||
frame_buf.push(first);
|
frame_buf.push(first);
|
||||||
frame_buf.extend_from_slice(data);
|
frame_buf.extend_from_slice(data);
|
||||||
client_writer
|
write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
|
||||||
.write_all(frame_buf.as_slice())
|
|
||||||
.await
|
|
||||||
.map_err(ProxyError::Io)?;
|
|
||||||
MeD2cWriteMode::Coalesced
|
MeD2cWriteMode::Coalesced
|
||||||
} else {
|
} else {
|
||||||
let header = [first];
|
let header = [first];
|
||||||
client_writer
|
write_all_client_or_cancel(client_writer, &header, cancel).await?;
|
||||||
.write_all(&header)
|
write_all_client_or_cancel(client_writer, data, cancel).await?;
|
||||||
.await
|
|
||||||
.map_err(ProxyError::Io)?;
|
|
||||||
client_writer
|
|
||||||
.write_all(data)
|
|
||||||
.await
|
|
||||||
.map_err(ProxyError::Io)?;
|
|
||||||
MeD2cWriteMode::Split
|
MeD2cWriteMode::Split
|
||||||
}
|
}
|
||||||
} else if len_words < (1 << 24) {
|
} else if len_words < (1 << 24) {
|
||||||
@@ -2664,21 +2673,12 @@ where
|
|||||||
frame_buf.reserve(wire_len);
|
frame_buf.reserve(wire_len);
|
||||||
frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]);
|
frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]);
|
||||||
frame_buf.extend_from_slice(data);
|
frame_buf.extend_from_slice(data);
|
||||||
client_writer
|
write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
|
||||||
.write_all(frame_buf.as_slice())
|
|
||||||
.await
|
|
||||||
.map_err(ProxyError::Io)?;
|
|
||||||
MeD2cWriteMode::Coalesced
|
MeD2cWriteMode::Coalesced
|
||||||
} else {
|
} else {
|
||||||
let header = [first, lw[0], lw[1], lw[2]];
|
let header = [first, lw[0], lw[1], lw[2]];
|
||||||
client_writer
|
write_all_client_or_cancel(client_writer, &header, cancel).await?;
|
||||||
.write_all(&header)
|
write_all_client_or_cancel(client_writer, data, cancel).await?;
|
||||||
.await
|
|
||||||
.map_err(ProxyError::Io)?;
|
|
||||||
client_writer
|
|
||||||
.write_all(data)
|
|
||||||
.await
|
|
||||||
.map_err(ProxyError::Io)?;
|
|
||||||
MeD2cWriteMode::Split
|
MeD2cWriteMode::Split
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -2713,21 +2713,12 @@ where
|
|||||||
frame_buf.resize(start + padding_len, 0);
|
frame_buf.resize(start + padding_len, 0);
|
||||||
rng.fill(&mut frame_buf[start..]);
|
rng.fill(&mut frame_buf[start..]);
|
||||||
}
|
}
|
||||||
client_writer
|
write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
|
||||||
.write_all(frame_buf.as_slice())
|
|
||||||
.await
|
|
||||||
.map_err(ProxyError::Io)?;
|
|
||||||
MeD2cWriteMode::Coalesced
|
MeD2cWriteMode::Coalesced
|
||||||
} else {
|
} else {
|
||||||
let header = len_val.to_le_bytes();
|
let header = len_val.to_le_bytes();
|
||||||
client_writer
|
write_all_client_or_cancel(client_writer, &header, cancel).await?;
|
||||||
.write_all(&header)
|
write_all_client_or_cancel(client_writer, data, cancel).await?;
|
||||||
.await
|
|
||||||
.map_err(ProxyError::Io)?;
|
|
||||||
client_writer
|
|
||||||
.write_all(data)
|
|
||||||
.await
|
|
||||||
.map_err(ProxyError::Io)?;
|
|
||||||
if padding_len > 0 {
|
if padding_len > 0 {
|
||||||
frame_buf.clear();
|
frame_buf.clear();
|
||||||
if frame_buf.capacity() < padding_len {
|
if frame_buf.capacity() < padding_len {
|
||||||
@@ -2735,10 +2726,7 @@ where
|
|||||||
}
|
}
|
||||||
frame_buf.resize(padding_len, 0);
|
frame_buf.resize(padding_len, 0);
|
||||||
rng.fill(frame_buf.as_mut_slice());
|
rng.fill(frame_buf.as_mut_slice());
|
||||||
client_writer
|
write_all_client_or_cancel(client_writer, frame_buf.as_slice(), cancel).await?;
|
||||||
.write_all(frame_buf.as_slice())
|
|
||||||
.await
|
|
||||||
.map_err(ProxyError::Io)?;
|
|
||||||
}
|
}
|
||||||
MeD2cWriteMode::Split
|
MeD2cWriteMode::Split
|
||||||
}
|
}
|
||||||
@@ -2752,6 +2740,7 @@ async fn write_client_ack<W>(
|
|||||||
client_writer: &mut CryptoWriter<W>,
|
client_writer: &mut CryptoWriter<W>,
|
||||||
proto_tag: ProtoTag,
|
proto_tag: ProtoTag,
|
||||||
confirm: u32,
|
confirm: u32,
|
||||||
|
cancel: &CancellationToken,
|
||||||
) -> Result<()>
|
) -> Result<()>
|
||||||
where
|
where
|
||||||
W: AsyncWrite + Unpin + Send + 'static,
|
W: AsyncWrite + Unpin + Send + 'static,
|
||||||
@@ -2761,10 +2750,34 @@ where
|
|||||||
} else {
|
} else {
|
||||||
confirm.to_le_bytes()
|
confirm.to_le_bytes()
|
||||||
};
|
};
|
||||||
client_writer
|
write_all_client_or_cancel(client_writer, &bytes, cancel).await
|
||||||
.write_all(&bytes)
|
}
|
||||||
.await
|
|
||||||
.map_err(ProxyError::Io)
|
async fn write_all_client_or_cancel<W>(
|
||||||
|
client_writer: &mut CryptoWriter<W>,
|
||||||
|
bytes: &[u8],
|
||||||
|
cancel: &CancellationToken,
|
||||||
|
) -> Result<()>
|
||||||
|
where
|
||||||
|
W: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
tokio::select! {
|
||||||
|
result = client_writer.write_all(bytes) => result.map_err(ProxyError::Io),
|
||||||
|
_ = cancel.cancelled() => Err(ProxyError::Proxy("ME client writer cancelled".into())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn flush_client_or_cancel<W>(
|
||||||
|
client_writer: &mut CryptoWriter<W>,
|
||||||
|
cancel: &CancellationToken,
|
||||||
|
) -> Result<()>
|
||||||
|
where
|
||||||
|
W: AsyncWrite + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
tokio::select! {
|
||||||
|
result = client_writer.flush() => result.map_err(ProxyError::Io),
|
||||||
|
_ = cancel.cancelled() => Err(ProxyError::Proxy("ME client writer cancelled".into())),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -637,6 +637,22 @@ fn unknown_dc_log_path_revalidation_rejects_parent_swapped_to_symlink() {
|
|||||||
"telemt-unknown-dc-parent-swap-{}",
|
"telemt-unknown-dc-parent-swap-{}",
|
||||||
std::process::id()
|
std::process::id()
|
||||||
));
|
));
|
||||||
|
if let Ok(meta) = fs::symlink_metadata(&parent) {
|
||||||
|
if meta.file_type().is_symlink() || meta.is_file() {
|
||||||
|
fs::remove_file(&parent).expect("stale parent-swap path must be removable");
|
||||||
|
} else {
|
||||||
|
fs::remove_dir_all(&parent).expect("stale parent-swap directory must be removable");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let moved = parent.with_extension("bak");
|
||||||
|
if let Ok(meta) = fs::symlink_metadata(&moved) {
|
||||||
|
if meta.file_type().is_symlink() || meta.is_file() {
|
||||||
|
fs::remove_file(&moved).expect("stale parent-swap backup path must be removable");
|
||||||
|
} else {
|
||||||
|
fs::remove_dir_all(&moved)
|
||||||
|
.expect("stale parent-swap backup directory must be removable");
|
||||||
|
}
|
||||||
|
}
|
||||||
fs::create_dir_all(&parent).expect("parent-swap test parent must be creatable");
|
fs::create_dir_all(&parent).expect("parent-swap test parent must be creatable");
|
||||||
|
|
||||||
let rel_candidate = format!(
|
let rel_candidate = format!(
|
||||||
@@ -646,8 +662,6 @@ fn unknown_dc_log_path_revalidation_rejects_parent_swapped_to_symlink() {
|
|||||||
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate)
|
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate)
|
||||||
.expect("candidate must sanitize before parent swap");
|
.expect("candidate must sanitize before parent swap");
|
||||||
|
|
||||||
let moved = parent.with_extension("bak");
|
|
||||||
let _ = fs::remove_dir_all(&moved);
|
|
||||||
fs::rename(&parent, &moved).expect("parent must be movable for swap simulation");
|
fs::rename(&parent, &moved).expect("parent must be movable for swap simulation");
|
||||||
symlink("/tmp", &parent).expect("symlink replacement for parent must be creatable");
|
symlink("/tmp", &parent).expect("symlink replacement for parent must be creatable");
|
||||||
|
|
||||||
@@ -720,6 +734,24 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() {
|
|||||||
"telemt-unknown-dc-parent-swap-openat-{}",
|
"telemt-unknown-dc-parent-swap-openat-{}",
|
||||||
std::process::id()
|
std::process::id()
|
||||||
));
|
));
|
||||||
|
if let Ok(meta) = fs::symlink_metadata(&base) {
|
||||||
|
if meta.file_type().is_symlink() || meta.is_file() {
|
||||||
|
fs::remove_file(&base).expect("stale parent-swap-openat path must be removable");
|
||||||
|
} else {
|
||||||
|
fs::remove_dir_all(&base)
|
||||||
|
.expect("stale parent-swap-openat directory must be removable");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let moved = base.with_extension("bak");
|
||||||
|
if let Ok(meta) = fs::symlink_metadata(&moved) {
|
||||||
|
if meta.file_type().is_symlink() || meta.is_file() {
|
||||||
|
fs::remove_file(&moved)
|
||||||
|
.expect("stale parent-swap-openat backup path must be removable");
|
||||||
|
} else {
|
||||||
|
fs::remove_dir_all(&moved)
|
||||||
|
.expect("stale parent-swap-openat backup directory must be removable");
|
||||||
|
}
|
||||||
|
}
|
||||||
fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable");
|
fs::create_dir_all(&base).expect("parent-swap-openat base must be creatable");
|
||||||
|
|
||||||
let rel_candidate = format!(
|
let rel_candidate = format!(
|
||||||
@@ -743,8 +775,6 @@ fn adversarial_parent_swap_after_check_is_blocked_by_anchored_open() {
|
|||||||
let outside_target = outside_parent.join("unknown-dc.log");
|
let outside_target = outside_parent.join("unknown-dc.log");
|
||||||
let _ = fs::remove_file(&outside_target);
|
let _ = fs::remove_file(&outside_target);
|
||||||
|
|
||||||
let moved = base.with_extension("bak");
|
|
||||||
let _ = fs::remove_dir_all(&moved);
|
|
||||||
fs::rename(&base, &moved).expect("base parent must be movable for swap simulation");
|
fs::rename(&base, &moved).expect("base parent must be movable for swap simulation");
|
||||||
symlink(&outside_parent, &base).expect("base parent symlink replacement must be creatable");
|
symlink(&outside_parent, &base).expect("base parent symlink replacement must be creatable");
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ struct CountedWriter {
|
|||||||
fail_writes: bool,
|
fail_writes: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct StalledWriter;
|
||||||
|
|
||||||
impl CountedWriter {
|
impl CountedWriter {
|
||||||
fn new(write_calls: Arc<AtomicUsize>, fail_writes: bool) -> Self {
|
fn new(write_calls: Arc<AtomicUsize>, fail_writes: bool) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -49,12 +51,36 @@ impl AsyncWrite for CountedWriter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for StalledWriter {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
_buf: &[u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn make_crypto_writer(inner: CountedWriter) -> CryptoWriter<CountedWriter> {
|
fn make_crypto_writer(inner: CountedWriter) -> CryptoWriter<CountedWriter> {
|
||||||
let key = [0u8; 32];
|
let key = [0u8; 32];
|
||||||
let iv = 0u128;
|
let iv = 0u128;
|
||||||
CryptoWriter::new(inner, AesCtr::new(&key, iv), 8 * 1024)
|
CryptoWriter::new(inner, AesCtr::new(&key, iv), 8 * 1024)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn make_stalled_crypto_writer() -> CryptoWriter<StalledWriter> {
|
||||||
|
let key = [0u8; 32];
|
||||||
|
let iv = 0u128;
|
||||||
|
CryptoWriter::new(StalledWriter, AesCtr::new(&key, iv), 8 * 1024)
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() {
|
async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() {
|
||||||
let stats = Stats::new();
|
let stats = Stats::new();
|
||||||
@@ -189,3 +215,53 @@ async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() {
|
|||||||
);
|
);
|
||||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn me_writer_data_write_obeys_flow_cancellation() {
|
||||||
|
let stats = Stats::new();
|
||||||
|
let user = "middle-me-writer-cancel-user";
|
||||||
|
let mut writer = make_stalled_crypto_writer();
|
||||||
|
let mut frame_buf = Vec::new();
|
||||||
|
let bytes_me2c = AtomicU64::new(0);
|
||||||
|
let cancel = CancellationToken::new();
|
||||||
|
cancel.cancel();
|
||||||
|
|
||||||
|
let result = process_me_writer_response_with_traffic_lease(
|
||||||
|
MeResponse::Data {
|
||||||
|
flags: 0,
|
||||||
|
data: Bytes::from_static(&[0x31, 0x32, 0x33, 0x34]),
|
||||||
|
route_permit: None,
|
||||||
|
},
|
||||||
|
&mut writer,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
&SecureRandom::new(),
|
||||||
|
&mut frame_buf,
|
||||||
|
&stats,
|
||||||
|
user,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
&cancel,
|
||||||
|
&bytes_me2c,
|
||||||
|
13,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(result, Err(ProxyError::Proxy(ref message)) if message == "ME client writer cancelled"),
|
||||||
|
"cancelled middle writer must return a bounded cancellation error"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
bytes_me2c.load(Ordering::Relaxed),
|
||||||
|
0,
|
||||||
|
"cancelled write must not advance committed ME->C bytes"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
stats.get_user_total_octets(user),
|
||||||
|
0,
|
||||||
|
"cancelled write must not advance user output telemetry"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,10 +4,67 @@ use std::io;
|
|||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
use std::task::{Context, Poll};
|
use std::task::{Context, Poll, Wake};
|
||||||
use tokio::io::{AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
|
||||||
|
enum ReadStep {
|
||||||
|
Data(Vec<u8>),
|
||||||
|
Pending,
|
||||||
|
Eof,
|
||||||
|
Error,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ScriptedReader {
|
||||||
|
scripted_reads: Arc<Mutex<VecDeque<ReadStep>>>,
|
||||||
|
read_calls: Arc<AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScriptedReader {
|
||||||
|
fn new(script: Vec<ReadStep>, read_calls: Arc<AtomicUsize>) -> Self {
|
||||||
|
Self {
|
||||||
|
scripted_reads: Arc::new(Mutex::new(script.into())),
|
||||||
|
read_calls,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for ScriptedReader {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<io::Result<()>> {
|
||||||
|
let this = self.get_mut();
|
||||||
|
this.read_calls.fetch_add(1, Ordering::Relaxed);
|
||||||
|
let step = this
|
||||||
|
.scripted_reads
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||||
|
.pop_front()
|
||||||
|
.unwrap_or(ReadStep::Eof);
|
||||||
|
match step {
|
||||||
|
ReadStep::Data(data) => {
|
||||||
|
let n = data.len().min(buf.remaining());
|
||||||
|
buf.put_slice(&data[..n]);
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
ReadStep::Pending => Poll::Pending,
|
||||||
|
ReadStep::Eof => Poll::Ready(Ok(())),
|
||||||
|
ReadStep::Error => Poll::Ready(Err(io::Error::new(
|
||||||
|
io::ErrorKind::BrokenPipe,
|
||||||
|
"forced read failure",
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct NoopWake;
|
||||||
|
|
||||||
|
impl Wake for NoopWake {
|
||||||
|
fn wake(self: Arc<Self>) {}
|
||||||
|
}
|
||||||
|
|
||||||
struct ScriptedWriter {
|
struct ScriptedWriter {
|
||||||
scripted_writes: Arc<Mutex<VecDeque<usize>>>,
|
scripted_writes: Arc<Mutex<VecDeque<usize>>>,
|
||||||
write_calls: Arc<AtomicUsize>,
|
write_calls: Arc<AtomicUsize>,
|
||||||
@@ -80,6 +137,131 @@ fn make_stats_io_with_script(
|
|||||||
(io, stats, write_calls, quota_exceeded)
|
(io, stats, write_calls, quota_exceeded)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn make_stats_io_with_read_script(
|
||||||
|
user: &str,
|
||||||
|
quota_limit: u64,
|
||||||
|
precharged_quota: u64,
|
||||||
|
script: Vec<ReadStep>,
|
||||||
|
) -> (
|
||||||
|
StatsIo<ScriptedReader>,
|
||||||
|
Arc<Stats>,
|
||||||
|
Arc<AtomicUsize>,
|
||||||
|
Arc<AtomicBool>,
|
||||||
|
) {
|
||||||
|
let stats = Arc::new(Stats::new());
|
||||||
|
if precharged_quota > 0 {
|
||||||
|
let user_stats = stats.get_or_create_user_stats_handle(user);
|
||||||
|
stats.quota_charge_post_write(user_stats.as_ref(), precharged_quota);
|
||||||
|
}
|
||||||
|
|
||||||
|
let read_calls = Arc::new(AtomicUsize::new(0));
|
||||||
|
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||||
|
let io = StatsIo::new(
|
||||||
|
ScriptedReader::new(script, read_calls.clone()),
|
||||||
|
Arc::new(SharedCounters::new()),
|
||||||
|
stats.clone(),
|
||||||
|
user.to_string(),
|
||||||
|
Some(quota_limit),
|
||||||
|
quota_exceeded.clone(),
|
||||||
|
Instant::now(),
|
||||||
|
);
|
||||||
|
|
||||||
|
(io, stats, read_calls, quota_exceeded)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_read_once<R: AsyncRead + Unpin>(
|
||||||
|
io: &mut StatsIo<R>,
|
||||||
|
storage: &mut [u8],
|
||||||
|
) -> Poll<io::Result<usize>> {
|
||||||
|
let waker = Arc::new(NoopWake).into();
|
||||||
|
let mut cx = Context::from_waker(&waker);
|
||||||
|
let mut read_buf = ReadBuf::new(storage);
|
||||||
|
let before = read_buf.filled().len();
|
||||||
|
match Pin::new(io).poll_read(&mut cx, &mut read_buf) {
|
||||||
|
Poll::Ready(Ok(())) => Poll::Ready(Ok(read_buf.filled().len() - before)),
|
||||||
|
Poll::Ready(Err(error)) => Poll::Ready(Err(error)),
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn direct_c2s_quota_refunds_unused_on_short_read() {
|
||||||
|
let user = "direct-c2s-short-read-refund-user";
|
||||||
|
let (mut io, stats, read_calls, quota_exceeded) = make_stats_io_with_read_script(
|
||||||
|
user,
|
||||||
|
64,
|
||||||
|
0,
|
||||||
|
vec![ReadStep::Data(vec![0x11; 5])],
|
||||||
|
);
|
||||||
|
let mut storage = [0u8; 16];
|
||||||
|
|
||||||
|
let n = match poll_read_once(&mut io, &mut storage) {
|
||||||
|
Poll::Ready(Ok(n)) => n,
|
||||||
|
other => panic!("short read must complete, got {other:?}"),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(n, 5);
|
||||||
|
assert_eq!(read_calls.load(Ordering::Relaxed), 1);
|
||||||
|
assert_eq!(stats.get_user_quota_used(user), 5);
|
||||||
|
assert_eq!(stats.get_quota_refund_bytes_total(), 11);
|
||||||
|
assert!(!quota_exceeded.load(Ordering::Acquire));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn direct_c2s_quota_refunds_full_reservation_on_pending() {
|
||||||
|
let user = "direct-c2s-pending-refund-user";
|
||||||
|
let (mut io, stats, read_calls, quota_exceeded) =
|
||||||
|
make_stats_io_with_read_script(user, 64, 0, vec![ReadStep::Pending]);
|
||||||
|
let mut storage = [0u8; 16];
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
poll_read_once(&mut io, &mut storage),
|
||||||
|
Poll::Pending
|
||||||
|
));
|
||||||
|
assert_eq!(read_calls.load(Ordering::Relaxed), 1);
|
||||||
|
assert_eq!(stats.get_user_quota_used(user), 0);
|
||||||
|
assert_eq!(stats.get_quota_refund_bytes_total(), 16);
|
||||||
|
assert!(!quota_exceeded.load(Ordering::Acquire));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn direct_c2s_quota_refunds_full_reservation_on_eof() {
|
||||||
|
let user = "direct-c2s-eof-refund-user";
|
||||||
|
let (mut io, stats, read_calls, quota_exceeded) =
|
||||||
|
make_stats_io_with_read_script(user, 64, 0, vec![ReadStep::Eof]);
|
||||||
|
let mut storage = [0u8; 16];
|
||||||
|
|
||||||
|
let n = match poll_read_once(&mut io, &mut storage) {
|
||||||
|
Poll::Ready(Ok(n)) => n,
|
||||||
|
other => panic!("EOF read must complete with zero bytes, got {other:?}"),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(n, 0);
|
||||||
|
assert_eq!(read_calls.load(Ordering::Relaxed), 1);
|
||||||
|
assert_eq!(stats.get_user_quota_used(user), 0);
|
||||||
|
assert_eq!(stats.get_quota_refund_bytes_total(), 16);
|
||||||
|
assert!(!quota_exceeded.load(Ordering::Acquire));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn direct_c2s_quota_refunds_full_reservation_on_error() {
|
||||||
|
let user = "direct-c2s-error-refund-user";
|
||||||
|
let (mut io, stats, read_calls, quota_exceeded) =
|
||||||
|
make_stats_io_with_read_script(user, 64, 0, vec![ReadStep::Error]);
|
||||||
|
let mut storage = [0u8; 16];
|
||||||
|
|
||||||
|
let error = match poll_read_once(&mut io, &mut storage) {
|
||||||
|
Poll::Ready(Err(error)) => error,
|
||||||
|
other => panic!("error read must return error, got {other:?}"),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(error.kind(), io::ErrorKind::BrokenPipe);
|
||||||
|
assert_eq!(read_calls.load(Ordering::Relaxed), 1);
|
||||||
|
assert_eq!(stats.get_user_quota_used(user), 0);
|
||||||
|
assert_eq!(stats.get_quota_refund_bytes_total(), 16);
|
||||||
|
assert!(!quota_exceeded.load(Ordering::Acquire));
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn direct_partial_write_charges_only_committed_bytes_without_double_charge() {
|
async fn direct_partial_write_charges_only_committed_bytes_without_double_charge() {
|
||||||
let user = "direct-partial-charge-user";
|
let user = "direct-partial-charge-user";
|
||||||
|
|||||||
Reference in New Issue
Block a user