This commit is contained in:
Alexey
2026-03-21 15:45:29 +03:00
parent 7a8f946029
commit d7bbb376c9
154 changed files with 6194 additions and 3775 deletions

View File

@@ -38,16 +38,13 @@ use bytes::{Bytes, BytesMut};
use std::io::{self, Error, ErrorKind, Result};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use super::state::{HeaderBuffer, StreamState, WriteBuffer, YieldBuffer};
use crate::protocol::constants::{
MAX_TLS_PLAINTEXT_SIZE,
TLS_VERSION,
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER,
TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT,
MAX_TLS_CIPHERTEXT_SIZE,
MAX_TLS_CIPHERTEXT_SIZE, MAX_TLS_PLAINTEXT_SIZE, TLS_RECORD_ALERT, TLS_RECORD_APPLICATION,
TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, TLS_VERSION,
};
use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer};
// ============= Constants =============
@@ -84,7 +81,11 @@ impl TlsRecordHeader {
let record_type = header[0];
let version = [header[1], header[2]];
let length = u16::from_be_bytes([header[3], header[4]]);
Some(Self { record_type, version, length })
Some(Self {
record_type,
version,
length,
})
}
/// Validate the header.
@@ -111,8 +112,7 @@ impl TlsRecordHeader {
ErrorKind::InvalidData,
format!(
"invalid TLS version for record type 0x{:02x}: {:02x?}",
self.record_type,
self.version
self.record_type, self.version
),
));
}
@@ -129,8 +129,7 @@ impl TlsRecordHeader {
ErrorKind::InvalidData,
format!(
"invalid TLS application data length: {} (min 1, max {})",
len,
MAX_TLS_CIPHERTEXT_SIZE
len, MAX_TLS_CIPHERTEXT_SIZE
),
));
}
@@ -160,8 +159,7 @@ impl TlsRecordHeader {
ErrorKind::InvalidData,
format!(
"invalid TLS handshake length: {} (min 4, max {})",
len,
MAX_TLS_PLAINTEXT_SIZE
len, MAX_TLS_PLAINTEXT_SIZE
),
));
}
@@ -212,14 +210,10 @@ enum TlsReaderState {
},
/// Have buffered data ready to yield to caller
Yielding {
buffer: YieldBuffer,
},
Yielding { buffer: YieldBuffer },
/// Stream encountered an error and cannot be used
Poisoned {
error: Option<io::Error>,
},
Poisoned { error: Option<io::Error> },
}
impl StreamState for TlsReaderState {
@@ -287,28 +281,41 @@ pub struct FakeTlsReader<R> {
impl<R> FakeTlsReader<R> {
pub fn new(upstream: R) -> Self {
Self { upstream, state: TlsReaderState::Idle }
Self {
upstream,
state: TlsReaderState::Idle,
}
}
pub fn get_ref(&self) -> &R { &self.upstream }
pub fn get_mut(&mut self) -> &mut R { &mut self.upstream }
pub fn into_inner(self) -> R { self.upstream }
pub fn get_ref(&self) -> &R {
&self.upstream
}
pub fn get_mut(&mut self) -> &mut R {
&mut self.upstream
}
pub fn into_inner(self) -> R {
self.upstream
}
pub fn into_inner_with_pending_plaintext(mut self) -> (R, Vec<u8>) {
let pending = match std::mem::replace(&mut self.state, TlsReaderState::Idle) {
TlsReaderState::Yielding { buffer } => buffer.as_slice().to_vec(),
TlsReaderState::ReadingBody { record_type, buffer, .. }
if record_type == TLS_RECORD_APPLICATION =>
{
buffer.to_vec()
}
TlsReaderState::ReadingBody {
record_type,
buffer,
..
} if record_type == TLS_RECORD_APPLICATION => buffer.to_vec(),
_ => Vec::new(),
};
(self.upstream, pending)
}
pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() }
pub fn state_name(&self) -> &'static str { self.state.state_name() }
pub fn is_poisoned(&self) -> bool {
self.state.is_poisoned()
}
pub fn state_name(&self) -> &'static str {
self.state.state_name()
}
fn poison(&mut self, error: io::Error) {
self.state = TlsReaderState::Poisoned { error: Some(error) };
@@ -316,9 +323,9 @@ impl<R> FakeTlsReader<R> {
fn take_poison_error(&mut self) -> io::Error {
match &mut self.state {
TlsReaderState::Poisoned { error } => error.take().unwrap_or_else(|| {
io::Error::other("stream previously poisoned")
}),
TlsReaderState::Poisoned { error } => error
.take()
.unwrap_or_else(|| io::Error::other("stream previously poisoned")),
_ => io::Error::other("stream not poisoned"),
}
}
@@ -353,9 +360,8 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
// Poisoned state: always return the stored error
TlsReaderState::Poisoned { error } => {
this.state = TlsReaderState::Poisoned { error: None };
let err = error.unwrap_or_else(|| {
io::Error::other("stream previously poisoned")
});
let err =
error.unwrap_or_else(|| io::Error::other("stream previously poisoned"));
return Poll::Ready(Err(err));
}
@@ -428,12 +434,20 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
}
// Read TLS payload
TlsReaderState::ReadingBody { record_type, length, mut buffer } => {
TlsReaderState::ReadingBody {
record_type,
length,
mut buffer,
} => {
let result = poll_read_body(&mut this.upstream, cx, &mut buffer, length);
match result {
BodyPollResult::Pending => {
this.state = TlsReaderState::ReadingBody { record_type, length, buffer };
this.state = TlsReaderState::ReadingBody {
record_type,
length,
buffer,
};
return Poll::Pending;
}
BodyPollResult::Error(e) => {
@@ -469,7 +483,10 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
TLS_RECORD_HANDSHAKE => {
// After FakeTLS handshake is done, we do not expect any Handshake records.
let err = Error::new(ErrorKind::InvalidData, "unexpected TLS handshake record");
let err = Error::new(
ErrorKind::InvalidData,
"unexpected TLS handshake record",
);
this.poison(Error::new(err.kind(), err.to_string()));
return Poll::Ready(Err(err));
}
@@ -528,7 +545,10 @@ fn poll_read_header<R: AsyncRead + Unpin>(
let header_bytes = *header.as_array();
match TlsRecordHeader::parse(&header_bytes) {
Some(h) => HeaderPollResult::Complete(h),
None => HeaderPollResult::Error(Error::new(ErrorKind::InvalidData, "failed to parse TLS header")),
None => HeaderPollResult::Error(Error::new(
ErrorKind::InvalidData,
"failed to parse TLS header",
)),
}
}
@@ -614,9 +634,7 @@ enum TlsWriterState {
},
/// Stream encountered an error and cannot be used
Poisoned {
error: Option<io::Error>,
},
Poisoned { error: Option<io::Error> },
}
impl StreamState for TlsWriterState {
@@ -652,15 +670,28 @@ pub struct FakeTlsWriter<W> {
impl<W> FakeTlsWriter<W> {
pub fn new(upstream: W) -> Self {
Self { upstream, state: TlsWriterState::Idle }
Self {
upstream,
state: TlsWriterState::Idle,
}
}
pub fn get_ref(&self) -> &W { &self.upstream }
pub fn get_mut(&mut self) -> &mut W { &mut self.upstream }
pub fn into_inner(self) -> W { self.upstream }
pub fn get_ref(&self) -> &W {
&self.upstream
}
pub fn get_mut(&mut self) -> &mut W {
&mut self.upstream
}
pub fn into_inner(self) -> W {
self.upstream
}
pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() }
pub fn state_name(&self) -> &'static str { self.state.state_name() }
pub fn is_poisoned(&self) -> bool {
self.state.is_poisoned()
}
pub fn state_name(&self) -> &'static str {
self.state.state_name()
}
pub fn has_pending(&self) -> bool {
matches!(&self.state, TlsWriterState::WritingRecord { record, .. } if !record.is_empty())
@@ -672,9 +703,9 @@ impl<W> FakeTlsWriter<W> {
fn take_poison_error(&mut self) -> io::Error {
match &mut self.state {
TlsWriterState::Poisoned { error } => error.take().unwrap_or_else(|| {
io::Error::other("stream previously poisoned")
}),
TlsWriterState::Poisoned { error } => error
.take()
.unwrap_or_else(|| io::Error::other("stream previously poisoned")),
_ => io::Error::other("stream not poisoned"),
}
}
@@ -725,11 +756,7 @@ impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
}
impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
let this = self.get_mut();
// Take ownership of state to avoid borrow conflicts.
@@ -738,17 +765,21 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
match state {
TlsWriterState::Poisoned { error } => {
this.state = TlsWriterState::Poisoned { error: None };
let err = error.unwrap_or_else(|| {
Error::other("stream previously poisoned")
});
let err = error.unwrap_or_else(|| Error::other("stream previously poisoned"));
return Poll::Ready(Err(err));
}
TlsWriterState::WritingRecord { mut record, payload_size } => {
TlsWriterState::WritingRecord {
mut record,
payload_size,
} => {
// Finish writing previous record before accepting new bytes.
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
FlushResult::Pending => {
this.state = TlsWriterState::WritingRecord { record, payload_size };
this.state = TlsWriterState::WritingRecord {
record,
payload_size,
};
return Poll::Pending;
}
FlushResult::Error(e) => {
@@ -780,9 +811,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
let record_data = Self::build_record(chunk);
match Pin::new(&mut this.upstream).poll_write(cx, &record_data) {
Poll::Ready(Ok(n)) if n == record_data.len() => {
Poll::Ready(Ok(chunk_size))
}
Poll::Ready(Ok(n)) if n == record_data.len() => Poll::Ready(Ok(chunk_size)),
Poll::Ready(Ok(n)) => {
// Partial write of the record: store remainder.
@@ -827,27 +856,29 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
match state {
TlsWriterState::Poisoned { error } => {
this.state = TlsWriterState::Poisoned { error: None };
let err = error.unwrap_or_else(|| {
Error::other("stream previously poisoned")
});
let err = error.unwrap_or_else(|| Error::other("stream previously poisoned"));
return Poll::Ready(Err(err));
}
TlsWriterState::WritingRecord { mut record, payload_size } => {
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
FlushResult::Pending => {
this.state = TlsWriterState::WritingRecord { record, payload_size };
return Poll::Pending;
}
FlushResult::Error(e) => {
this.poison(Error::new(e.kind(), e.to_string()));
return Poll::Ready(Err(e));
}
FlushResult::Complete(_) => {
this.state = TlsWriterState::Idle;
}
TlsWriterState::WritingRecord {
mut record,
payload_size,
} => match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
FlushResult::Pending => {
this.state = TlsWriterState::WritingRecord {
record,
payload_size,
};
return Poll::Pending;
}
}
FlushResult::Error(e) => {
this.poison(Error::new(e.kind(), e.to_string()));
return Poll::Ready(Err(e));
}
FlushResult::Complete(_) => {
this.state = TlsWriterState::Idle;
}
},
TlsWriterState::Idle => {
this.state = TlsWriterState::Idle;
@@ -863,7 +894,10 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
match state {
TlsWriterState::WritingRecord { mut record, payload_size: _ } => {
TlsWriterState::WritingRecord {
mut record,
payload_size: _,
} => {
// Best-effort flush (do not block shutdown forever).
let _ = Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record);
this.state = TlsWriterState::Idle;
@@ -905,10 +939,10 @@ mod size_adversarial_tests;
mod tests {
use super::*;
use std::collections::VecDeque;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
// ============= Test Helpers =============
/// Build a valid TLS Application Data record
fn build_tls_record(data: &[u8]) -> Vec<u8> {
let mut record = vec![
@@ -921,24 +955,25 @@ mod tests {
record.extend_from_slice(data);
record
}
/// Build a Change Cipher Spec record
fn build_ccs_record() -> Vec<u8> {
vec![
TLS_RECORD_CHANGE_CIPHER,
TLS_VERSION[0],
TLS_VERSION[1],
0x00, 0x01, // length = 1
0x01, // CCS byte
0x00,
0x01, // length = 1
0x01, // CCS byte
]
}
/// Mock reader that returns data in chunks
struct ChunkedReader {
data: VecDeque<u8>,
chunk_size: usize,
}
impl ChunkedReader {
fn new(data: &[u8], chunk_size: usize) -> Self {
Self {
@@ -947,7 +982,7 @@ mod tests {
}
}
}
impl AsyncRead for ChunkedReader {
fn poll_read(
mut self: Pin<&mut Self>,
@@ -957,92 +992,92 @@ mod tests {
if self.data.is_empty() {
return Poll::Ready(Ok(()));
}
let to_read = self.chunk_size.min(self.data.len()).min(buf.remaining());
for _ in 0..to_read {
if let Some(byte) = self.data.pop_front() {
buf.put_slice(&[byte]);
}
}
Poll::Ready(Ok(()))
}
}
// ============= FakeTlsReader Tests =============
#[tokio::test]
async fn test_tls_reader_single_record() {
let payload = b"Hello, TLS!";
let record = build_tls_record(payload);
let reader = ChunkedReader::new(&record, 100);
let mut tls_reader = FakeTlsReader::new(reader);
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
assert_eq!(&buf[..], payload);
}
#[tokio::test]
async fn test_tls_reader_multiple_records() {
let payload1 = b"First record";
let payload2 = b"Second record";
let mut data = build_tls_record(payload1);
data.extend_from_slice(&build_tls_record(payload2));
let reader = ChunkedReader::new(&data, 100);
let mut tls_reader = FakeTlsReader::new(reader);
let buf1 = tls_reader.read_exact(payload1.len()).await.unwrap();
assert_eq!(&buf1[..], payload1);
let buf2 = tls_reader.read_exact(payload2.len()).await.unwrap();
assert_eq!(&buf2[..], payload2);
}
#[tokio::test]
async fn test_tls_reader_partial_header() {
// Read header byte by byte
let payload = b"Test";
let record = build_tls_record(payload);
let reader = ChunkedReader::new(&record, 1); // 1 byte at a time!
let mut tls_reader = FakeTlsReader::new(reader);
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
assert_eq!(&buf[..], payload);
}
#[tokio::test]
async fn test_tls_reader_partial_body() {
let payload = b"This is a longer payload that will be read in parts";
let record = build_tls_record(payload);
let reader = ChunkedReader::new(&record, 7); // Awkward chunk size
let mut tls_reader = FakeTlsReader::new(reader);
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
assert_eq!(&buf[..], payload);
}
#[tokio::test]
async fn test_tls_reader_skip_ccs() {
// CCS record followed by application data
let mut data = build_ccs_record();
let payload = b"After CCS";
data.extend_from_slice(&build_tls_record(payload));
let reader = ChunkedReader::new(&data, 100);
let mut tls_reader = FakeTlsReader::new(reader);
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
assert_eq!(&buf[..], payload);
}
#[tokio::test]
async fn test_tls_reader_multiple_ccs() {
// Multiple CCS records
@@ -1050,127 +1085,127 @@ mod tests {
data.extend_from_slice(&build_ccs_record());
let payload = b"After multiple CCS";
data.extend_from_slice(&build_tls_record(payload));
let reader = ChunkedReader::new(&data, 3); // Small chunks
let mut tls_reader = FakeTlsReader::new(reader);
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
assert_eq!(&buf[..], payload);
}
#[tokio::test]
async fn test_tls_reader_eof() {
let reader = ChunkedReader::new(&[], 100);
let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; 10];
let read = tls_reader.read(&mut buf).await.unwrap();
assert_eq!(read, 0);
}
#[tokio::test]
async fn test_tls_reader_state_names() {
let reader = ChunkedReader::new(&[], 100);
let tls_reader = FakeTlsReader::new(reader);
assert_eq!(tls_reader.state_name(), "Idle");
assert!(!tls_reader.is_poisoned());
}
// ============= FakeTlsWriter Tests =============
#[tokio::test]
async fn test_tls_writer_single_write() {
let (client, mut server) = duplex(4096);
let mut writer = FakeTlsWriter::new(client);
let payload = b"Hello, TLS!";
writer.write_all(payload).await.unwrap();
writer.flush().await.unwrap();
// Read the TLS record from server
let mut header = [0u8; 5];
server.read_exact(&mut header).await.unwrap();
assert_eq!(header[0], TLS_RECORD_APPLICATION);
assert_eq!(&header[1..3], &TLS_VERSION);
let length = u16::from_be_bytes([header[3], header[4]]) as usize;
assert_eq!(length, payload.len());
let mut body = vec![0u8; length];
server.read_exact(&mut body).await.unwrap();
assert_eq!(&body, payload);
}
#[tokio::test]
async fn test_tls_writer_large_data_chunking() {
let (client, mut server) = duplex(65536);
let mut writer = FakeTlsWriter::new(client);
// Write data larger than MAX_TLS_PAYLOAD
let payload: Vec<u8> = (0..20000).map(|i| (i % 256) as u8).collect();
writer.write_all(&payload).await.unwrap();
writer.flush().await.unwrap();
// Read back - should be multiple records
let mut received = Vec::new();
let mut records_count = 0;
while received.len() < payload.len() {
let mut header = [0u8; 5];
if server.read_exact(&mut header).await.is_err() {
break;
}
assert_eq!(header[0], TLS_RECORD_APPLICATION);
records_count += 1;
let length = u16::from_be_bytes([header[3], header[4]]) as usize;
assert!(length <= MAX_TLS_PAYLOAD);
let mut body = vec![0u8; length];
server.read_exact(&mut body).await.unwrap();
received.extend_from_slice(&body);
}
assert_eq!(received, payload);
assert!(records_count >= 2); // Should have multiple records
}
#[tokio::test]
async fn test_tls_stream_roundtrip() {
let (client, server) = duplex(4096);
let mut writer = FakeTlsWriter::new(client);
let mut reader = FakeTlsReader::new(server);
let original = b"Hello, fake TLS!";
writer.write_all_tls(original).await.unwrap();
writer.flush().await.unwrap();
let received = reader.read_exact(original.len()).await.unwrap();
assert_eq!(&received[..], original);
}
#[tokio::test]
async fn test_tls_stream_roundtrip_large() {
let (client, server) = duplex(4096);
let mut writer = FakeTlsWriter::new(client);
let mut reader = FakeTlsReader::new(server);
let original: Vec<u8> = (0..50000).map(|i| (i % 256) as u8).collect();
// Write in background
let write_data = original.clone();
let write_handle = tokio::spawn(async move {
writer.write_all_tls(&write_data).await.unwrap();
writer.shutdown().await.unwrap();
});
// Read
let mut received = Vec::new();
let mut buf = vec![0u8; 1024];
@@ -1181,87 +1216,95 @@ mod tests {
}
received.extend_from_slice(&buf[..n]);
}
write_handle.await.unwrap();
assert_eq!(received, original);
}
#[tokio::test]
async fn test_tls_writer_state_names() {
let (client, _server) = duplex(4096);
let writer = FakeTlsWriter::new(client);
assert_eq!(writer.state_name(), "Idle");
assert!(!writer.is_poisoned());
assert!(!writer.has_pending());
}
// ============= Error Handling Tests =============
#[tokio::test]
async fn test_tls_reader_invalid_version() {
let invalid_record = vec![
TLS_RECORD_APPLICATION,
0x04, 0x00, // Invalid version
0x00, 0x05, // length = 5
0x01, 0x02, 0x03, 0x04, 0x05,
0x04,
0x00, // Invalid version
0x00,
0x05, // length = 5
0x01,
0x02,
0x03,
0x04,
0x05,
];
let reader = ChunkedReader::new(&invalid_record, 100);
let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; 5];
let result = tls_reader.read(&mut buf).await;
assert!(result.is_err());
assert!(tls_reader.is_poisoned());
}
#[tokio::test]
async fn test_tls_reader_unexpected_eof_header() {
// Partial header
let partial = vec![TLS_RECORD_APPLICATION, 0x03];
let reader = ChunkedReader::new(&partial, 100);
let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; 10];
let result = tls_reader.read(&mut buf).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_tls_reader_unexpected_eof_body() {
// Valid header but truncated body
let mut record = vec![
TLS_RECORD_APPLICATION,
TLS_VERSION[0], TLS_VERSION[1],
0x00, 0x10, // length = 16
TLS_VERSION[0],
TLS_VERSION[1],
0x00,
0x10, // length = 16
];
record.extend_from_slice(&[0u8; 8]); // Only 8 bytes of body
let reader = ChunkedReader::new(&record, 100);
let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; 16];
let result = tls_reader.read(&mut buf).await;
assert!(result.is_err());
}
// ============= Header Parsing Tests =============
#[test]
fn test_tls_record_header_parse() {
let header = [0x17, 0x03, 0x03, 0x01, 0x00];
let parsed = TlsRecordHeader::parse(&header).unwrap();
assert_eq!(parsed.record_type, TLS_RECORD_APPLICATION);
assert_eq!(parsed.version, TLS_VERSION);
assert_eq!(parsed.length, 256);
}
#[test]
fn test_tls_record_header_validate() {
let valid = TlsRecordHeader {
@@ -1270,14 +1313,14 @@ mod tests {
length: 100,
};
assert!(valid.validate().is_ok());
let invalid_version = TlsRecordHeader {
record_type: TLS_RECORD_APPLICATION,
version: [0x04, 0x00],
length: 100,
};
assert!(invalid_version.validate().is_err());
let too_large = TlsRecordHeader {
record_type: TLS_RECORD_APPLICATION,
version: TLS_VERSION,
@@ -1285,7 +1328,7 @@ mod tests {
};
assert!(too_large.validate().is_err());
}
#[test]
fn test_tls_record_header_to_bytes() {
let header = TlsRecordHeader {
@@ -1293,7 +1336,7 @@ mod tests {
version: TLS_VERSION,
length: 0x1234,
};
let bytes = header.to_bytes();
assert_eq!(bytes, [0x17, 0x03, 0x03, 0x12, 0x34]);
}