diff --git a/Cargo.lock b/Cargo.lock index b4cfbca..89eefd6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2093,7 +2093,7 @@ dependencies = [ [[package]] name = "telemt" -version = "3.3.18" +version = "3.3.19" dependencies = [ "aes", "anyhow", diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 0aaa016..ba01c74 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -11,6 +11,7 @@ use bytes::{Bytes, BytesMut}; use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::{mpsc, oneshot, watch}; +use tokio::time::timeout; use tracing::{debug, trace, warn}; use crate::config::ProxyConfig; @@ -581,6 +582,7 @@ where &mut crypto_reader, proto_tag, frame_limit, + Duration::from_secs(config.timeouts.client_handshake.max(1)), &buffer_pool, &forensics, &mut frame_counter, @@ -670,6 +672,7 @@ async fn read_client_payload( client_reader: &mut CryptoReader, proto_tag: ProtoTag, max_frame: usize, + frame_read_timeout: Duration, buffer_pool: &Arc, forensics: &RelayForensicsState, frame_counter: &mut u64, @@ -678,23 +681,40 @@ async fn read_client_payload( where R: AsyncRead + Unpin + Send + 'static, { + async fn read_exact_with_timeout( + client_reader: &mut CryptoReader, + buf: &mut [u8], + frame_read_timeout: Duration, + ) -> Result<()> + where + R: AsyncRead + Unpin + Send + 'static, + { + match timeout(frame_read_timeout, client_reader.read_exact(buf)).await { + Ok(Ok(_)) => Ok(()), + Ok(Err(e)) => Err(ProxyError::Io(e)), + Err(_) => Err(ProxyError::Io(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "middle-relay client frame read timeout", + ))), + } + } + loop { let (len, quickack, raw_len_bytes) = match proto_tag { ProtoTag::Abridged => { let mut first = [0u8; 1]; - match client_reader.read_exact(&mut first).await { - Ok(_) => {} - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ProxyError::Io(e)), + match read_exact_with_timeout(client_reader, &mut first, frame_read_timeout).await { + Ok(()) => {} + Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return Ok(None); + } + Err(e) => return Err(e), } let quickack = (first[0] & 0x80) != 0; let len_words = if (first[0] & 0x7f) == 0x7f { let mut ext = [0u8; 3]; - client_reader - .read_exact(&mut ext) - .await - .map_err(ProxyError::Io)?; + read_exact_with_timeout(client_reader, &mut ext, frame_read_timeout).await?; u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize } else { (first[0] & 0x7f) as usize @@ -707,10 +727,12 @@ where } ProtoTag::Intermediate | ProtoTag::Secure => { let mut len_buf = [0u8; 4]; - match client_reader.read_exact(&mut len_buf).await { - Ok(_) => {} - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ProxyError::Io(e)), + match read_exact_with_timeout(client_reader, &mut len_buf, frame_read_timeout).await { + Ok(()) => {} + Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + return Ok(None); + } + Err(e) => return Err(e), } let quickack = (len_buf[3] & 0x80) != 0; ( @@ -769,10 +791,8 @@ where let chunk_len = remaining.min(chunk_cap); let mut chunk = buffer_pool.get(); chunk.resize(chunk_len, 0); - client_reader - .read_exact(&mut chunk[..chunk_len]) - .await - .map_err(ProxyError::Io)?; + read_exact_with_timeout(client_reader, &mut chunk[..chunk_len], frame_read_timeout) + .await?; payload.extend_from_slice(&chunk[..chunk_len]); remaining -= chunk_len; } diff --git a/src/proxy/middle_relay_security_tests.rs b/src/proxy/middle_relay_security_tests.rs index d7d1243..a2f89f8 100644 --- a/src/proxy/middle_relay_security_tests.rs +++ b/src/proxy/middle_relay_security_tests.rs @@ -1,4 +1,12 @@ use super::*; +use crate::crypto::AesCtr; +use crate::stats::Stats; +use crate::stream::{BufferPool, CryptoReader}; +use std::net::SocketAddr; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use tokio::io::AsyncWriteExt; +use tokio::io::duplex; use tokio::time::{Duration as TokioDuration, timeout}; #[test] @@ -101,3 +109,93 @@ fn desync_dedup_cache_is_bounded() { "already tracked key inside dedup window must stay suppressed" ); } + +fn make_forensics_state() -> RelayForensicsState { + RelayForensicsState { + trace_id: 1, + conn_id: 2, + user: "test-user".to_string(), + peer: "127.0.0.1:50000".parse::().unwrap(), + peer_hash: 3, + started_at: Instant::now(), + bytes_c2me: 0, + bytes_me2c: Arc::new(AtomicU64::new(0)), + desync_all_full: false, + } +} + +fn make_crypto_reader(reader: tokio::io::DuplexStream) -> CryptoReader { + let key = [0u8; 32]; + let iv = 0u128; + CryptoReader::new(reader, AesCtr::new(&key, iv)) +} + +fn encrypt_for_reader(plaintext: &[u8]) -> Vec { + let key = [0u8; 32]; + let iv = 0u128; + let mut cipher = AesCtr::new(&key, iv); + cipher.encrypt(plaintext) +} + +#[tokio::test] +async fn read_client_payload_times_out_on_header_stall() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + let (reader, _writer) = duplex(1024); + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + TokioDuration::from_millis(25), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut), + "stalled header read must time out" + ); +} + +#[tokio::test] +async fn read_client_payload_times_out_on_payload_stall() { + let _guard = desync_dedup_test_lock() + .lock() + .expect("middle relay test lock must be available"); + let (reader, mut writer) = duplex(1024); + let encrypted_len = encrypt_for_reader(&[8, 0, 0, 0]); + writer.write_all(&encrypted_len).await.unwrap(); + + let mut crypto_reader = make_crypto_reader(reader); + let buffer_pool = Arc::new(BufferPool::new()); + let stats = Stats::new(); + let forensics = make_forensics_state(); + let mut frame_counter = 0; + + let result = read_client_payload( + &mut crypto_reader, + ProtoTag::Intermediate, + 1024, + TokioDuration::from_millis(25), + &buffer_pool, + &forensics, + &mut frame_counter, + &stats, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut), + "stalled payload body read must time out" + ); +}