feat(proxy): implement timeout handling for client payload reads and add corresponding tests

This commit is contained in:
David Osipov 2026-03-17 01:53:44 +04:00
parent e0d821c6b6
commit a1caebbe6f
No known key found for this signature in database
GPG Key ID: 0E55C4A47454E82E
3 changed files with 135 additions and 17 deletions

2
Cargo.lock generated
View File

@ -2093,7 +2093,7 @@ dependencies = [
[[package]] [[package]]
name = "telemt" name = "telemt"
version = "3.3.18" version = "3.3.19"
dependencies = [ dependencies = [
"aes", "aes",
"anyhow", "anyhow",

View File

@ -11,6 +11,7 @@ use bytes::{Bytes, BytesMut};
use dashmap::DashMap; use dashmap::DashMap;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot, watch}; use tokio::sync::{mpsc, oneshot, watch};
use tokio::time::timeout;
use tracing::{debug, trace, warn}; use tracing::{debug, trace, warn};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
@ -581,6 +582,7 @@ where
&mut crypto_reader, &mut crypto_reader,
proto_tag, proto_tag,
frame_limit, frame_limit,
Duration::from_secs(config.timeouts.client_handshake.max(1)),
&buffer_pool, &buffer_pool,
&forensics, &forensics,
&mut frame_counter, &mut frame_counter,
@ -670,6 +672,7 @@ async fn read_client_payload<R>(
client_reader: &mut CryptoReader<R>, client_reader: &mut CryptoReader<R>,
proto_tag: ProtoTag, proto_tag: ProtoTag,
max_frame: usize, max_frame: usize,
frame_read_timeout: Duration,
buffer_pool: &Arc<BufferPool>, buffer_pool: &Arc<BufferPool>,
forensics: &RelayForensicsState, forensics: &RelayForensicsState,
frame_counter: &mut u64, frame_counter: &mut u64,
@ -678,23 +681,40 @@ async fn read_client_payload<R>(
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
{ {
async fn read_exact_with_timeout<R>(
client_reader: &mut CryptoReader<R>,
buf: &mut [u8],
frame_read_timeout: Duration,
) -> Result<()>
where
R: AsyncRead + Unpin + Send + 'static,
{
match timeout(frame_read_timeout, client_reader.read_exact(buf)).await {
Ok(Ok(_)) => Ok(()),
Ok(Err(e)) => Err(ProxyError::Io(e)),
Err(_) => Err(ProxyError::Io(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"middle-relay client frame read timeout",
))),
}
}
loop { loop {
let (len, quickack, raw_len_bytes) = match proto_tag { let (len, quickack, raw_len_bytes) = match proto_tag {
ProtoTag::Abridged => { ProtoTag::Abridged => {
let mut first = [0u8; 1]; let mut first = [0u8; 1];
match client_reader.read_exact(&mut first).await { match read_exact_with_timeout(client_reader, &mut first, frame_read_timeout).await {
Ok(_) => {} Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
Err(e) => return Err(ProxyError::Io(e)), return Ok(None);
}
Err(e) => return Err(e),
} }
let quickack = (first[0] & 0x80) != 0; let quickack = (first[0] & 0x80) != 0;
let len_words = if (first[0] & 0x7f) == 0x7f { let len_words = if (first[0] & 0x7f) == 0x7f {
let mut ext = [0u8; 3]; let mut ext = [0u8; 3];
client_reader read_exact_with_timeout(client_reader, &mut ext, frame_read_timeout).await?;
.read_exact(&mut ext)
.await
.map_err(ProxyError::Io)?;
u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize
} else { } else {
(first[0] & 0x7f) as usize (first[0] & 0x7f) as usize
@ -707,10 +727,12 @@ where
} }
ProtoTag::Intermediate | ProtoTag::Secure => { ProtoTag::Intermediate | ProtoTag::Secure => {
let mut len_buf = [0u8; 4]; let mut len_buf = [0u8; 4];
match client_reader.read_exact(&mut len_buf).await { match read_exact_with_timeout(client_reader, &mut len_buf, frame_read_timeout).await {
Ok(_) => {} Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
Err(e) => return Err(ProxyError::Io(e)), return Ok(None);
}
Err(e) => return Err(e),
} }
let quickack = (len_buf[3] & 0x80) != 0; let quickack = (len_buf[3] & 0x80) != 0;
( (
@ -769,10 +791,8 @@ where
let chunk_len = remaining.min(chunk_cap); let chunk_len = remaining.min(chunk_cap);
let mut chunk = buffer_pool.get(); let mut chunk = buffer_pool.get();
chunk.resize(chunk_len, 0); chunk.resize(chunk_len, 0);
client_reader read_exact_with_timeout(client_reader, &mut chunk[..chunk_len], frame_read_timeout)
.read_exact(&mut chunk[..chunk_len]) .await?;
.await
.map_err(ProxyError::Io)?;
payload.extend_from_slice(&chunk[..chunk_len]); payload.extend_from_slice(&chunk[..chunk_len]);
remaining -= chunk_len; remaining -= chunk_len;
} }

View File

@ -1,4 +1,12 @@
use super::*; use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use tokio::io::AsyncWriteExt;
use tokio::io::duplex;
use tokio::time::{Duration as TokioDuration, timeout}; use tokio::time::{Duration as TokioDuration, timeout};
#[test] #[test]
@ -101,3 +109,93 @@ fn desync_dedup_cache_is_bounded() {
"already tracked key inside dedup window must stay suppressed" "already tracked key inside dedup window must stay suppressed"
); );
} }
fn make_forensics_state() -> RelayForensicsState {
RelayForensicsState {
trace_id: 1,
conn_id: 2,
user: "test-user".to_string(),
peer: "127.0.0.1:50000".parse::<SocketAddr>().unwrap(),
peer_hash: 3,
started_at: Instant::now(),
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
fn make_crypto_reader(reader: tokio::io::DuplexStream) -> CryptoReader<tokio::io::DuplexStream> {
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
#[tokio::test]
async fn read_client_payload_times_out_on_header_stall() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("middle relay test lock must be available");
let (reader, _writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let forensics = make_forensics_state();
let mut frame_counter = 0;
let result = read_client_payload(
&mut crypto_reader,
ProtoTag::Intermediate,
1024,
TokioDuration::from_millis(25),
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
)
.await;
assert!(
matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut),
"stalled header read must time out"
);
}
#[tokio::test]
async fn read_client_payload_times_out_on_payload_stall() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("middle relay test lock must be available");
let (reader, mut writer) = duplex(1024);
let encrypted_len = encrypt_for_reader(&[8, 0, 0, 0]);
writer.write_all(&encrypted_len).await.unwrap();
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let forensics = make_forensics_state();
let mut frame_counter = 0;
let result = read_client_payload(
&mut crypto_reader,
ProtoTag::Intermediate,
1024,
TokioDuration::from_millis(25),
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
)
.await;
assert!(
matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut),
"stalled payload body read must time out"
);
}