mirror of
https://github.com/telemt/telemt.git
synced 2026-05-02 01:44:10 +03:00
feat(proxy): implement timeout handling for client payload reads and add corresponding tests
This commit is contained in:
@@ -11,6 +11,7 @@ use bytes::{Bytes, BytesMut};
|
||||
use dashmap::DashMap;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::sync::{mpsc, oneshot, watch};
|
||||
use tokio::time::timeout;
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
use crate::config::ProxyConfig;
|
||||
@@ -581,6 +582,7 @@ where
|
||||
&mut crypto_reader,
|
||||
proto_tag,
|
||||
frame_limit,
|
||||
Duration::from_secs(config.timeouts.client_handshake.max(1)),
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
@@ -670,6 +672,7 @@ async fn read_client_payload<R>(
|
||||
client_reader: &mut CryptoReader<R>,
|
||||
proto_tag: ProtoTag,
|
||||
max_frame: usize,
|
||||
frame_read_timeout: Duration,
|
||||
buffer_pool: &Arc<BufferPool>,
|
||||
forensics: &RelayForensicsState,
|
||||
frame_counter: &mut u64,
|
||||
@@ -678,23 +681,40 @@ async fn read_client_payload<R>(
|
||||
where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
async fn read_exact_with_timeout<R>(
|
||||
client_reader: &mut CryptoReader<R>,
|
||||
buf: &mut [u8],
|
||||
frame_read_timeout: Duration,
|
||||
) -> Result<()>
|
||||
where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
{
|
||||
match timeout(frame_read_timeout, client_reader.read_exact(buf)).await {
|
||||
Ok(Ok(_)) => Ok(()),
|
||||
Ok(Err(e)) => Err(ProxyError::Io(e)),
|
||||
Err(_) => Err(ProxyError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::TimedOut,
|
||||
"middle-relay client frame read timeout",
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
let (len, quickack, raw_len_bytes) = match proto_tag {
|
||||
ProtoTag::Abridged => {
|
||||
let mut first = [0u8; 1];
|
||||
match client_reader.read_exact(&mut first).await {
|
||||
Ok(_) => {}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(ProxyError::Io(e)),
|
||||
match read_exact_with_timeout(client_reader, &mut first, frame_read_timeout).await {
|
||||
Ok(()) => {}
|
||||
Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
|
||||
return Ok(None);
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
|
||||
let quickack = (first[0] & 0x80) != 0;
|
||||
let len_words = if (first[0] & 0x7f) == 0x7f {
|
||||
let mut ext = [0u8; 3];
|
||||
client_reader
|
||||
.read_exact(&mut ext)
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
read_exact_with_timeout(client_reader, &mut ext, frame_read_timeout).await?;
|
||||
u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize
|
||||
} else {
|
||||
(first[0] & 0x7f) as usize
|
||||
@@ -707,10 +727,12 @@ where
|
||||
}
|
||||
ProtoTag::Intermediate | ProtoTag::Secure => {
|
||||
let mut len_buf = [0u8; 4];
|
||||
match client_reader.read_exact(&mut len_buf).await {
|
||||
Ok(_) => {}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
|
||||
Err(e) => return Err(ProxyError::Io(e)),
|
||||
match read_exact_with_timeout(client_reader, &mut len_buf, frame_read_timeout).await {
|
||||
Ok(()) => {}
|
||||
Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
|
||||
return Ok(None);
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
let quickack = (len_buf[3] & 0x80) != 0;
|
||||
(
|
||||
@@ -769,10 +791,8 @@ where
|
||||
let chunk_len = remaining.min(chunk_cap);
|
||||
let mut chunk = buffer_pool.get();
|
||||
chunk.resize(chunk_len, 0);
|
||||
client_reader
|
||||
.read_exact(&mut chunk[..chunk_len])
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
read_exact_with_timeout(client_reader, &mut chunk[..chunk_len], frame_read_timeout)
|
||||
.await?;
|
||||
payload.extend_from_slice(&chunk[..chunk_len]);
|
||||
remaining -= chunk_len;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user