mirror of
https://github.com/telemt/telemt.git
synced 2026-04-25 14:34:10 +03:00
Format
This commit is contained in:
@@ -38,16 +38,13 @@ use bytes::{Bytes, BytesMut};
|
||||
use std::io::{self, Error, ErrorKind, Result};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
|
||||
use super::state::{HeaderBuffer, StreamState, WriteBuffer, YieldBuffer};
|
||||
use crate::protocol::constants::{
|
||||
MAX_TLS_PLAINTEXT_SIZE,
|
||||
TLS_VERSION,
|
||||
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER,
|
||||
TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT,
|
||||
MAX_TLS_CIPHERTEXT_SIZE,
|
||||
MAX_TLS_CIPHERTEXT_SIZE, MAX_TLS_PLAINTEXT_SIZE, TLS_RECORD_ALERT, TLS_RECORD_APPLICATION,
|
||||
TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, TLS_VERSION,
|
||||
};
|
||||
use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer};
|
||||
|
||||
// ============= Constants =============
|
||||
|
||||
@@ -84,7 +81,11 @@ impl TlsRecordHeader {
|
||||
let record_type = header[0];
|
||||
let version = [header[1], header[2]];
|
||||
let length = u16::from_be_bytes([header[3], header[4]]);
|
||||
Some(Self { record_type, version, length })
|
||||
Some(Self {
|
||||
record_type,
|
||||
version,
|
||||
length,
|
||||
})
|
||||
}
|
||||
|
||||
/// Validate the header.
|
||||
@@ -111,8 +112,7 @@ impl TlsRecordHeader {
|
||||
ErrorKind::InvalidData,
|
||||
format!(
|
||||
"invalid TLS version for record type 0x{:02x}: {:02x?}",
|
||||
self.record_type,
|
||||
self.version
|
||||
self.record_type, self.version
|
||||
),
|
||||
));
|
||||
}
|
||||
@@ -129,8 +129,7 @@ impl TlsRecordHeader {
|
||||
ErrorKind::InvalidData,
|
||||
format!(
|
||||
"invalid TLS application data length: {} (min 1, max {})",
|
||||
len,
|
||||
MAX_TLS_CIPHERTEXT_SIZE
|
||||
len, MAX_TLS_CIPHERTEXT_SIZE
|
||||
),
|
||||
));
|
||||
}
|
||||
@@ -160,8 +159,7 @@ impl TlsRecordHeader {
|
||||
ErrorKind::InvalidData,
|
||||
format!(
|
||||
"invalid TLS handshake length: {} (min 4, max {})",
|
||||
len,
|
||||
MAX_TLS_PLAINTEXT_SIZE
|
||||
len, MAX_TLS_PLAINTEXT_SIZE
|
||||
),
|
||||
));
|
||||
}
|
||||
@@ -212,14 +210,10 @@ enum TlsReaderState {
|
||||
},
|
||||
|
||||
/// Have buffered data ready to yield to caller
|
||||
Yielding {
|
||||
buffer: YieldBuffer,
|
||||
},
|
||||
Yielding { buffer: YieldBuffer },
|
||||
|
||||
/// Stream encountered an error and cannot be used
|
||||
Poisoned {
|
||||
error: Option<io::Error>,
|
||||
},
|
||||
Poisoned { error: Option<io::Error> },
|
||||
}
|
||||
|
||||
impl StreamState for TlsReaderState {
|
||||
@@ -287,28 +281,41 @@ pub struct FakeTlsReader<R> {
|
||||
|
||||
impl<R> FakeTlsReader<R> {
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self { upstream, state: TlsReaderState::Idle }
|
||||
Self {
|
||||
upstream,
|
||||
state: TlsReaderState::Idle,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_ref(&self) -> &R { &self.upstream }
|
||||
pub fn get_mut(&mut self) -> &mut R { &mut self.upstream }
|
||||
pub fn into_inner(self) -> R { self.upstream }
|
||||
pub fn get_ref(&self) -> &R {
|
||||
&self.upstream
|
||||
}
|
||||
pub fn get_mut(&mut self) -> &mut R {
|
||||
&mut self.upstream
|
||||
}
|
||||
pub fn into_inner(self) -> R {
|
||||
self.upstream
|
||||
}
|
||||
|
||||
pub fn into_inner_with_pending_plaintext(mut self) -> (R, Vec<u8>) {
|
||||
let pending = match std::mem::replace(&mut self.state, TlsReaderState::Idle) {
|
||||
TlsReaderState::Yielding { buffer } => buffer.as_slice().to_vec(),
|
||||
TlsReaderState::ReadingBody { record_type, buffer, .. }
|
||||
if record_type == TLS_RECORD_APPLICATION =>
|
||||
{
|
||||
buffer.to_vec()
|
||||
}
|
||||
TlsReaderState::ReadingBody {
|
||||
record_type,
|
||||
buffer,
|
||||
..
|
||||
} if record_type == TLS_RECORD_APPLICATION => buffer.to_vec(),
|
||||
_ => Vec::new(),
|
||||
};
|
||||
(self.upstream, pending)
|
||||
}
|
||||
|
||||
pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() }
|
||||
pub fn state_name(&self) -> &'static str { self.state.state_name() }
|
||||
pub fn is_poisoned(&self) -> bool {
|
||||
self.state.is_poisoned()
|
||||
}
|
||||
pub fn state_name(&self) -> &'static str {
|
||||
self.state.state_name()
|
||||
}
|
||||
|
||||
fn poison(&mut self, error: io::Error) {
|
||||
self.state = TlsReaderState::Poisoned { error: Some(error) };
|
||||
@@ -316,9 +323,9 @@ impl<R> FakeTlsReader<R> {
|
||||
|
||||
fn take_poison_error(&mut self) -> io::Error {
|
||||
match &mut self.state {
|
||||
TlsReaderState::Poisoned { error } => error.take().unwrap_or_else(|| {
|
||||
io::Error::other("stream previously poisoned")
|
||||
}),
|
||||
TlsReaderState::Poisoned { error } => error
|
||||
.take()
|
||||
.unwrap_or_else(|| io::Error::other("stream previously poisoned")),
|
||||
_ => io::Error::other("stream not poisoned"),
|
||||
}
|
||||
}
|
||||
@@ -353,9 +360,8 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
// Poisoned state: always return the stored error
|
||||
TlsReaderState::Poisoned { error } => {
|
||||
this.state = TlsReaderState::Poisoned { error: None };
|
||||
let err = error.unwrap_or_else(|| {
|
||||
io::Error::other("stream previously poisoned")
|
||||
});
|
||||
let err =
|
||||
error.unwrap_or_else(|| io::Error::other("stream previously poisoned"));
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
|
||||
@@ -428,12 +434,20 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
}
|
||||
|
||||
// Read TLS payload
|
||||
TlsReaderState::ReadingBody { record_type, length, mut buffer } => {
|
||||
TlsReaderState::ReadingBody {
|
||||
record_type,
|
||||
length,
|
||||
mut buffer,
|
||||
} => {
|
||||
let result = poll_read_body(&mut this.upstream, cx, &mut buffer, length);
|
||||
|
||||
match result {
|
||||
BodyPollResult::Pending => {
|
||||
this.state = TlsReaderState::ReadingBody { record_type, length, buffer };
|
||||
this.state = TlsReaderState::ReadingBody {
|
||||
record_type,
|
||||
length,
|
||||
buffer,
|
||||
};
|
||||
return Poll::Pending;
|
||||
}
|
||||
BodyPollResult::Error(e) => {
|
||||
@@ -469,7 +483,10 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
|
||||
TLS_RECORD_HANDSHAKE => {
|
||||
// After FakeTLS handshake is done, we do not expect any Handshake records.
|
||||
let err = Error::new(ErrorKind::InvalidData, "unexpected TLS handshake record");
|
||||
let err = Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"unexpected TLS handshake record",
|
||||
);
|
||||
this.poison(Error::new(err.kind(), err.to_string()));
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
@@ -528,7 +545,10 @@ fn poll_read_header<R: AsyncRead + Unpin>(
|
||||
let header_bytes = *header.as_array();
|
||||
match TlsRecordHeader::parse(&header_bytes) {
|
||||
Some(h) => HeaderPollResult::Complete(h),
|
||||
None => HeaderPollResult::Error(Error::new(ErrorKind::InvalidData, "failed to parse TLS header")),
|
||||
None => HeaderPollResult::Error(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"failed to parse TLS header",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -614,9 +634,7 @@ enum TlsWriterState {
|
||||
},
|
||||
|
||||
/// Stream encountered an error and cannot be used
|
||||
Poisoned {
|
||||
error: Option<io::Error>,
|
||||
},
|
||||
Poisoned { error: Option<io::Error> },
|
||||
}
|
||||
|
||||
impl StreamState for TlsWriterState {
|
||||
@@ -652,15 +670,28 @@ pub struct FakeTlsWriter<W> {
|
||||
|
||||
impl<W> FakeTlsWriter<W> {
|
||||
pub fn new(upstream: W) -> Self {
|
||||
Self { upstream, state: TlsWriterState::Idle }
|
||||
Self {
|
||||
upstream,
|
||||
state: TlsWriterState::Idle,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_ref(&self) -> &W { &self.upstream }
|
||||
pub fn get_mut(&mut self) -> &mut W { &mut self.upstream }
|
||||
pub fn into_inner(self) -> W { self.upstream }
|
||||
pub fn get_ref(&self) -> &W {
|
||||
&self.upstream
|
||||
}
|
||||
pub fn get_mut(&mut self) -> &mut W {
|
||||
&mut self.upstream
|
||||
}
|
||||
pub fn into_inner(self) -> W {
|
||||
self.upstream
|
||||
}
|
||||
|
||||
pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() }
|
||||
pub fn state_name(&self) -> &'static str { self.state.state_name() }
|
||||
pub fn is_poisoned(&self) -> bool {
|
||||
self.state.is_poisoned()
|
||||
}
|
||||
pub fn state_name(&self) -> &'static str {
|
||||
self.state.state_name()
|
||||
}
|
||||
|
||||
pub fn has_pending(&self) -> bool {
|
||||
matches!(&self.state, TlsWriterState::WritingRecord { record, .. } if !record.is_empty())
|
||||
@@ -672,9 +703,9 @@ impl<W> FakeTlsWriter<W> {
|
||||
|
||||
fn take_poison_error(&mut self) -> io::Error {
|
||||
match &mut self.state {
|
||||
TlsWriterState::Poisoned { error } => error.take().unwrap_or_else(|| {
|
||||
io::Error::other("stream previously poisoned")
|
||||
}),
|
||||
TlsWriterState::Poisoned { error } => error
|
||||
.take()
|
||||
.unwrap_or_else(|| io::Error::other("stream previously poisoned")),
|
||||
_ => io::Error::other("stream not poisoned"),
|
||||
}
|
||||
}
|
||||
@@ -725,11 +756,7 @@ impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// Take ownership of state to avoid borrow conflicts.
|
||||
@@ -738,17 +765,21 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
match state {
|
||||
TlsWriterState::Poisoned { error } => {
|
||||
this.state = TlsWriterState::Poisoned { error: None };
|
||||
let err = error.unwrap_or_else(|| {
|
||||
Error::other("stream previously poisoned")
|
||||
});
|
||||
let err = error.unwrap_or_else(|| Error::other("stream previously poisoned"));
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
|
||||
TlsWriterState::WritingRecord { mut record, payload_size } => {
|
||||
TlsWriterState::WritingRecord {
|
||||
mut record,
|
||||
payload_size,
|
||||
} => {
|
||||
// Finish writing previous record before accepting new bytes.
|
||||
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
|
||||
FlushResult::Pending => {
|
||||
this.state = TlsWriterState::WritingRecord { record, payload_size };
|
||||
this.state = TlsWriterState::WritingRecord {
|
||||
record,
|
||||
payload_size,
|
||||
};
|
||||
return Poll::Pending;
|
||||
}
|
||||
FlushResult::Error(e) => {
|
||||
@@ -780,9 +811,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
let record_data = Self::build_record(chunk);
|
||||
|
||||
match Pin::new(&mut this.upstream).poll_write(cx, &record_data) {
|
||||
Poll::Ready(Ok(n)) if n == record_data.len() => {
|
||||
Poll::Ready(Ok(chunk_size))
|
||||
}
|
||||
Poll::Ready(Ok(n)) if n == record_data.len() => Poll::Ready(Ok(chunk_size)),
|
||||
|
||||
Poll::Ready(Ok(n)) => {
|
||||
// Partial write of the record: store remainder.
|
||||
@@ -827,27 +856,29 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
match state {
|
||||
TlsWriterState::Poisoned { error } => {
|
||||
this.state = TlsWriterState::Poisoned { error: None };
|
||||
let err = error.unwrap_or_else(|| {
|
||||
Error::other("stream previously poisoned")
|
||||
});
|
||||
let err = error.unwrap_or_else(|| Error::other("stream previously poisoned"));
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
|
||||
TlsWriterState::WritingRecord { mut record, payload_size } => {
|
||||
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
|
||||
FlushResult::Pending => {
|
||||
this.state = TlsWriterState::WritingRecord { record, payload_size };
|
||||
return Poll::Pending;
|
||||
}
|
||||
FlushResult::Error(e) => {
|
||||
this.poison(Error::new(e.kind(), e.to_string()));
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
FlushResult::Complete(_) => {
|
||||
this.state = TlsWriterState::Idle;
|
||||
}
|
||||
TlsWriterState::WritingRecord {
|
||||
mut record,
|
||||
payload_size,
|
||||
} => match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
|
||||
FlushResult::Pending => {
|
||||
this.state = TlsWriterState::WritingRecord {
|
||||
record,
|
||||
payload_size,
|
||||
};
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
FlushResult::Error(e) => {
|
||||
this.poison(Error::new(e.kind(), e.to_string()));
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
FlushResult::Complete(_) => {
|
||||
this.state = TlsWriterState::Idle;
|
||||
}
|
||||
},
|
||||
|
||||
TlsWriterState::Idle => {
|
||||
this.state = TlsWriterState::Idle;
|
||||
@@ -863,7 +894,10 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
|
||||
|
||||
match state {
|
||||
TlsWriterState::WritingRecord { mut record, payload_size: _ } => {
|
||||
TlsWriterState::WritingRecord {
|
||||
mut record,
|
||||
payload_size: _,
|
||||
} => {
|
||||
// Best-effort flush (do not block shutdown forever).
|
||||
let _ = Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record);
|
||||
this.state = TlsWriterState::Idle;
|
||||
@@ -905,10 +939,10 @@ mod size_adversarial_tests;
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::VecDeque;
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
|
||||
// ============= Test Helpers =============
|
||||
|
||||
|
||||
/// Build a valid TLS Application Data record
|
||||
fn build_tls_record(data: &[u8]) -> Vec<u8> {
|
||||
let mut record = vec![
|
||||
@@ -921,24 +955,25 @@ mod tests {
|
||||
record.extend_from_slice(data);
|
||||
record
|
||||
}
|
||||
|
||||
|
||||
/// Build a Change Cipher Spec record
|
||||
fn build_ccs_record() -> Vec<u8> {
|
||||
vec![
|
||||
TLS_RECORD_CHANGE_CIPHER,
|
||||
TLS_VERSION[0],
|
||||
TLS_VERSION[1],
|
||||
0x00, 0x01, // length = 1
|
||||
0x01, // CCS byte
|
||||
0x00,
|
||||
0x01, // length = 1
|
||||
0x01, // CCS byte
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
/// Mock reader that returns data in chunks
|
||||
struct ChunkedReader {
|
||||
data: VecDeque<u8>,
|
||||
chunk_size: usize,
|
||||
}
|
||||
|
||||
|
||||
impl ChunkedReader {
|
||||
fn new(data: &[u8], chunk_size: usize) -> Self {
|
||||
Self {
|
||||
@@ -947,7 +982,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl AsyncRead for ChunkedReader {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
@@ -957,92 +992,92 @@ mod tests {
|
||||
if self.data.is_empty() {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
|
||||
let to_read = self.chunk_size.min(self.data.len()).min(buf.remaining());
|
||||
for _ in 0..to_read {
|
||||
if let Some(byte) = self.data.pop_front() {
|
||||
buf.put_slice(&[byte]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ============= FakeTlsReader Tests =============
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_single_record() {
|
||||
let payload = b"Hello, TLS!";
|
||||
let record = build_tls_record(payload);
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&record, 100);
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
|
||||
assert_eq!(&buf[..], payload);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_multiple_records() {
|
||||
let payload1 = b"First record";
|
||||
let payload2 = b"Second record";
|
||||
|
||||
|
||||
let mut data = build_tls_record(payload1);
|
||||
data.extend_from_slice(&build_tls_record(payload2));
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&data, 100);
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let buf1 = tls_reader.read_exact(payload1.len()).await.unwrap();
|
||||
assert_eq!(&buf1[..], payload1);
|
||||
|
||||
|
||||
let buf2 = tls_reader.read_exact(payload2.len()).await.unwrap();
|
||||
assert_eq!(&buf2[..], payload2);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_partial_header() {
|
||||
// Read header byte by byte
|
||||
let payload = b"Test";
|
||||
let record = build_tls_record(payload);
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&record, 1); // 1 byte at a time!
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
|
||||
|
||||
|
||||
assert_eq!(&buf[..], payload);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_partial_body() {
|
||||
let payload = b"This is a longer payload that will be read in parts";
|
||||
let record = build_tls_record(payload);
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&record, 7); // Awkward chunk size
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
|
||||
|
||||
|
||||
assert_eq!(&buf[..], payload);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_skip_ccs() {
|
||||
// CCS record followed by application data
|
||||
let mut data = build_ccs_record();
|
||||
let payload = b"After CCS";
|
||||
data.extend_from_slice(&build_tls_record(payload));
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&data, 100);
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
|
||||
|
||||
|
||||
assert_eq!(&buf[..], payload);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_multiple_ccs() {
|
||||
// Multiple CCS records
|
||||
@@ -1050,127 +1085,127 @@ mod tests {
|
||||
data.extend_from_slice(&build_ccs_record());
|
||||
let payload = b"After multiple CCS";
|
||||
data.extend_from_slice(&build_tls_record(payload));
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&data, 3); // Small chunks
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
|
||||
|
||||
|
||||
assert_eq!(&buf[..], payload);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_eof() {
|
||||
let reader = ChunkedReader::new(&[], 100);
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let mut buf = vec![0u8; 10];
|
||||
let read = tls_reader.read(&mut buf).await.unwrap();
|
||||
|
||||
|
||||
assert_eq!(read, 0);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_state_names() {
|
||||
let reader = ChunkedReader::new(&[], 100);
|
||||
let tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
assert_eq!(tls_reader.state_name(), "Idle");
|
||||
assert!(!tls_reader.is_poisoned());
|
||||
}
|
||||
|
||||
|
||||
// ============= FakeTlsWriter Tests =============
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_writer_single_write() {
|
||||
let (client, mut server) = duplex(4096);
|
||||
let mut writer = FakeTlsWriter::new(client);
|
||||
|
||||
|
||||
let payload = b"Hello, TLS!";
|
||||
writer.write_all(payload).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
|
||||
// Read the TLS record from server
|
||||
let mut header = [0u8; 5];
|
||||
server.read_exact(&mut header).await.unwrap();
|
||||
|
||||
|
||||
assert_eq!(header[0], TLS_RECORD_APPLICATION);
|
||||
assert_eq!(&header[1..3], &TLS_VERSION);
|
||||
|
||||
|
||||
let length = u16::from_be_bytes([header[3], header[4]]) as usize;
|
||||
assert_eq!(length, payload.len());
|
||||
|
||||
|
||||
let mut body = vec![0u8; length];
|
||||
server.read_exact(&mut body).await.unwrap();
|
||||
assert_eq!(&body, payload);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_writer_large_data_chunking() {
|
||||
let (client, mut server) = duplex(65536);
|
||||
let mut writer = FakeTlsWriter::new(client);
|
||||
|
||||
|
||||
// Write data larger than MAX_TLS_PAYLOAD
|
||||
let payload: Vec<u8> = (0..20000).map(|i| (i % 256) as u8).collect();
|
||||
writer.write_all(&payload).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
|
||||
// Read back - should be multiple records
|
||||
let mut received = Vec::new();
|
||||
let mut records_count = 0;
|
||||
|
||||
|
||||
while received.len() < payload.len() {
|
||||
let mut header = [0u8; 5];
|
||||
if server.read_exact(&mut header).await.is_err() {
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
assert_eq!(header[0], TLS_RECORD_APPLICATION);
|
||||
records_count += 1;
|
||||
|
||||
|
||||
let length = u16::from_be_bytes([header[3], header[4]]) as usize;
|
||||
assert!(length <= MAX_TLS_PAYLOAD);
|
||||
|
||||
|
||||
let mut body = vec![0u8; length];
|
||||
server.read_exact(&mut body).await.unwrap();
|
||||
received.extend_from_slice(&body);
|
||||
}
|
||||
|
||||
|
||||
assert_eq!(received, payload);
|
||||
assert!(records_count >= 2); // Should have multiple records
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_stream_roundtrip() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
|
||||
let mut writer = FakeTlsWriter::new(client);
|
||||
let mut reader = FakeTlsReader::new(server);
|
||||
|
||||
|
||||
let original = b"Hello, fake TLS!";
|
||||
writer.write_all_tls(original).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
|
||||
let received = reader.read_exact(original.len()).await.unwrap();
|
||||
assert_eq!(&received[..], original);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_stream_roundtrip_large() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
|
||||
let mut writer = FakeTlsWriter::new(client);
|
||||
let mut reader = FakeTlsReader::new(server);
|
||||
|
||||
|
||||
let original: Vec<u8> = (0..50000).map(|i| (i % 256) as u8).collect();
|
||||
|
||||
|
||||
// Write in background
|
||||
let write_data = original.clone();
|
||||
let write_handle = tokio::spawn(async move {
|
||||
writer.write_all_tls(&write_data).await.unwrap();
|
||||
writer.shutdown().await.unwrap();
|
||||
});
|
||||
|
||||
|
||||
// Read
|
||||
let mut received = Vec::new();
|
||||
let mut buf = vec![0u8; 1024];
|
||||
@@ -1181,87 +1216,95 @@ mod tests {
|
||||
}
|
||||
received.extend_from_slice(&buf[..n]);
|
||||
}
|
||||
|
||||
|
||||
write_handle.await.unwrap();
|
||||
assert_eq!(received, original);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_writer_state_names() {
|
||||
let (client, _server) = duplex(4096);
|
||||
let writer = FakeTlsWriter::new(client);
|
||||
|
||||
|
||||
assert_eq!(writer.state_name(), "Idle");
|
||||
assert!(!writer.is_poisoned());
|
||||
assert!(!writer.has_pending());
|
||||
}
|
||||
|
||||
|
||||
// ============= Error Handling Tests =============
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_invalid_version() {
|
||||
let invalid_record = vec![
|
||||
TLS_RECORD_APPLICATION,
|
||||
0x04, 0x00, // Invalid version
|
||||
0x00, 0x05, // length = 5
|
||||
0x01, 0x02, 0x03, 0x04, 0x05,
|
||||
0x04,
|
||||
0x00, // Invalid version
|
||||
0x00,
|
||||
0x05, // length = 5
|
||||
0x01,
|
||||
0x02,
|
||||
0x03,
|
||||
0x04,
|
||||
0x05,
|
||||
];
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&invalid_record, 100);
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let mut buf = vec![0u8; 5];
|
||||
let result = tls_reader.read(&mut buf).await;
|
||||
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(tls_reader.is_poisoned());
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_unexpected_eof_header() {
|
||||
// Partial header
|
||||
let partial = vec![TLS_RECORD_APPLICATION, 0x03];
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&partial, 100);
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let mut buf = vec![0u8; 10];
|
||||
let result = tls_reader.read(&mut buf).await;
|
||||
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_unexpected_eof_body() {
|
||||
// Valid header but truncated body
|
||||
let mut record = vec![
|
||||
TLS_RECORD_APPLICATION,
|
||||
TLS_VERSION[0], TLS_VERSION[1],
|
||||
0x00, 0x10, // length = 16
|
||||
TLS_VERSION[0],
|
||||
TLS_VERSION[1],
|
||||
0x00,
|
||||
0x10, // length = 16
|
||||
];
|
||||
record.extend_from_slice(&[0u8; 8]); // Only 8 bytes of body
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&record, 100);
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let mut buf = vec![0u8; 16];
|
||||
let result = tls_reader.read(&mut buf).await;
|
||||
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
|
||||
// ============= Header Parsing Tests =============
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_tls_record_header_parse() {
|
||||
let header = [0x17, 0x03, 0x03, 0x01, 0x00];
|
||||
let parsed = TlsRecordHeader::parse(&header).unwrap();
|
||||
|
||||
|
||||
assert_eq!(parsed.record_type, TLS_RECORD_APPLICATION);
|
||||
assert_eq!(parsed.version, TLS_VERSION);
|
||||
assert_eq!(parsed.length, 256);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_tls_record_header_validate() {
|
||||
let valid = TlsRecordHeader {
|
||||
@@ -1270,14 +1313,14 @@ mod tests {
|
||||
length: 100,
|
||||
};
|
||||
assert!(valid.validate().is_ok());
|
||||
|
||||
|
||||
let invalid_version = TlsRecordHeader {
|
||||
record_type: TLS_RECORD_APPLICATION,
|
||||
version: [0x04, 0x00],
|
||||
length: 100,
|
||||
};
|
||||
assert!(invalid_version.validate().is_err());
|
||||
|
||||
|
||||
let too_large = TlsRecordHeader {
|
||||
record_type: TLS_RECORD_APPLICATION,
|
||||
version: TLS_VERSION,
|
||||
@@ -1285,7 +1328,7 @@ mod tests {
|
||||
};
|
||||
assert!(too_large.validate().is_err());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_tls_record_header_to_bytes() {
|
||||
let header = TlsRecordHeader {
|
||||
@@ -1293,7 +1336,7 @@ mod tests {
|
||||
version: TLS_VERSION,
|
||||
length: 0x1234,
|
||||
};
|
||||
|
||||
|
||||
let bytes = header.to_bytes();
|
||||
assert_eq!(bytes, [0x17, 0x03, 0x03, 0x12, 0x34]);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user