mirror of https://github.com/telemt/telemt.git
feat(proxy): implement timeout handling for client payload reads and add corresponding tests
This commit is contained in:
parent
e0d821c6b6
commit
a1caebbe6f
|
|
@ -2093,7 +2093,7 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "telemt"
|
name = "telemt"
|
||||||
version = "3.3.18"
|
version = "3.3.19"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aes",
|
"aes",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ use bytes::{Bytes, BytesMut};
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
use tokio::sync::{mpsc, oneshot, watch};
|
use tokio::sync::{mpsc, oneshot, watch};
|
||||||
|
use tokio::time::timeout;
|
||||||
use tracing::{debug, trace, warn};
|
use tracing::{debug, trace, warn};
|
||||||
|
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
|
|
@ -581,6 +582,7 @@ where
|
||||||
&mut crypto_reader,
|
&mut crypto_reader,
|
||||||
proto_tag,
|
proto_tag,
|
||||||
frame_limit,
|
frame_limit,
|
||||||
|
Duration::from_secs(config.timeouts.client_handshake.max(1)),
|
||||||
&buffer_pool,
|
&buffer_pool,
|
||||||
&forensics,
|
&forensics,
|
||||||
&mut frame_counter,
|
&mut frame_counter,
|
||||||
|
|
@ -670,6 +672,7 @@ async fn read_client_payload<R>(
|
||||||
client_reader: &mut CryptoReader<R>,
|
client_reader: &mut CryptoReader<R>,
|
||||||
proto_tag: ProtoTag,
|
proto_tag: ProtoTag,
|
||||||
max_frame: usize,
|
max_frame: usize,
|
||||||
|
frame_read_timeout: Duration,
|
||||||
buffer_pool: &Arc<BufferPool>,
|
buffer_pool: &Arc<BufferPool>,
|
||||||
forensics: &RelayForensicsState,
|
forensics: &RelayForensicsState,
|
||||||
frame_counter: &mut u64,
|
frame_counter: &mut u64,
|
||||||
|
|
@ -678,23 +681,40 @@ async fn read_client_payload<R>(
|
||||||
where
|
where
|
||||||
R: AsyncRead + Unpin + Send + 'static,
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
{
|
{
|
||||||
|
async fn read_exact_with_timeout<R>(
|
||||||
|
client_reader: &mut CryptoReader<R>,
|
||||||
|
buf: &mut [u8],
|
||||||
|
frame_read_timeout: Duration,
|
||||||
|
) -> Result<()>
|
||||||
|
where
|
||||||
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
match timeout(frame_read_timeout, client_reader.read_exact(buf)).await {
|
||||||
|
Ok(Ok(_)) => Ok(()),
|
||||||
|
Ok(Err(e)) => Err(ProxyError::Io(e)),
|
||||||
|
Err(_) => Err(ProxyError::Io(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::TimedOut,
|
||||||
|
"middle-relay client frame read timeout",
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let (len, quickack, raw_len_bytes) = match proto_tag {
|
let (len, quickack, raw_len_bytes) = match proto_tag {
|
||||||
ProtoTag::Abridged => {
|
ProtoTag::Abridged => {
|
||||||
let mut first = [0u8; 1];
|
let mut first = [0u8; 1];
|
||||||
match client_reader.read_exact(&mut first).await {
|
match read_exact_with_timeout(client_reader, &mut first, frame_read_timeout).await {
|
||||||
Ok(_) => {}
|
Ok(()) => {}
|
||||||
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
|
Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
|
||||||
Err(e) => return Err(ProxyError::Io(e)),
|
return Ok(None);
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
}
|
}
|
||||||
|
|
||||||
let quickack = (first[0] & 0x80) != 0;
|
let quickack = (first[0] & 0x80) != 0;
|
||||||
let len_words = if (first[0] & 0x7f) == 0x7f {
|
let len_words = if (first[0] & 0x7f) == 0x7f {
|
||||||
let mut ext = [0u8; 3];
|
let mut ext = [0u8; 3];
|
||||||
client_reader
|
read_exact_with_timeout(client_reader, &mut ext, frame_read_timeout).await?;
|
||||||
.read_exact(&mut ext)
|
|
||||||
.await
|
|
||||||
.map_err(ProxyError::Io)?;
|
|
||||||
u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize
|
u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize
|
||||||
} else {
|
} else {
|
||||||
(first[0] & 0x7f) as usize
|
(first[0] & 0x7f) as usize
|
||||||
|
|
@ -707,10 +727,12 @@ where
|
||||||
}
|
}
|
||||||
ProtoTag::Intermediate | ProtoTag::Secure => {
|
ProtoTag::Intermediate | ProtoTag::Secure => {
|
||||||
let mut len_buf = [0u8; 4];
|
let mut len_buf = [0u8; 4];
|
||||||
match client_reader.read_exact(&mut len_buf).await {
|
match read_exact_with_timeout(client_reader, &mut len_buf, frame_read_timeout).await {
|
||||||
Ok(_) => {}
|
Ok(()) => {}
|
||||||
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
|
Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
|
||||||
Err(e) => return Err(ProxyError::Io(e)),
|
return Ok(None);
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
}
|
}
|
||||||
let quickack = (len_buf[3] & 0x80) != 0;
|
let quickack = (len_buf[3] & 0x80) != 0;
|
||||||
(
|
(
|
||||||
|
|
@ -769,10 +791,8 @@ where
|
||||||
let chunk_len = remaining.min(chunk_cap);
|
let chunk_len = remaining.min(chunk_cap);
|
||||||
let mut chunk = buffer_pool.get();
|
let mut chunk = buffer_pool.get();
|
||||||
chunk.resize(chunk_len, 0);
|
chunk.resize(chunk_len, 0);
|
||||||
client_reader
|
read_exact_with_timeout(client_reader, &mut chunk[..chunk_len], frame_read_timeout)
|
||||||
.read_exact(&mut chunk[..chunk_len])
|
.await?;
|
||||||
.await
|
|
||||||
.map_err(ProxyError::Io)?;
|
|
||||||
payload.extend_from_slice(&chunk[..chunk_len]);
|
payload.extend_from_slice(&chunk[..chunk_len]);
|
||||||
remaining -= chunk_len;
|
remaining -= chunk_len;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,12 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::crypto::AesCtr;
|
||||||
|
use crate::stats::Stats;
|
||||||
|
use crate::stream::{BufferPool, CryptoReader};
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::AtomicU64;
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
use tokio::io::duplex;
|
||||||
use tokio::time::{Duration as TokioDuration, timeout};
|
use tokio::time::{Duration as TokioDuration, timeout};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -101,3 +109,93 @@ fn desync_dedup_cache_is_bounded() {
|
||||||
"already tracked key inside dedup window must stay suppressed"
|
"already tracked key inside dedup window must stay suppressed"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn make_forensics_state() -> RelayForensicsState {
|
||||||
|
RelayForensicsState {
|
||||||
|
trace_id: 1,
|
||||||
|
conn_id: 2,
|
||||||
|
user: "test-user".to_string(),
|
||||||
|
peer: "127.0.0.1:50000".parse::<SocketAddr>().unwrap(),
|
||||||
|
peer_hash: 3,
|
||||||
|
started_at: Instant::now(),
|
||||||
|
bytes_c2me: 0,
|
||||||
|
bytes_me2c: Arc::new(AtomicU64::new(0)),
|
||||||
|
desync_all_full: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_crypto_reader(reader: tokio::io::DuplexStream) -> CryptoReader<tokio::io::DuplexStream> {
|
||||||
|
let key = [0u8; 32];
|
||||||
|
let iv = 0u128;
|
||||||
|
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
|
||||||
|
let key = [0u8; 32];
|
||||||
|
let iv = 0u128;
|
||||||
|
let mut cipher = AesCtr::new(&key, iv);
|
||||||
|
cipher.encrypt(plaintext)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_times_out_on_header_stall() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
let (reader, _writer) = duplex(1024);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let result = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
1024,
|
||||||
|
TokioDuration::from_millis(25),
|
||||||
|
&buffer_pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut),
|
||||||
|
"stalled header read must time out"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_times_out_on_payload_stall() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
let (reader, mut writer) = duplex(1024);
|
||||||
|
let encrypted_len = encrypt_for_reader(&[8, 0, 0, 0]);
|
||||||
|
writer.write_all(&encrypted_len).await.unwrap();
|
||||||
|
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let result = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
1024,
|
||||||
|
TokioDuration::from_millis(25),
|
||||||
|
&buffer_pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut),
|
||||||
|
"stalled payload body read must time out"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue