This commit is contained in:
Alexey
2026-03-21 15:45:29 +03:00
parent 7a8f946029
commit d7bbb376c9
154 changed files with 6194 additions and 3775 deletions

View File

@@ -8,8 +8,8 @@
use bytes::BytesMut;
use crossbeam_queue::ArrayQueue;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
// ============= Configuration =============
@@ -42,7 +42,7 @@ impl BufferPool {
pub fn new() -> Self {
Self::with_config(DEFAULT_BUFFER_SIZE, DEFAULT_MAX_BUFFERS)
}
/// Create a buffer pool with custom configuration
pub fn with_config(buffer_size: usize, max_buffers: usize) -> Self {
Self {
@@ -54,7 +54,7 @@ impl BufferPool {
hits: AtomicUsize::new(0),
}
}
/// Get a buffer from the pool, or create a new one if empty
pub fn get(self: &Arc<Self>) -> PooledBuffer {
match self.buffers.pop() {
@@ -76,7 +76,7 @@ impl BufferPool {
}
}
}
/// Try to get a buffer, returns None if pool is empty
pub fn try_get(self: &Arc<Self>) -> Option<PooledBuffer> {
self.buffers.pop().map(|mut buffer| {
@@ -88,12 +88,12 @@ impl BufferPool {
}
})
}
/// Return a buffer to the pool
fn return_buffer(&self, mut buffer: BytesMut) {
// Clear the buffer but keep capacity
buffer.clear();
// Only return if we haven't exceeded max and buffer is right size
if buffer.capacity() >= self.buffer_size {
// Try to push to pool, if full just drop
@@ -103,7 +103,7 @@ impl BufferPool {
// Actually we don't decrement here because the buffer might have been
// grown beyond our size - we just let it go
}
/// Get pool statistics
pub fn stats(&self) -> PoolStats {
PoolStats {
@@ -115,17 +115,21 @@ impl BufferPool {
misses: self.misses.load(Ordering::Relaxed),
}
}
/// Get buffer size
pub fn buffer_size(&self) -> usize {
self.buffer_size
}
/// Preallocate buffers to fill the pool
pub fn preallocate(&self, count: usize) {
let to_alloc = count.min(self.max_buffers);
for _ in 0..to_alloc {
if self.buffers.push(BytesMut::with_capacity(self.buffer_size)).is_err() {
if self
.buffers
.push(BytesMut::with_capacity(self.buffer_size))
.is_err()
{
break;
}
self.allocated.fetch_add(1, Ordering::Relaxed);
@@ -183,22 +187,22 @@ impl PooledBuffer {
pub fn take(mut self) -> BytesMut {
self.buffer.take().unwrap()
}
/// Get the capacity of the buffer
pub fn capacity(&self) -> usize {
self.buffer.as_ref().map(|b| b.capacity()).unwrap_or(0)
}
/// Check if buffer is empty
pub fn is_empty(&self) -> bool {
self.buffer.as_ref().map(|b| b.is_empty()).unwrap_or(true)
}
/// Get the length of data in buffer
pub fn len(&self) -> usize {
self.buffer.as_ref().map(|b| b.len()).unwrap_or(0)
}
/// Clear the buffer
pub fn clear(&mut self) {
if let Some(ref mut b) = self.buffer {
@@ -209,7 +213,7 @@ impl PooledBuffer {
impl Deref for PooledBuffer {
type Target = BytesMut;
fn deref(&self) -> &Self::Target {
self.buffer.as_ref().expect("buffer taken")
}
@@ -259,7 +263,7 @@ impl<'a> ScopedBuffer<'a> {
impl<'a> Deref for ScopedBuffer<'a> {
type Target = BytesMut;
fn deref(&self) -> &Self::Target {
self.buffer.deref()
}
@@ -280,108 +284,108 @@ impl<'a> Drop for ScopedBuffer<'a> {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_basic() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
// Get a buffer
let mut buf1 = pool.get();
buf1.extend_from_slice(b"hello");
assert_eq!(&buf1[..], b"hello");
// Drop returns to pool
drop(buf1);
let stats = pool.stats();
assert_eq!(stats.pooled, 1);
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 1);
// Get again - should reuse
let buf2 = pool.get();
assert!(buf2.is_empty()); // Buffer was cleared
let stats = pool.stats();
assert_eq!(stats.pooled, 0);
assert_eq!(stats.hits, 1);
}
#[test]
fn test_pool_multiple_buffers() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
// Get multiple buffers
let buf1 = pool.get();
let buf2 = pool.get();
let buf3 = pool.get();
let stats = pool.stats();
assert_eq!(stats.allocated, 3);
assert_eq!(stats.pooled, 0);
// Return all
drop(buf1);
drop(buf2);
drop(buf3);
let stats = pool.stats();
assert_eq!(stats.pooled, 3);
}
#[test]
fn test_pool_overflow() {
let pool = Arc::new(BufferPool::with_config(1024, 2));
// Get 3 buffers (more than max)
let buf1 = pool.get();
let buf2 = pool.get();
let buf3 = pool.get();
// Return all - only 2 should be pooled
drop(buf1);
drop(buf2);
drop(buf3);
let stats = pool.stats();
assert_eq!(stats.pooled, 2);
}
#[test]
fn test_pool_take() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
let mut buf = pool.get();
buf.extend_from_slice(b"data");
// Take ownership, buffer should not return to pool
let taken = buf.take();
assert_eq!(&taken[..], b"data");
let stats = pool.stats();
assert_eq!(stats.pooled, 0);
}
#[test]
fn test_pool_preallocate() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
pool.preallocate(5);
let stats = pool.stats();
assert_eq!(stats.pooled, 5);
assert_eq!(stats.allocated, 5);
}
#[test]
fn test_pool_try_get() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
// Pool is empty, try_get returns None
assert!(pool.try_get().is_none());
// Add a buffer to pool
pool.preallocate(1);
// Now try_get should succeed once while the buffer is held
let buf = pool.try_get();
assert!(buf.is_some());
@@ -391,50 +395,50 @@ mod tests {
drop(buf);
assert!(pool.try_get().is_some());
}
#[test]
fn test_hit_rate() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
// First get is a miss
let buf1 = pool.get();
drop(buf1);
// Second get is a hit
let buf2 = pool.get();
drop(buf2);
// Third get is a hit
let _buf3 = pool.get();
let stats = pool.stats();
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 1);
assert!((stats.hit_rate() - 66.67).abs() < 1.0);
}
#[test]
fn test_scoped_buffer() {
let pool = Arc::new(BufferPool::with_config(1024, 10));
let mut buf = pool.get();
{
let mut scoped = ScopedBuffer::new(&mut buf);
scoped.extend_from_slice(b"scoped data");
assert_eq!(&scoped[..], b"scoped data");
}
// After scoped is dropped, buffer is cleared
assert!(buf.is_empty());
}
#[test]
fn test_concurrent_access() {
use std::thread;
let pool = Arc::new(BufferPool::with_config(1024, 100));
let mut handles = vec![];
for _ in 0..10 {
let pool_clone = Arc::clone(&pool);
handles.push(thread::spawn(move || {
@@ -445,11 +449,11 @@ mod tests {
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let stats = pool.stats();
// All buffers should be returned
assert!(stats.pooled > 0);

View File

@@ -37,7 +37,7 @@
//!
//! Backpressure
//! - pending ciphertext buffer is bounded (configurable per connection)
//! - pending is full and upstream is pending
//! - pending is full and upstream is pending
//! -> poll_write returns Poll::Pending
//! -> do not accept any plaintext
//!
@@ -59,8 +59,8 @@ use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tracing::{debug, trace};
use crate::crypto::AesCtr;
use super::state::{StreamState, YieldBuffer};
use crate::crypto::AesCtr;
// ============= Constants =============
@@ -152,9 +152,9 @@ impl<R> CryptoReader<R> {
fn take_poison_error(&mut self) -> io::Error {
match &mut self.state {
CryptoReaderState::Poisoned { error } => error.take().unwrap_or_else(|| {
io::Error::other("stream previously poisoned")
}),
CryptoReaderState::Poisoned { error } => error
.take()
.unwrap_or_else(|| io::Error::other("stream previously poisoned")),
_ => io::Error::other("stream not poisoned"),
}
}
@@ -221,7 +221,11 @@ impl<R: AsyncRead + Unpin> AsyncRead for CryptoReader<R> {
let filled = buf.filled_mut();
this.decryptor.apply(&mut filled[before..after]);
trace!(bytes_read, state = this.state_name(), "CryptoReader decrypted chunk");
trace!(
bytes_read,
state = this.state_name(),
"CryptoReader decrypted chunk"
);
return Poll::Ready(Ok(()));
}
@@ -503,9 +507,9 @@ impl<W> CryptoWriter<W> {
fn take_poison_error(&mut self) -> io::Error {
match &mut self.state {
CryptoWriterState::Poisoned { error } => error.take().unwrap_or_else(|| {
io::Error::other("stream previously poisoned")
}),
CryptoWriterState::Poisoned { error } => error
.take()
.unwrap_or_else(|| io::Error::other("stream previously poisoned")),
_ => io::Error::other("stream not poisoned"),
}
}
@@ -525,7 +529,11 @@ impl<W> CryptoWriter<W> {
}
/// Select how many plaintext bytes can be accepted in buffering path
fn select_to_accept_for_buffering(state: &CryptoWriterState, buf_len: usize, max_pending: usize) -> usize {
fn select_to_accept_for_buffering(
state: &CryptoWriterState,
buf_len: usize,
max_pending: usize,
) -> usize {
if buf_len == 0 {
return 0;
}
@@ -602,11 +610,7 @@ impl<W: AsyncWrite + Unpin> CryptoWriter<W> {
}
impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
let this = self.get_mut();
// Poisoned?
@@ -629,8 +633,11 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => {
// Upstream blocked. Apply ideal backpressure
let to_accept =
Self::select_to_accept_for_buffering(&this.state, buf.len(), this.max_pending_write);
let to_accept = Self::select_to_accept_for_buffering(
&this.state,
buf.len(),
this.max_pending_write,
);
if to_accept == 0 {
trace!(

View File

@@ -9,8 +9,8 @@ use bytes::{Bytes, BytesMut};
use std::io::Result;
use std::sync::Arc;
use crate::protocol::constants::ProtoTag;
use crate::crypto::SecureRandom;
use crate::protocol::constants::ProtoTag;
// ============= Frame Types =============
@@ -31,27 +31,27 @@ impl Frame {
meta: FrameMeta::default(),
}
}
/// Create a new frame with data and metadata
pub fn with_meta(data: Bytes, meta: FrameMeta) -> Self {
Self { data, meta }
}
/// Create an empty frame
pub fn empty() -> Self {
Self::new(Bytes::new())
}
/// Check if frame is empty
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
/// Get frame length
pub fn len(&self) -> usize {
self.data.len()
}
/// Create a QuickAck request frame
pub fn quickack(data: Bytes) -> Self {
Self {
@@ -62,7 +62,7 @@ impl Frame {
},
}
}
/// Create a simple ACK frame
pub fn simple_ack(data: Bytes) -> Self {
Self {
@@ -91,25 +91,25 @@ impl FrameMeta {
pub fn new() -> Self {
Self::default()
}
/// Create with quickack flag
pub fn with_quickack(mut self) -> Self {
self.quickack = true;
self
}
/// Create with simple_ack flag
pub fn with_simple_ack(mut self) -> Self {
self.simple_ack = true;
self
}
/// Create with padding length
pub fn with_padding(mut self, len: u8) -> Self {
self.padding_len = len;
self
}
/// Check if any special flags are set
pub fn has_flags(&self) -> bool {
self.quickack || self.simple_ack
@@ -122,12 +122,12 @@ impl FrameMeta {
pub trait FrameCodec: Send + Sync {
/// Get the protocol tag for this codec
fn proto_tag(&self) -> ProtoTag;
/// Encode a frame into the destination buffer
///
/// Returns the number of bytes written.
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> Result<usize>;
/// Try to decode a frame from the source buffer
///
/// Returns:
@@ -137,10 +137,10 @@ pub trait FrameCodec: Send + Sync {
///
/// On success, the consumed bytes are removed from `src`.
fn decode(&self, src: &mut BytesMut) -> Result<Option<Frame>>;
/// Get the minimum bytes needed to determine frame length
fn min_header_size(&self) -> usize;
/// Get the maximum allowed frame size
fn max_frame_size(&self) -> usize {
// Default: 16MB
@@ -162,30 +162,28 @@ pub fn create_codec(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Box<dyn Fram
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_frame_creation() {
let frame = Frame::new(Bytes::from_static(b"test"));
assert_eq!(frame.len(), 4);
assert!(!frame.is_empty());
assert!(!frame.meta.quickack);
let frame = Frame::empty();
assert!(frame.is_empty());
let frame = Frame::quickack(Bytes::from_static(b"ack"));
assert!(frame.meta.quickack);
}
#[test]
fn test_frame_meta() {
let meta = FrameMeta::new()
.with_quickack()
.with_padding(3);
let meta = FrameMeta::new().with_quickack().with_padding(3);
assert!(meta.quickack);
assert!(!meta.simple_ack);
assert_eq!(meta.padding_len, 3);
assert!(meta.has_flags());
}
}
}

View File

@@ -5,16 +5,16 @@
#![allow(dead_code)]
use bytes::{Bytes, BytesMut, BufMut};
use bytes::{BufMut, Bytes, BytesMut};
use std::io::{self, Error, ErrorKind};
use std::sync::Arc;
use tokio_util::codec::{Decoder, Encoder};
use super::frame::{Frame, FrameCodec as FrameCodecTrait, FrameMeta};
use crate::crypto::SecureRandom;
use crate::protocol::constants::{
ProtoTag, is_valid_secure_payload_len, secure_padding_len, secure_payload_len_from_wire_len,
};
use crate::crypto::SecureRandom;
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
// ============= Unified Codec =============
@@ -40,13 +40,13 @@ impl FrameCodec {
rng,
}
}
/// Set maximum frame size
pub fn with_max_frame_size(mut self, size: usize) -> Self {
self.max_frame_size = size;
self
}
/// Get protocol tag
pub fn proto_tag(&self) -> ProtoTag {
self.proto_tag
@@ -56,7 +56,7 @@ impl FrameCodec {
impl Decoder for FrameCodec {
type Item = Frame;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
match self.proto_tag {
ProtoTag::Abridged => decode_abridged(src, self.max_frame_size),
@@ -68,7 +68,7 @@ impl Decoder for FrameCodec {
impl Encoder<Frame> for FrameCodec {
type Error = io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
match self.proto_tag {
ProtoTag::Abridged => encode_abridged(&frame, dst),
@@ -84,18 +84,18 @@ fn decode_abridged(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Fra
if src.is_empty() {
return Ok(None);
}
let mut meta = FrameMeta::new();
let first_byte = src[0];
// Extract length and quickack flag
let mut len_words = (first_byte & 0x7f) as usize;
if first_byte >= 0x80 {
meta.quickack = true;
}
let header_len;
if len_words == 0x7f {
// Extended length (3 more bytes needed)
if src.len() < 4 {
@@ -106,46 +106,49 @@ fn decode_abridged(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Fra
} else {
header_len = 1;
}
// Length is in 4-byte words
let byte_len = len_words.checked_mul(4).ok_or_else(|| {
Error::new(ErrorKind::InvalidData, "frame length overflow")
})?;
let byte_len = len_words
.checked_mul(4)
.ok_or_else(|| Error::new(ErrorKind::InvalidData, "frame length overflow"))?;
// Validate size
if byte_len > max_size {
return Err(Error::new(
ErrorKind::InvalidData,
format!("frame too large: {} bytes (max {})", byte_len, max_size)
format!("frame too large: {} bytes (max {})", byte_len, max_size),
));
}
let total_len = header_len + byte_len;
if src.len() < total_len {
// Reserve space for the rest of the frame
src.reserve(total_len - src.len());
return Ok(None);
}
// Extract data
let _ = src.split_to(header_len);
let data = src.split_to(byte_len).freeze();
Ok(Some(Frame::with_meta(data, meta)))
}
fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
let data = &frame.data;
// Validate alignment
if !data.len().is_multiple_of(4) {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("abridged frame must be 4-byte aligned, got {} bytes", data.len())
format!(
"abridged frame must be 4-byte aligned, got {} bytes",
data.len()
),
));
}
// Simple ACK: send reversed data without header
if frame.meta.simple_ack {
dst.reserve(data.len());
@@ -154,9 +157,9 @@ fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
}
return Ok(());
}
let len_words = data.len() / 4;
if len_words < 0x7f {
// Short header
dst.reserve(1 + data.len());
@@ -178,10 +181,10 @@ fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
} else {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("frame too large: {} bytes", data.len())
format!("frame too large: {} bytes", data.len()),
));
}
dst.extend_from_slice(data);
Ok(())
}
@@ -192,58 +195,58 @@ fn decode_intermediate(src: &mut BytesMut, max_size: usize) -> io::Result<Option
if src.len() < 4 {
return Ok(None);
}
let mut meta = FrameMeta::new();
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
// Check QuickACK flag
if len >= 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
// Validate size
if len > max_size {
return Err(Error::new(
ErrorKind::InvalidData,
format!("frame too large: {} bytes (max {})", len, max_size)
format!("frame too large: {} bytes (max {})", len, max_size),
));
}
let total_len = 4 + len;
if src.len() < total_len {
src.reserve(total_len - src.len());
return Ok(None);
}
// Extract data
let _ = src.split_to(4);
let data = src.split_to(len).freeze();
Ok(Some(Frame::with_meta(data, meta)))
}
fn encode_intermediate(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
let data = &frame.data;
// Simple ACK: just send data
if frame.meta.simple_ack {
dst.reserve(data.len());
dst.extend_from_slice(data);
return Ok(());
}
dst.reserve(4 + data.len());
let mut len = data.len() as u32;
if frame.meta.quickack {
len |= 0x80000000;
}
dst.extend_from_slice(&len.to_le_bytes());
dst.extend_from_slice(data);
Ok(())
}
@@ -253,31 +256,31 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
if src.len() < 4 {
return Ok(None);
}
let mut meta = FrameMeta::new();
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
// Check QuickACK flag
if len >= 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
// Validate size
if len > max_size {
return Err(Error::new(
ErrorKind::InvalidData,
format!("frame too large: {} bytes (max {})", len, max_size)
format!("frame too large: {} bytes (max {})", len, max_size),
));
}
let total_len = 4 + len;
if src.len() < total_len {
src.reserve(total_len - src.len());
return Ok(None);
}
let data_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
Error::new(
ErrorKind::InvalidData,
@@ -285,28 +288,28 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
)
})?;
let padding_len = len - data_len;
meta.padding_len = padding_len as u8;
// Extract data (excluding padding)
let _ = src.split_to(4);
let all_data = src.split_to(len);
// Copy only the data portion, excluding padding
let data = Bytes::copy_from_slice(&all_data[..data_len]);
Ok(Some(Frame::with_meta(data, meta)))
}
fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::Result<()> {
let data = &frame.data;
// Simple ACK: just send data
if frame.meta.simple_ack {
dst.reserve(data.len());
dst.extend_from_slice(data);
return Ok(());
}
if !is_valid_secure_payload_len(data.len()) {
return Err(Error::new(
ErrorKind::InvalidData,
@@ -316,23 +319,23 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::R
// Generate padding that keeps total length non-divisible by 4.
let padding_len = secure_padding_len(data.len(), rng);
let total_len = data.len() + padding_len;
dst.reserve(4 + total_len);
let mut len = total_len as u32;
if frame.meta.quickack {
len |= 0x80000000;
}
dst.extend_from_slice(&len.to_le_bytes());
dst.extend_from_slice(data);
if padding_len > 0 {
let padding = rng.bytes(padding_len);
dst.extend_from_slice(&padding);
}
Ok(())
}
@@ -360,7 +363,7 @@ impl Default for AbridgedCodec {
impl Decoder for AbridgedCodec {
type Item = Frame;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
decode_abridged(src, self.max_frame_size)
}
@@ -368,7 +371,7 @@ impl Decoder for AbridgedCodec {
impl Encoder<Frame> for AbridgedCodec {
type Error = io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
encode_abridged(&frame, dst)
}
@@ -378,17 +381,17 @@ impl FrameCodecTrait for AbridgedCodec {
fn proto_tag(&self) -> ProtoTag {
ProtoTag::Abridged
}
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
let before = dst.len();
encode_abridged(frame, dst)?;
Ok(dst.len() - before)
}
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
decode_abridged(src, self.max_frame_size)
}
fn min_header_size(&self) -> usize {
1
}
@@ -416,7 +419,7 @@ impl Default for IntermediateCodec {
impl Decoder for IntermediateCodec {
type Item = Frame;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
decode_intermediate(src, self.max_frame_size)
}
@@ -424,7 +427,7 @@ impl Decoder for IntermediateCodec {
impl Encoder<Frame> for IntermediateCodec {
type Error = io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
encode_intermediate(&frame, dst)
}
@@ -434,17 +437,17 @@ impl FrameCodecTrait for IntermediateCodec {
fn proto_tag(&self) -> ProtoTag {
ProtoTag::Intermediate
}
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
let before = dst.len();
encode_intermediate(frame, dst)?;
Ok(dst.len() - before)
}
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
decode_intermediate(src, self.max_frame_size)
}
fn min_header_size(&self) -> usize {
4
}
@@ -474,7 +477,7 @@ impl Default for SecureCodec {
impl Decoder for SecureCodec {
type Item = Frame;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
decode_secure(src, self.max_frame_size)
}
@@ -482,7 +485,7 @@ impl Decoder for SecureCodec {
impl Encoder<Frame> for SecureCodec {
type Error = io::Error;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
encode_secure(&frame, dst, &self.rng)
}
@@ -492,17 +495,17 @@ impl FrameCodecTrait for SecureCodec {
fn proto_tag(&self) -> ProtoTag {
ProtoTag::Secure
}
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
let before = dst.len();
encode_secure(frame, dst, &self.rng)?;
Ok(dst.len() - before)
}
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
decode_secure(src, self.max_frame_size)
}
fn min_header_size(&self) -> usize {
4
}
@@ -513,121 +516,127 @@ impl FrameCodecTrait for SecureCodec {
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use tokio_util::codec::{FramedRead, FramedWrite};
use tokio::io::duplex;
use futures::{SinkExt, StreamExt};
use crate::crypto::SecureRandom;
use futures::{SinkExt, StreamExt};
use std::collections::HashSet;
use std::sync::Arc;
use tokio::io::duplex;
use tokio_util::codec::{FramedRead, FramedWrite};
#[tokio::test]
async fn test_framed_abridged() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, AbridgedCodec::new());
let mut reader = FramedRead::new(server, AbridgedCodec::new());
// Write a frame
let frame = Frame::new(Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]));
writer.send(frame).await.unwrap();
// Read it back
let received = reader.next().await.unwrap().unwrap();
assert_eq!(&received.data[..], &[1, 2, 3, 4, 5, 6, 7, 8]);
}
#[tokio::test]
async fn test_framed_intermediate() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
let mut reader = FramedRead::new(server, IntermediateCodec::new());
let frame = Frame::new(Bytes::from_static(b"hello world"));
writer.send(frame).await.unwrap();
let received = reader.next().await.unwrap().unwrap();
assert_eq!(&received.data[..], b"hello world");
}
#[tokio::test]
async fn test_framed_secure() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, SecureCodec::new(Arc::new(SecureRandom::new())));
let mut reader = FramedRead::new(server, SecureCodec::new(Arc::new(SecureRandom::new())));
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
let frame = Frame::new(original.clone());
writer.send(frame).await.unwrap();
let received = reader.next().await.unwrap().unwrap();
assert_eq!(&received.data[..], &original[..]);
}
#[tokio::test]
async fn test_unified_codec() {
for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
let mut writer = FramedWrite::new(
client,
FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())),
);
let mut reader = FramedRead::new(
server,
FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())),
);
// Use 4-byte aligned data for abridged compatibility
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
let frame = Frame::new(original.clone());
writer.send(frame).await.unwrap();
let received = reader.next().await.unwrap().unwrap();
assert_eq!(received.data.len(), 8);
}
}
#[tokio::test]
async fn test_multiple_frames() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
let mut reader = FramedRead::new(server, IntermediateCodec::new());
// Send multiple frames
for i in 0..10 {
let data: Vec<u8> = (0..((i + 1) * 10)).map(|j| (j % 256) as u8).collect();
let frame = Frame::new(Bytes::from(data));
writer.send(frame).await.unwrap();
}
// Receive them
for i in 0..10 {
let received = reader.next().await.unwrap().unwrap();
assert_eq!(received.data.len(), (i + 1) * 10);
}
}
#[tokio::test]
async fn test_quickack_flag() {
let (client, server) = duplex(4096);
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
let mut reader = FramedRead::new(server, IntermediateCodec::new());
let frame = Frame::quickack(Bytes::from_static(b"urgent"));
writer.send(frame).await.unwrap();
let received = reader.next().await.unwrap().unwrap();
assert!(received.meta.quickack);
}
#[test]
fn test_frame_too_large() {
let mut codec = FrameCodec::new(ProtoTag::Intermediate, Arc::new(SecureRandom::new()))
.with_max_frame_size(100);
// Create a "frame" that claims to be very large
let mut buf = BytesMut::new();
buf.extend_from_slice(&1000u32.to_le_bytes()); // length = 1000
buf.extend_from_slice(&[0u8; 10]); // partial data
let result = codec.decode(&mut buf);
assert!(result.is_err());
}

View File

@@ -2,13 +2,13 @@
#![allow(dead_code)]
use super::traits::{FrameMeta, LayeredStream};
use crate::crypto::{SecureRandom, crc32};
use crate::protocol::constants::*;
use bytes::Bytes;
use std::io::{Error, ErrorKind, Result};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use crate::protocol::constants::*;
use crate::crypto::{crc32, SecureRandom};
use std::sync::Arc;
use super::traits::{FrameMeta, LayeredStream};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
// ============= Abridged (Compact) Frame =============
@@ -27,41 +27,47 @@ impl<R: AsyncRead + Unpin> AbridgedFrameReader<R> {
/// Read a frame and return (data, metadata)
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
let mut meta = FrameMeta::new();
// Read length byte
let mut len_byte = [0u8];
self.upstream.read_exact(&mut len_byte).await?;
let mut len = len_byte[0] as usize;
// Check QuickACK flag (high bit)
if len >= 0x80 {
meta.quickack = true;
len -= 0x80;
}
// Extended length (3 bytes)
if len == 0x7f {
let mut len_bytes = [0u8; 3];
self.upstream.read_exact(&mut len_bytes).await?;
len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], 0]) as usize;
}
// Length is in 4-byte words
let byte_len = len * 4;
// Read data
let mut data = vec![0u8; byte_len];
self.upstream.read_exact(&mut data).await?;
Ok((Bytes::from(data), meta))
}
}
impl<R> LayeredStream<R> for AbridgedFrameReader<R> {
fn upstream(&self) -> &R { &self.upstream }
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
fn into_upstream(self) -> R { self.upstream }
fn upstream(&self) -> &R {
&self.upstream
}
fn upstream_mut(&mut self) -> &mut R {
&mut self.upstream
}
fn into_upstream(self) -> R {
self.upstream
}
}
/// Writer for abridged MTProto framing
@@ -81,19 +87,22 @@ impl<W: AsyncWrite + Unpin> AbridgedFrameWriter<W> {
if !data.len().is_multiple_of(4) {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("Abridged frame must be aligned to 4 bytes, got {}", data.len()),
format!(
"Abridged frame must be aligned to 4 bytes, got {}",
data.len()
),
));
}
// Simple ACK: send reversed data
if meta.simple_ack {
let reversed: Vec<u8> = data.iter().rev().copied().collect();
self.upstream.write_all(&reversed).await?;
return Ok(());
}
let len_div_4 = data.len() / 4;
if len_div_4 < 0x7f {
// Short length (1 byte)
self.upstream.write_all(&[len_div_4 as u8]).await?;
@@ -108,20 +117,26 @@ impl<W: AsyncWrite + Unpin> AbridgedFrameWriter<W> {
format!("Frame too large: {} bytes", data.len()),
));
}
self.upstream.write_all(data).await?;
Ok(())
}
pub async fn flush(&mut self) -> Result<()> {
self.upstream.flush().await
}
}
impl<W> LayeredStream<W> for AbridgedFrameWriter<W> {
fn upstream(&self) -> &W { &self.upstream }
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
fn into_upstream(self) -> W { self.upstream }
fn upstream(&self) -> &W {
&self.upstream
}
fn upstream_mut(&mut self) -> &mut W {
&mut self.upstream
}
fn into_upstream(self) -> W {
self.upstream
}
}
// ============= Intermediate Frame =============
@@ -140,31 +155,37 @@ impl<R> IntermediateFrameReader<R> {
impl<R: AsyncRead + Unpin> IntermediateFrameReader<R> {
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
let mut meta = FrameMeta::new();
// Read 4-byte length
let mut len_bytes = [0u8; 4];
self.upstream.read_exact(&mut len_bytes).await?;
let mut len = u32::from_le_bytes(len_bytes) as usize;
// Check QuickACK flag (high bit)
if len > 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
// Read data
let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?;
Ok((Bytes::from(data), meta))
}
}
impl<R> LayeredStream<R> for IntermediateFrameReader<R> {
fn upstream(&self) -> &R { &self.upstream }
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
fn into_upstream(self) -> R { self.upstream }
fn upstream(&self) -> &R {
&self.upstream
}
fn upstream_mut(&mut self) -> &mut R {
&mut self.upstream
}
fn into_upstream(self) -> R {
self.upstream
}
}
/// Writer for intermediate MTProto framing
@@ -189,16 +210,22 @@ impl<W: AsyncWrite + Unpin> IntermediateFrameWriter<W> {
}
Ok(())
}
pub async fn flush(&mut self) -> Result<()> {
self.upstream.flush().await
}
}
impl<W> LayeredStream<W> for IntermediateFrameWriter<W> {
fn upstream(&self) -> &W { &self.upstream }
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
fn into_upstream(self) -> W { self.upstream }
fn upstream(&self) -> &W {
&self.upstream
}
fn upstream_mut(&mut self) -> &mut W {
&mut self.upstream
}
fn into_upstream(self) -> W {
self.upstream
}
}
// ============= Secure Intermediate Frame =============
@@ -217,23 +244,23 @@ impl<R> SecureIntermediateFrameReader<R> {
impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
let mut meta = FrameMeta::new();
// Read 4-byte length
let mut len_bytes = [0u8; 4];
self.upstream.read_exact(&mut len_bytes).await?;
let mut len = u32::from_le_bytes(len_bytes) as usize;
// Check QuickACK flag
if len > 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
// Read data (including padding)
let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?;
let payload_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
Error::new(
ErrorKind::InvalidData,
@@ -241,15 +268,21 @@ impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
)
})?;
data.truncate(payload_len);
Ok((Bytes::from(data), meta))
}
}
impl<R> LayeredStream<R> for SecureIntermediateFrameReader<R> {
fn upstream(&self) -> &R { &self.upstream }
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
fn into_upstream(self) -> R { self.upstream }
fn upstream(&self) -> &R {
&self.upstream
}
fn upstream_mut(&mut self) -> &mut R {
&mut self.upstream
}
fn into_upstream(self) -> R {
self.upstream
}
}
/// Writer for secure intermediate MTProto framing
@@ -270,7 +303,7 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
self.upstream.write_all(data).await?;
return Ok(());
}
if !is_valid_secure_payload_len(data.len()) {
return Err(Error::new(
ErrorKind::InvalidData,
@@ -281,26 +314,32 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
// Add padding so total length is never divisible by 4 (MTProto Secure)
let padding_len = secure_padding_len(data.len(), &self.rng);
let padding = self.rng.bytes(padding_len);
let total_len = data.len() + padding_len;
let len_bytes = (total_len as u32).to_le_bytes();
self.upstream.write_all(&len_bytes).await?;
self.upstream.write_all(data).await?;
self.upstream.write_all(&padding).await?;
Ok(())
}
pub async fn flush(&mut self) -> Result<()> {
self.upstream.flush().await
}
}
impl<W> LayeredStream<W> for SecureIntermediateFrameWriter<W> {
fn upstream(&self) -> &W { &self.upstream }
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
fn into_upstream(self) -> W { self.upstream }
fn upstream(&self) -> &W {
&self.upstream
}
fn upstream_mut(&mut self) -> &mut W {
&mut self.upstream
}
fn into_upstream(self) -> W {
self.upstream
}
}
// ============= Full MTProto Frame (with CRC) =============
@@ -313,7 +352,10 @@ pub struct MtprotoFrameReader<R> {
impl<R> MtprotoFrameReader<R> {
pub fn new(upstream: R, start_seq: i32) -> Self {
Self { upstream, seq_no: start_seq }
Self {
upstream,
seq_no: start_seq,
}
}
}
@@ -324,57 +366,65 @@ impl<R: AsyncRead + Unpin> MtprotoFrameReader<R> {
let mut len_bytes = [0u8; 4];
self.upstream.read_exact(&mut len_bytes).await?;
let len = u32::from_le_bytes(len_bytes) as usize;
// Skip padding-only messages
if len == 4 {
continue;
}
// Validate length
if !(MIN_MSG_LEN..=MAX_MSG_LEN).contains(&len) || !len.is_multiple_of(PADDING_FILLER.len()) {
if !(MIN_MSG_LEN..=MAX_MSG_LEN).contains(&len)
|| !len.is_multiple_of(PADDING_FILLER.len())
{
return Err(Error::new(
ErrorKind::InvalidData,
format!("Invalid message length: {}", len),
));
}
// Read sequence number
let mut seq_bytes = [0u8; 4];
self.upstream.read_exact(&mut seq_bytes).await?;
let msg_seq = i32::from_le_bytes(seq_bytes);
if msg_seq != self.seq_no {
return Err(Error::new(
ErrorKind::InvalidData,
format!("Sequence mismatch: expected {}, got {}", self.seq_no, msg_seq),
format!(
"Sequence mismatch: expected {}, got {}",
self.seq_no, msg_seq
),
));
}
self.seq_no += 1;
// Read data (length - 4 len - 4 seq - 4 crc = len - 12)
let data_len = len - 12;
let mut data = vec![0u8; data_len];
self.upstream.read_exact(&mut data).await?;
// Read and verify CRC32
let mut crc_bytes = [0u8; 4];
self.upstream.read_exact(&mut crc_bytes).await?;
let expected_crc = u32::from_le_bytes(crc_bytes);
// Compute CRC over len + seq + data
let mut crc_input = Vec::with_capacity(8 + data_len);
crc_input.extend_from_slice(&len_bytes);
crc_input.extend_from_slice(&seq_bytes);
crc_input.extend_from_slice(&data);
let computed_crc = crc32(&crc_input);
if computed_crc != expected_crc {
return Err(Error::new(
ErrorKind::InvalidData,
format!("CRC mismatch: expected {:08x}, got {:08x}", expected_crc, computed_crc),
format!(
"CRC mismatch: expected {:08x}, got {:08x}",
expected_crc, computed_crc
),
));
}
return Ok(Bytes::from(data));
}
}
@@ -388,7 +438,10 @@ pub struct MtprotoFrameWriter<W> {
impl<W> MtprotoFrameWriter<W> {
pub fn new(upstream: W, start_seq: i32) -> Self {
Self { upstream, seq_no: start_seq }
Self {
upstream,
seq_no: start_seq,
}
}
}
@@ -396,11 +449,11 @@ impl<W: AsyncWrite + Unpin> MtprotoFrameWriter<W> {
pub async fn write_frame(&mut self, msg: &[u8]) -> Result<()> {
// Total length: 4 (len) + 4 (seq) + data + 4 (crc)
let len = msg.len() + 12;
let len_bytes = (len as u32).to_le_bytes();
let seq_bytes = self.seq_no.to_le_bytes();
self.seq_no += 1;
// Compute CRC
let mut crc_input = Vec::with_capacity(8 + msg.len());
crc_input.extend_from_slice(&len_bytes);
@@ -408,25 +461,25 @@ impl<W: AsyncWrite + Unpin> MtprotoFrameWriter<W> {
crc_input.extend_from_slice(msg);
let checksum = crc32(&crc_input);
let crc_bytes = checksum.to_le_bytes();
// Calculate padding for CBC alignment
let total_len = len_bytes.len() + seq_bytes.len() + msg.len() + crc_bytes.len();
let padding_needed = (CBC_PADDING - (total_len % CBC_PADDING)) % CBC_PADDING;
let padding_count = padding_needed / PADDING_FILLER.len();
// Write everything
self.upstream.write_all(&len_bytes).await?;
self.upstream.write_all(&seq_bytes).await?;
self.upstream.write_all(msg).await?;
self.upstream.write_all(&crc_bytes).await?;
for _ in 0..padding_count {
self.upstream.write_all(&PADDING_FILLER).await?;
}
Ok(())
}
pub async fn flush(&mut self) -> Result<()> {
self.upstream.flush().await
}
@@ -445,11 +498,15 @@ impl<R: AsyncRead + Unpin> FrameReaderKind<R> {
pub fn new(upstream: R, proto_tag: ProtoTag) -> Self {
match proto_tag {
ProtoTag::Abridged => FrameReaderKind::Abridged(AbridgedFrameReader::new(upstream)),
ProtoTag::Intermediate => FrameReaderKind::Intermediate(IntermediateFrameReader::new(upstream)),
ProtoTag::Secure => FrameReaderKind::SecureIntermediate(SecureIntermediateFrameReader::new(upstream)),
ProtoTag::Intermediate => {
FrameReaderKind::Intermediate(IntermediateFrameReader::new(upstream))
}
ProtoTag::Secure => {
FrameReaderKind::SecureIntermediate(SecureIntermediateFrameReader::new(upstream))
}
}
}
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
match self {
FrameReaderKind::Abridged(r) => r.read_frame().await,
@@ -469,11 +526,15 @@ impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
pub fn new(upstream: W, proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
match proto_tag {
ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)),
ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)),
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream, rng)),
ProtoTag::Intermediate => {
FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream))
}
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(
SecureIntermediateFrameWriter::new(upstream, rng),
),
}
}
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
match self {
FrameWriterKind::Abridged(w) => w.write_frame(data, meta).await,
@@ -481,7 +542,7 @@ impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
FrameWriterKind::SecureIntermediate(w) => w.write_frame(data, meta).await,
}
}
pub async fn flush(&mut self) -> Result<()> {
match self {
FrameWriterKind::Abridged(w) => w.flush().await,
@@ -494,103 +555,110 @@ impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::duplex;
use std::sync::Arc;
use crate::crypto::SecureRandom;
use std::sync::Arc;
use tokio::io::duplex;
#[tokio::test]
async fn test_abridged_roundtrip() {
let (client, server) = duplex(1024);
let mut writer = AbridgedFrameWriter::new(client);
let mut reader = AbridgedFrameReader::new(server);
// Short frame
let data = vec![1u8, 2, 3, 4]; // 4 bytes = 1 word
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
writer.flush().await.unwrap();
let (received, _meta) = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &data[..]);
}
#[tokio::test]
async fn test_abridged_long_frame() {
let (client, server) = duplex(65536);
let mut writer = AbridgedFrameWriter::new(client);
let mut reader = AbridgedFrameReader::new(server);
// Long frame (> 0x7f words = 508 bytes)
let data: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
let padded_len = (data.len() + 3) / 4 * 4;
let mut padded = data.clone();
padded.resize(padded_len, 0);
writer.write_frame(&padded, &FrameMeta::new()).await.unwrap();
writer
.write_frame(&padded, &FrameMeta::new())
.await
.unwrap();
writer.flush().await.unwrap();
let (received, _meta) = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &padded[..]);
}
#[tokio::test]
async fn test_intermediate_roundtrip() {
let (client, server) = duplex(1024);
let mut writer = IntermediateFrameWriter::new(client);
let mut reader = IntermediateFrameReader::new(server);
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
writer.flush().await.unwrap();
let (received, _meta) = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &data[..]);
}
#[tokio::test]
async fn test_secure_intermediate_padding() {
let (client, server) = duplex(1024);
let mut writer = SecureIntermediateFrameWriter::new(client, Arc::new(SecureRandom::new()));
let mut reader = SecureIntermediateFrameReader::new(server);
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
writer.flush().await.unwrap();
let (received, _meta) = reader.read_frame().await.unwrap();
assert_eq!(received.len(), data.len());
}
#[tokio::test]
async fn test_mtproto_frame_roundtrip() {
let (client, server) = duplex(1024);
let mut writer = MtprotoFrameWriter::new(client, 0);
let mut reader = MtprotoFrameReader::new(server, 0);
// Message must be padded properly
let data = vec![0u8; 16]; // Aligned to 4 and CBC_PADDING
writer.write_frame(&data).await.unwrap();
writer.flush().await.unwrap();
let received = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &data[..]);
}
#[tokio::test]
async fn test_frame_reader_kind() {
let (client, server) = duplex(1024);
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate, Arc::new(SecureRandom::new()));
let mut writer = FrameWriterKind::new(
client,
ProtoTag::Intermediate,
Arc::new(SecureRandom::new()),
);
let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate);
let data = vec![1u8, 2, 3, 4];
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
writer.flush().await.unwrap();
let (received, _) = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &data[..]);
}

View File

@@ -1,12 +1,12 @@
//! Stream wrappers for MTProto protocol layers
pub mod state;
pub mod buffer_pool;
pub mod traits;
pub mod crypto_stream;
pub mod tls_stream;
pub mod frame;
pub mod frame_codec;
pub mod state;
pub mod tls_stream;
pub mod traits;
// Legacy compatibility - will be removed later
pub mod frame_stream;
@@ -14,13 +14,12 @@ pub mod frame_stream;
// Re-export state machine types
#[allow(unused_imports)]
pub use state::{
StreamState, Transition, PollResult,
ReadBuffer, WriteBuffer, HeaderBuffer, YieldBuffer,
HeaderBuffer, PollResult, ReadBuffer, StreamState, Transition, WriteBuffer, YieldBuffer,
};
// Re-export buffer pool
#[allow(unused_imports)]
pub use buffer_pool::{BufferPool, PooledBuffer, PoolStats};
pub use buffer_pool::{BufferPool, PoolStats, PooledBuffer};
// Re-export stream implementations
#[allow(unused_imports)]
@@ -29,21 +28,16 @@ pub use tls_stream::{FakeTlsReader, FakeTlsWriter};
// Re-export frame types
#[allow(unused_imports)]
pub use frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait, create_codec};
pub use frame::{Frame, FrameCodec as FrameCodecTrait, FrameMeta, create_codec};
// Re-export tokio-util compatible codecs
#[allow(unused_imports)]
pub use frame_codec::{
FrameCodec,
AbridgedCodec, IntermediateCodec, SecureCodec,
};
pub use frame_codec::{AbridgedCodec, FrameCodec, IntermediateCodec, SecureCodec};
// Legacy re-exports for compatibility
#[allow(unused_imports)]
pub use frame_stream::{
AbridgedFrameReader, AbridgedFrameWriter,
IntermediateFrameReader, IntermediateFrameWriter,
AbridgedFrameReader, AbridgedFrameWriter, FrameReaderKind, FrameWriterKind,
IntermediateFrameReader, IntermediateFrameWriter, MtprotoFrameReader, MtprotoFrameWriter,
SecureIntermediateFrameReader, SecureIntermediateFrameWriter,
MtprotoFrameReader, MtprotoFrameWriter,
FrameReaderKind, FrameWriterKind,
};

View File

@@ -14,10 +14,10 @@ use std::io;
pub trait StreamState: Sized {
/// Check if this is a terminal state (no more transitions possible)
fn is_terminal(&self) -> bool;
/// Check if stream is in poisoned/error state
fn is_poisoned(&self) -> bool;
/// Get human-readable state name for debugging
fn state_name(&self) -> &'static str;
}
@@ -44,7 +44,7 @@ impl<S, O> Transition<S, O> {
pub fn has_output(&self) -> bool {
matches!(self, Transition::Complete(_) | Transition::Yield(_, _))
}
/// Map the output value
pub fn map_output<U, F: FnOnce(O) -> U>(self, f: F) -> Transition<S, U> {
match self {
@@ -55,7 +55,7 @@ impl<S, O> Transition<S, O> {
Transition::Error(e) => Transition::Error(e),
}
}
/// Map the state value
pub fn map_state<T, F: FnOnce(S) -> T>(self, f: F) -> Transition<T, O> {
match self {
@@ -90,12 +90,12 @@ impl<T> PollResult<T> {
pub fn is_ready(&self) -> bool {
matches!(self, PollResult::Ready(_))
}
/// Check if result indicates EOF
pub fn is_eof(&self) -> bool {
matches!(self, PollResult::Eof)
}
/// Convert to Option, discarding non-ready states
pub fn ok(self) -> Option<T> {
match self {
@@ -103,7 +103,7 @@ impl<T> PollResult<T> {
_ => None,
}
}
/// Map the value
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> PollResult<U> {
match self {
@@ -146,7 +146,7 @@ impl ReadBuffer {
target: None,
}
}
/// Create with specific capacity
pub fn with_capacity(capacity: usize) -> Self {
Self {
@@ -154,7 +154,7 @@ impl ReadBuffer {
target: None,
}
}
/// Create with target size
pub fn with_target(target: usize) -> Self {
Self {
@@ -162,17 +162,17 @@ impl ReadBuffer {
target: Some(target),
}
}
/// Get current buffer length
pub fn len(&self) -> usize {
self.buffer.len()
}
/// Check if buffer is empty
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
/// Check if target is reached
pub fn is_complete(&self) -> bool {
match self.target {
@@ -180,7 +180,7 @@ impl ReadBuffer {
None => false,
}
}
/// Get remaining bytes needed
pub fn remaining(&self) -> usize {
match self.target {
@@ -188,18 +188,18 @@ impl ReadBuffer {
None => 0,
}
}
/// Append data to buffer
pub fn extend(&mut self, data: &[u8]) {
self.buffer.extend_from_slice(data);
}
/// Take all data from buffer
pub fn take(&mut self) -> Bytes {
self.target = None;
self.buffer.split().freeze()
}
/// Take exactly n bytes
pub fn take_exact(&mut self, n: usize) -> Option<Bytes> {
if self.buffer.len() >= n {
@@ -208,23 +208,23 @@ impl ReadBuffer {
None
}
}
/// Get a slice of the buffer
pub fn as_slice(&self) -> &[u8] {
&self.buffer
}
/// Get mutable access to underlying BytesMut
pub fn as_bytes_mut(&mut self) -> &mut BytesMut {
&mut self.buffer
}
/// Clear the buffer
pub fn clear(&mut self) {
self.buffer.clear();
self.target = None;
}
/// Set new target
pub fn set_target(&mut self, target: usize) {
self.target = Some(target);
@@ -253,7 +253,7 @@ impl WriteBuffer {
pub fn new() -> Self {
Self::with_max_size(256 * 1024)
}
/// Create with specific max size
pub fn with_max_size(max_size: usize) -> Self {
Self {
@@ -262,27 +262,27 @@ impl WriteBuffer {
max_size,
}
}
/// Get pending bytes count
pub fn len(&self) -> usize {
self.buffer.len() - self.position
}
/// Check if buffer is empty (all written)
pub fn is_empty(&self) -> bool {
self.position >= self.buffer.len()
}
/// Check if buffer is full
pub fn is_full(&self) -> bool {
self.buffer.len() >= self.max_size
}
/// Get remaining capacity
pub fn remaining_capacity(&self) -> usize {
self.max_size.saturating_sub(self.buffer.len())
}
/// Append data to buffer
pub fn extend(&mut self, data: &[u8]) -> Result<(), ()> {
if self.buffer.len() + data.len() > self.max_size {
@@ -291,23 +291,23 @@ impl WriteBuffer {
self.buffer.extend_from_slice(data);
Ok(())
}
/// Get slice of data to write
pub fn pending(&self) -> &[u8] {
&self.buffer[self.position..]
}
/// Advance position by n bytes (after successful write)
pub fn advance(&mut self, n: usize) {
self.position += n;
// If all data written, reset buffer
if self.position >= self.buffer.len() {
self.buffer.clear();
self.position = 0;
}
}
/// Clear the buffer
pub fn clear(&mut self) {
self.buffer.clear();
@@ -340,38 +340,38 @@ impl<const N: usize> HeaderBuffer<N> {
filled: 0,
}
}
/// Get slice for reading into
pub fn unfilled_mut(&mut self) -> &mut [u8] {
&mut self.data[self.filled..]
}
/// Advance filled count
pub fn advance(&mut self, n: usize) {
self.filled = (self.filled + n).min(N);
}
/// Check if completely filled
pub fn is_complete(&self) -> bool {
self.filled >= N
}
/// Get remaining bytes needed
pub fn remaining(&self) -> usize {
N - self.filled
}
/// Get filled bytes as slice
pub fn as_slice(&self) -> &[u8] {
&self.data[..self.filled]
}
/// Get complete buffer (panics if not complete)
pub fn as_array(&self) -> &[u8; N] {
assert!(self.is_complete());
&self.data
}
/// Take the buffer, resetting state
pub fn take(&mut self) -> [u8; N] {
let data = self.data;
@@ -379,7 +379,7 @@ impl<const N: usize> HeaderBuffer<N> {
self.filled = 0;
data
}
/// Reset to empty state
pub fn reset(&mut self) {
self.filled = 0;
@@ -406,17 +406,17 @@ impl YieldBuffer {
pub fn new(data: Bytes) -> Self {
Self { data, position: 0 }
}
/// Check if all data has been yielded
pub fn is_empty(&self) -> bool {
self.position >= self.data.len()
}
/// Get remaining bytes
pub fn remaining(&self) -> usize {
self.data.len() - self.position
}
/// Copy data to output slice, return bytes copied
pub fn copy_to(&mut self, dst: &mut [u8]) -> usize {
let available = &self.data[self.position..];
@@ -425,7 +425,7 @@ impl YieldBuffer {
self.position += to_copy;
to_copy
}
/// Get remaining data as slice
pub fn as_slice(&self) -> &[u8] {
&self.data[self.position..]
@@ -468,106 +468,106 @@ macro_rules! ready_or_pending {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_read_buffer_basic() {
let mut buf = ReadBuffer::with_target(10);
assert_eq!(buf.remaining(), 10);
assert!(!buf.is_complete());
buf.extend(b"hello");
assert_eq!(buf.len(), 5);
assert_eq!(buf.remaining(), 5);
assert!(!buf.is_complete());
buf.extend(b"world");
assert_eq!(buf.len(), 10);
assert!(buf.is_complete());
}
#[test]
fn test_read_buffer_take() {
let mut buf = ReadBuffer::new();
buf.extend(b"test data");
let data = buf.take();
assert_eq!(&data[..], b"test data");
assert!(buf.is_empty());
}
#[test]
fn test_write_buffer_basic() {
let mut buf = WriteBuffer::with_max_size(100);
assert!(buf.is_empty());
buf.extend(b"hello").unwrap();
assert_eq!(buf.len(), 5);
assert!(!buf.is_empty());
buf.advance(3);
assert_eq!(buf.len(), 2);
assert_eq!(buf.pending(), b"lo");
}
#[test]
fn test_write_buffer_overflow() {
let mut buf = WriteBuffer::with_max_size(10);
assert!(buf.extend(b"short").is_ok());
assert!(buf.extend(b"toolong").is_err());
}
#[test]
fn test_header_buffer() {
let mut buf = HeaderBuffer::<5>::new();
assert!(!buf.is_complete());
assert_eq!(buf.remaining(), 5);
buf.unfilled_mut()[..3].copy_from_slice(b"hel");
buf.advance(3);
assert_eq!(buf.remaining(), 2);
buf.unfilled_mut()[..2].copy_from_slice(b"lo");
buf.advance(2);
assert!(buf.is_complete());
assert_eq!(buf.as_array(), b"hello");
}
#[test]
fn test_yield_buffer() {
let mut buf = YieldBuffer::new(Bytes::from_static(b"hello world"));
let mut dst = [0u8; 5];
assert_eq!(buf.copy_to(&mut dst), 5);
assert_eq!(&dst, b"hello");
assert_eq!(buf.remaining(), 6);
let mut dst = [0u8; 10];
assert_eq!(buf.copy_to(&mut dst), 6);
assert_eq!(&dst[..6], b" world");
assert!(buf.is_empty());
}
#[test]
fn test_transition_map() {
let t: Transition<i32, String> = Transition::Complete("hello".to_string());
let t = t.map_output(|s| s.len());
match t {
Transition::Complete(5) => {}
_ => panic!("Expected Complete(5)"),
}
}
#[test]
fn test_poll_result() {
let r: PollResult<i32> = PollResult::Ready(42);
assert!(r.is_ready());
assert_eq!(r.ok(), Some(42));
let r: PollResult<i32> = PollResult::Eof;
assert!(r.is_eof());
assert_eq!(r.ok(), None);
}
}
}

View File

@@ -38,16 +38,13 @@ use bytes::{Bytes, BytesMut};
use std::io::{self, Error, ErrorKind, Result};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use super::state::{HeaderBuffer, StreamState, WriteBuffer, YieldBuffer};
use crate::protocol::constants::{
MAX_TLS_PLAINTEXT_SIZE,
TLS_VERSION,
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER,
TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT,
MAX_TLS_CIPHERTEXT_SIZE,
MAX_TLS_CIPHERTEXT_SIZE, MAX_TLS_PLAINTEXT_SIZE, TLS_RECORD_ALERT, TLS_RECORD_APPLICATION,
TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, TLS_VERSION,
};
use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer};
// ============= Constants =============
@@ -84,7 +81,11 @@ impl TlsRecordHeader {
let record_type = header[0];
let version = [header[1], header[2]];
let length = u16::from_be_bytes([header[3], header[4]]);
Some(Self { record_type, version, length })
Some(Self {
record_type,
version,
length,
})
}
/// Validate the header.
@@ -111,8 +112,7 @@ impl TlsRecordHeader {
ErrorKind::InvalidData,
format!(
"invalid TLS version for record type 0x{:02x}: {:02x?}",
self.record_type,
self.version
self.record_type, self.version
),
));
}
@@ -129,8 +129,7 @@ impl TlsRecordHeader {
ErrorKind::InvalidData,
format!(
"invalid TLS application data length: {} (min 1, max {})",
len,
MAX_TLS_CIPHERTEXT_SIZE
len, MAX_TLS_CIPHERTEXT_SIZE
),
));
}
@@ -160,8 +159,7 @@ impl TlsRecordHeader {
ErrorKind::InvalidData,
format!(
"invalid TLS handshake length: {} (min 4, max {})",
len,
MAX_TLS_PLAINTEXT_SIZE
len, MAX_TLS_PLAINTEXT_SIZE
),
));
}
@@ -212,14 +210,10 @@ enum TlsReaderState {
},
/// Have buffered data ready to yield to caller
Yielding {
buffer: YieldBuffer,
},
Yielding { buffer: YieldBuffer },
/// Stream encountered an error and cannot be used
Poisoned {
error: Option<io::Error>,
},
Poisoned { error: Option<io::Error> },
}
impl StreamState for TlsReaderState {
@@ -287,28 +281,41 @@ pub struct FakeTlsReader<R> {
impl<R> FakeTlsReader<R> {
pub fn new(upstream: R) -> Self {
Self { upstream, state: TlsReaderState::Idle }
Self {
upstream,
state: TlsReaderState::Idle,
}
}
pub fn get_ref(&self) -> &R { &self.upstream }
pub fn get_mut(&mut self) -> &mut R { &mut self.upstream }
pub fn into_inner(self) -> R { self.upstream }
pub fn get_ref(&self) -> &R {
&self.upstream
}
pub fn get_mut(&mut self) -> &mut R {
&mut self.upstream
}
pub fn into_inner(self) -> R {
self.upstream
}
pub fn into_inner_with_pending_plaintext(mut self) -> (R, Vec<u8>) {
let pending = match std::mem::replace(&mut self.state, TlsReaderState::Idle) {
TlsReaderState::Yielding { buffer } => buffer.as_slice().to_vec(),
TlsReaderState::ReadingBody { record_type, buffer, .. }
if record_type == TLS_RECORD_APPLICATION =>
{
buffer.to_vec()
}
TlsReaderState::ReadingBody {
record_type,
buffer,
..
} if record_type == TLS_RECORD_APPLICATION => buffer.to_vec(),
_ => Vec::new(),
};
(self.upstream, pending)
}
pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() }
pub fn state_name(&self) -> &'static str { self.state.state_name() }
pub fn is_poisoned(&self) -> bool {
self.state.is_poisoned()
}
pub fn state_name(&self) -> &'static str {
self.state.state_name()
}
fn poison(&mut self, error: io::Error) {
self.state = TlsReaderState::Poisoned { error: Some(error) };
@@ -316,9 +323,9 @@ impl<R> FakeTlsReader<R> {
fn take_poison_error(&mut self) -> io::Error {
match &mut self.state {
TlsReaderState::Poisoned { error } => error.take().unwrap_or_else(|| {
io::Error::other("stream previously poisoned")
}),
TlsReaderState::Poisoned { error } => error
.take()
.unwrap_or_else(|| io::Error::other("stream previously poisoned")),
_ => io::Error::other("stream not poisoned"),
}
}
@@ -353,9 +360,8 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
// Poisoned state: always return the stored error
TlsReaderState::Poisoned { error } => {
this.state = TlsReaderState::Poisoned { error: None };
let err = error.unwrap_or_else(|| {
io::Error::other("stream previously poisoned")
});
let err =
error.unwrap_or_else(|| io::Error::other("stream previously poisoned"));
return Poll::Ready(Err(err));
}
@@ -428,12 +434,20 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
}
// Read TLS payload
TlsReaderState::ReadingBody { record_type, length, mut buffer } => {
TlsReaderState::ReadingBody {
record_type,
length,
mut buffer,
} => {
let result = poll_read_body(&mut this.upstream, cx, &mut buffer, length);
match result {
BodyPollResult::Pending => {
this.state = TlsReaderState::ReadingBody { record_type, length, buffer };
this.state = TlsReaderState::ReadingBody {
record_type,
length,
buffer,
};
return Poll::Pending;
}
BodyPollResult::Error(e) => {
@@ -469,7 +483,10 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
TLS_RECORD_HANDSHAKE => {
// After FakeTLS handshake is done, we do not expect any Handshake records.
let err = Error::new(ErrorKind::InvalidData, "unexpected TLS handshake record");
let err = Error::new(
ErrorKind::InvalidData,
"unexpected TLS handshake record",
);
this.poison(Error::new(err.kind(), err.to_string()));
return Poll::Ready(Err(err));
}
@@ -528,7 +545,10 @@ fn poll_read_header<R: AsyncRead + Unpin>(
let header_bytes = *header.as_array();
match TlsRecordHeader::parse(&header_bytes) {
Some(h) => HeaderPollResult::Complete(h),
None => HeaderPollResult::Error(Error::new(ErrorKind::InvalidData, "failed to parse TLS header")),
None => HeaderPollResult::Error(Error::new(
ErrorKind::InvalidData,
"failed to parse TLS header",
)),
}
}
@@ -614,9 +634,7 @@ enum TlsWriterState {
},
/// Stream encountered an error and cannot be used
Poisoned {
error: Option<io::Error>,
},
Poisoned { error: Option<io::Error> },
}
impl StreamState for TlsWriterState {
@@ -652,15 +670,28 @@ pub struct FakeTlsWriter<W> {
impl<W> FakeTlsWriter<W> {
pub fn new(upstream: W) -> Self {
Self { upstream, state: TlsWriterState::Idle }
Self {
upstream,
state: TlsWriterState::Idle,
}
}
pub fn get_ref(&self) -> &W { &self.upstream }
pub fn get_mut(&mut self) -> &mut W { &mut self.upstream }
pub fn into_inner(self) -> W { self.upstream }
pub fn get_ref(&self) -> &W {
&self.upstream
}
pub fn get_mut(&mut self) -> &mut W {
&mut self.upstream
}
pub fn into_inner(self) -> W {
self.upstream
}
pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() }
pub fn state_name(&self) -> &'static str { self.state.state_name() }
pub fn is_poisoned(&self) -> bool {
self.state.is_poisoned()
}
pub fn state_name(&self) -> &'static str {
self.state.state_name()
}
pub fn has_pending(&self) -> bool {
matches!(&self.state, TlsWriterState::WritingRecord { record, .. } if !record.is_empty())
@@ -672,9 +703,9 @@ impl<W> FakeTlsWriter<W> {
fn take_poison_error(&mut self) -> io::Error {
match &mut self.state {
TlsWriterState::Poisoned { error } => error.take().unwrap_or_else(|| {
io::Error::other("stream previously poisoned")
}),
TlsWriterState::Poisoned { error } => error
.take()
.unwrap_or_else(|| io::Error::other("stream previously poisoned")),
_ => io::Error::other("stream not poisoned"),
}
}
@@ -725,11 +756,7 @@ impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
}
impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
let this = self.get_mut();
// Take ownership of state to avoid borrow conflicts.
@@ -738,17 +765,21 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
match state {
TlsWriterState::Poisoned { error } => {
this.state = TlsWriterState::Poisoned { error: None };
let err = error.unwrap_or_else(|| {
Error::other("stream previously poisoned")
});
let err = error.unwrap_or_else(|| Error::other("stream previously poisoned"));
return Poll::Ready(Err(err));
}
TlsWriterState::WritingRecord { mut record, payload_size } => {
TlsWriterState::WritingRecord {
mut record,
payload_size,
} => {
// Finish writing previous record before accepting new bytes.
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
FlushResult::Pending => {
this.state = TlsWriterState::WritingRecord { record, payload_size };
this.state = TlsWriterState::WritingRecord {
record,
payload_size,
};
return Poll::Pending;
}
FlushResult::Error(e) => {
@@ -780,9 +811,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
let record_data = Self::build_record(chunk);
match Pin::new(&mut this.upstream).poll_write(cx, &record_data) {
Poll::Ready(Ok(n)) if n == record_data.len() => {
Poll::Ready(Ok(chunk_size))
}
Poll::Ready(Ok(n)) if n == record_data.len() => Poll::Ready(Ok(chunk_size)),
Poll::Ready(Ok(n)) => {
// Partial write of the record: store remainder.
@@ -827,27 +856,29 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
match state {
TlsWriterState::Poisoned { error } => {
this.state = TlsWriterState::Poisoned { error: None };
let err = error.unwrap_or_else(|| {
Error::other("stream previously poisoned")
});
let err = error.unwrap_or_else(|| Error::other("stream previously poisoned"));
return Poll::Ready(Err(err));
}
TlsWriterState::WritingRecord { mut record, payload_size } => {
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
FlushResult::Pending => {
this.state = TlsWriterState::WritingRecord { record, payload_size };
return Poll::Pending;
}
FlushResult::Error(e) => {
this.poison(Error::new(e.kind(), e.to_string()));
return Poll::Ready(Err(e));
}
FlushResult::Complete(_) => {
this.state = TlsWriterState::Idle;
}
TlsWriterState::WritingRecord {
mut record,
payload_size,
} => match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
FlushResult::Pending => {
this.state = TlsWriterState::WritingRecord {
record,
payload_size,
};
return Poll::Pending;
}
}
FlushResult::Error(e) => {
this.poison(Error::new(e.kind(), e.to_string()));
return Poll::Ready(Err(e));
}
FlushResult::Complete(_) => {
this.state = TlsWriterState::Idle;
}
},
TlsWriterState::Idle => {
this.state = TlsWriterState::Idle;
@@ -863,7 +894,10 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
match state {
TlsWriterState::WritingRecord { mut record, payload_size: _ } => {
TlsWriterState::WritingRecord {
mut record,
payload_size: _,
} => {
// Best-effort flush (do not block shutdown forever).
let _ = Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record);
this.state = TlsWriterState::Idle;
@@ -905,10 +939,10 @@ mod size_adversarial_tests;
mod tests {
use super::*;
use std::collections::VecDeque;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
// ============= Test Helpers =============
/// Build a valid TLS Application Data record
fn build_tls_record(data: &[u8]) -> Vec<u8> {
let mut record = vec![
@@ -921,24 +955,25 @@ mod tests {
record.extend_from_slice(data);
record
}
/// Build a Change Cipher Spec record
fn build_ccs_record() -> Vec<u8> {
vec![
TLS_RECORD_CHANGE_CIPHER,
TLS_VERSION[0],
TLS_VERSION[1],
0x00, 0x01, // length = 1
0x01, // CCS byte
0x00,
0x01, // length = 1
0x01, // CCS byte
]
}
/// Mock reader that returns data in chunks
struct ChunkedReader {
data: VecDeque<u8>,
chunk_size: usize,
}
impl ChunkedReader {
fn new(data: &[u8], chunk_size: usize) -> Self {
Self {
@@ -947,7 +982,7 @@ mod tests {
}
}
}
impl AsyncRead for ChunkedReader {
fn poll_read(
mut self: Pin<&mut Self>,
@@ -957,92 +992,92 @@ mod tests {
if self.data.is_empty() {
return Poll::Ready(Ok(()));
}
let to_read = self.chunk_size.min(self.data.len()).min(buf.remaining());
for _ in 0..to_read {
if let Some(byte) = self.data.pop_front() {
buf.put_slice(&[byte]);
}
}
Poll::Ready(Ok(()))
}
}
// ============= FakeTlsReader Tests =============
#[tokio::test]
async fn test_tls_reader_single_record() {
let payload = b"Hello, TLS!";
let record = build_tls_record(payload);
let reader = ChunkedReader::new(&record, 100);
let mut tls_reader = FakeTlsReader::new(reader);
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
assert_eq!(&buf[..], payload);
}
#[tokio::test]
async fn test_tls_reader_multiple_records() {
let payload1 = b"First record";
let payload2 = b"Second record";
let mut data = build_tls_record(payload1);
data.extend_from_slice(&build_tls_record(payload2));
let reader = ChunkedReader::new(&data, 100);
let mut tls_reader = FakeTlsReader::new(reader);
let buf1 = tls_reader.read_exact(payload1.len()).await.unwrap();
assert_eq!(&buf1[..], payload1);
let buf2 = tls_reader.read_exact(payload2.len()).await.unwrap();
assert_eq!(&buf2[..], payload2);
}
#[tokio::test]
async fn test_tls_reader_partial_header() {
// Read header byte by byte
let payload = b"Test";
let record = build_tls_record(payload);
let reader = ChunkedReader::new(&record, 1); // 1 byte at a time!
let mut tls_reader = FakeTlsReader::new(reader);
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
assert_eq!(&buf[..], payload);
}
#[tokio::test]
async fn test_tls_reader_partial_body() {
let payload = b"This is a longer payload that will be read in parts";
let record = build_tls_record(payload);
let reader = ChunkedReader::new(&record, 7); // Awkward chunk size
let mut tls_reader = FakeTlsReader::new(reader);
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
assert_eq!(&buf[..], payload);
}
#[tokio::test]
async fn test_tls_reader_skip_ccs() {
// CCS record followed by application data
let mut data = build_ccs_record();
let payload = b"After CCS";
data.extend_from_slice(&build_tls_record(payload));
let reader = ChunkedReader::new(&data, 100);
let mut tls_reader = FakeTlsReader::new(reader);
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
assert_eq!(&buf[..], payload);
}
#[tokio::test]
async fn test_tls_reader_multiple_ccs() {
// Multiple CCS records
@@ -1050,127 +1085,127 @@ mod tests {
data.extend_from_slice(&build_ccs_record());
let payload = b"After multiple CCS";
data.extend_from_slice(&build_tls_record(payload));
let reader = ChunkedReader::new(&data, 3); // Small chunks
let mut tls_reader = FakeTlsReader::new(reader);
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
assert_eq!(&buf[..], payload);
}
#[tokio::test]
async fn test_tls_reader_eof() {
let reader = ChunkedReader::new(&[], 100);
let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; 10];
let read = tls_reader.read(&mut buf).await.unwrap();
assert_eq!(read, 0);
}
#[tokio::test]
async fn test_tls_reader_state_names() {
let reader = ChunkedReader::new(&[], 100);
let tls_reader = FakeTlsReader::new(reader);
assert_eq!(tls_reader.state_name(), "Idle");
assert!(!tls_reader.is_poisoned());
}
// ============= FakeTlsWriter Tests =============
#[tokio::test]
async fn test_tls_writer_single_write() {
let (client, mut server) = duplex(4096);
let mut writer = FakeTlsWriter::new(client);
let payload = b"Hello, TLS!";
writer.write_all(payload).await.unwrap();
writer.flush().await.unwrap();
// Read the TLS record from server
let mut header = [0u8; 5];
server.read_exact(&mut header).await.unwrap();
assert_eq!(header[0], TLS_RECORD_APPLICATION);
assert_eq!(&header[1..3], &TLS_VERSION);
let length = u16::from_be_bytes([header[3], header[4]]) as usize;
assert_eq!(length, payload.len());
let mut body = vec![0u8; length];
server.read_exact(&mut body).await.unwrap();
assert_eq!(&body, payload);
}
#[tokio::test]
async fn test_tls_writer_large_data_chunking() {
let (client, mut server) = duplex(65536);
let mut writer = FakeTlsWriter::new(client);
// Write data larger than MAX_TLS_PAYLOAD
let payload: Vec<u8> = (0..20000).map(|i| (i % 256) as u8).collect();
writer.write_all(&payload).await.unwrap();
writer.flush().await.unwrap();
// Read back - should be multiple records
let mut received = Vec::new();
let mut records_count = 0;
while received.len() < payload.len() {
let mut header = [0u8; 5];
if server.read_exact(&mut header).await.is_err() {
break;
}
assert_eq!(header[0], TLS_RECORD_APPLICATION);
records_count += 1;
let length = u16::from_be_bytes([header[3], header[4]]) as usize;
assert!(length <= MAX_TLS_PAYLOAD);
let mut body = vec![0u8; length];
server.read_exact(&mut body).await.unwrap();
received.extend_from_slice(&body);
}
assert_eq!(received, payload);
assert!(records_count >= 2); // Should have multiple records
}
#[tokio::test]
async fn test_tls_stream_roundtrip() {
let (client, server) = duplex(4096);
let mut writer = FakeTlsWriter::new(client);
let mut reader = FakeTlsReader::new(server);
let original = b"Hello, fake TLS!";
writer.write_all_tls(original).await.unwrap();
writer.flush().await.unwrap();
let received = reader.read_exact(original.len()).await.unwrap();
assert_eq!(&received[..], original);
}
#[tokio::test]
async fn test_tls_stream_roundtrip_large() {
let (client, server) = duplex(4096);
let mut writer = FakeTlsWriter::new(client);
let mut reader = FakeTlsReader::new(server);
let original: Vec<u8> = (0..50000).map(|i| (i % 256) as u8).collect();
// Write in background
let write_data = original.clone();
let write_handle = tokio::spawn(async move {
writer.write_all_tls(&write_data).await.unwrap();
writer.shutdown().await.unwrap();
});
// Read
let mut received = Vec::new();
let mut buf = vec![0u8; 1024];
@@ -1181,87 +1216,95 @@ mod tests {
}
received.extend_from_slice(&buf[..n]);
}
write_handle.await.unwrap();
assert_eq!(received, original);
}
#[tokio::test]
async fn test_tls_writer_state_names() {
let (client, _server) = duplex(4096);
let writer = FakeTlsWriter::new(client);
assert_eq!(writer.state_name(), "Idle");
assert!(!writer.is_poisoned());
assert!(!writer.has_pending());
}
// ============= Error Handling Tests =============
#[tokio::test]
async fn test_tls_reader_invalid_version() {
let invalid_record = vec![
TLS_RECORD_APPLICATION,
0x04, 0x00, // Invalid version
0x00, 0x05, // length = 5
0x01, 0x02, 0x03, 0x04, 0x05,
0x04,
0x00, // Invalid version
0x00,
0x05, // length = 5
0x01,
0x02,
0x03,
0x04,
0x05,
];
let reader = ChunkedReader::new(&invalid_record, 100);
let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; 5];
let result = tls_reader.read(&mut buf).await;
assert!(result.is_err());
assert!(tls_reader.is_poisoned());
}
#[tokio::test]
async fn test_tls_reader_unexpected_eof_header() {
// Partial header
let partial = vec![TLS_RECORD_APPLICATION, 0x03];
let reader = ChunkedReader::new(&partial, 100);
let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; 10];
let result = tls_reader.read(&mut buf).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_tls_reader_unexpected_eof_body() {
// Valid header but truncated body
let mut record = vec![
TLS_RECORD_APPLICATION,
TLS_VERSION[0], TLS_VERSION[1],
0x00, 0x10, // length = 16
TLS_VERSION[0],
TLS_VERSION[1],
0x00,
0x10, // length = 16
];
record.extend_from_slice(&[0u8; 8]); // Only 8 bytes of body
let reader = ChunkedReader::new(&record, 100);
let mut tls_reader = FakeTlsReader::new(reader);
let mut buf = vec![0u8; 16];
let result = tls_reader.read(&mut buf).await;
assert!(result.is_err());
}
// ============= Header Parsing Tests =============
#[test]
fn test_tls_record_header_parse() {
let header = [0x17, 0x03, 0x03, 0x01, 0x00];
let parsed = TlsRecordHeader::parse(&header).unwrap();
assert_eq!(parsed.record_type, TLS_RECORD_APPLICATION);
assert_eq!(parsed.version, TLS_VERSION);
assert_eq!(parsed.length, 256);
}
#[test]
fn test_tls_record_header_validate() {
let valid = TlsRecordHeader {
@@ -1270,14 +1313,14 @@ mod tests {
length: 100,
};
assert!(valid.validate().is_ok());
let invalid_version = TlsRecordHeader {
record_type: TLS_RECORD_APPLICATION,
version: [0x04, 0x00],
length: 100,
};
assert!(invalid_version.validate().is_err());
let too_large = TlsRecordHeader {
record_type: TLS_RECORD_APPLICATION,
version: TLS_VERSION,
@@ -1285,7 +1328,7 @@ mod tests {
};
assert!(too_large.validate().is_err());
}
#[test]
fn test_tls_record_header_to_bytes() {
let header = TlsRecordHeader {
@@ -1293,7 +1336,7 @@ mod tests {
version: TLS_VERSION,
length: 0x1234,
};
let bytes = header.to_bytes();
assert_eq!(bytes, [0x17, 0x03, 0x03, 0x12, 0x34]);
}

View File

@@ -13,8 +13,7 @@ fn reading_body_pending_application_plaintext_is_preserved_on_into_inner() {
let (_inner, pending) = reader.into_inner_with_pending_plaintext();
assert_eq!(
pending,
sample,
pending, sample,
"partial application-data body must survive into fallback path"
);
}
@@ -59,7 +58,10 @@ fn partial_header_state_does_not_produce_plaintext() {
reader.state = TlsReaderState::ReadingHeader { header };
let (_inner, pending) = reader.into_inner_with_pending_plaintext();
assert!(pending.is_empty(), "partial header bytes are not plaintext payload");
assert!(
pending.is_empty(),
"partial header bytes are not plaintext payload"
);
}
#[test]
@@ -83,7 +85,10 @@ fn adversarial_poisoned_state_never_leaks_pending_bytes() {
};
let (_inner, pending) = reader.into_inner_with_pending_plaintext();
assert!(pending.is_empty(), "poisoned state must fail-closed for fallback payload");
assert!(
pending.is_empty(),
"poisoned state must fail-closed for fallback payload"
);
}
#[test]
@@ -101,7 +106,10 @@ fn stress_large_application_fragment_survives_state_extraction() {
};
let (_inner, pending) = reader.into_inner_with_pending_plaintext();
assert_eq!(pending, payload, "large pending application plaintext must be preserved exactly");
assert_eq!(
pending, payload,
"large pending application plaintext must be preserved exactly"
);
}
#[test]

View File

@@ -308,52 +308,228 @@ macro_rules! expect_accept {
};
}
expect_reject!(appdata_zero_len_must_be_rejected, TLS_RECORD_APPLICATION, TLS_VERSION, 0);
expect_accept!(appdata_one_len_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, 1);
expect_accept!(appdata_small_len_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, 32);
expect_accept!(appdata_medium_len_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, 1024);
expect_accept!(appdata_plaintext_limit_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, MAX_TLS_PLAINTEXT_SIZE as u16);
expect_accept!(appdata_ciphertext_limit_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, MAX_TLS_CIPHERTEXT_SIZE as u16);
expect_reject!(appdata_ciphertext_plus_one_must_be_rejected, TLS_RECORD_APPLICATION, TLS_VERSION, (MAX_TLS_CIPHERTEXT_SIZE as u16) + 1);
expect_reject!(
appdata_zero_len_must_be_rejected,
TLS_RECORD_APPLICATION,
TLS_VERSION,
0
);
expect_accept!(
appdata_one_len_is_accepted,
TLS_RECORD_APPLICATION,
TLS_VERSION,
1
);
expect_accept!(
appdata_small_len_is_accepted,
TLS_RECORD_APPLICATION,
TLS_VERSION,
32
);
expect_accept!(
appdata_medium_len_is_accepted,
TLS_RECORD_APPLICATION,
TLS_VERSION,
1024
);
expect_accept!(
appdata_plaintext_limit_is_accepted,
TLS_RECORD_APPLICATION,
TLS_VERSION,
MAX_TLS_PLAINTEXT_SIZE as u16
);
expect_accept!(
appdata_ciphertext_limit_is_accepted,
TLS_RECORD_APPLICATION,
TLS_VERSION,
MAX_TLS_CIPHERTEXT_SIZE as u16
);
expect_reject!(
appdata_ciphertext_plus_one_must_be_rejected,
TLS_RECORD_APPLICATION,
TLS_VERSION,
(MAX_TLS_CIPHERTEXT_SIZE as u16) + 1
);
expect_reject!(appdata_tls10_header_len_one_must_be_rejected, TLS_RECORD_APPLICATION, [0x03, 0x01], 1);
expect_reject!(appdata_tls10_header_medium_must_be_rejected, TLS_RECORD_APPLICATION, [0x03, 0x01], 1024);
expect_reject!(appdata_tls10_header_ciphertext_limit_must_be_rejected, TLS_RECORD_APPLICATION, [0x03, 0x01], MAX_TLS_CIPHERTEXT_SIZE as u16);
expect_reject!(
appdata_tls10_header_len_one_must_be_rejected,
TLS_RECORD_APPLICATION,
[0x03, 0x01],
1
);
expect_reject!(
appdata_tls10_header_medium_must_be_rejected,
TLS_RECORD_APPLICATION,
[0x03, 0x01],
1024
);
expect_reject!(
appdata_tls10_header_ciphertext_limit_must_be_rejected,
TLS_RECORD_APPLICATION,
[0x03, 0x01],
MAX_TLS_CIPHERTEXT_SIZE as u16
);
expect_reject!(ccs_tls10_header_len_one_must_be_rejected, TLS_RECORD_CHANGE_CIPHER, [0x03, 0x01], 1);
expect_reject!(ccs_tls10_header_len_zero_must_be_rejected, TLS_RECORD_CHANGE_CIPHER, [0x03, 0x01], 0);
expect_reject!(ccs_tls10_header_len_two_must_be_rejected, TLS_RECORD_CHANGE_CIPHER, [0x03, 0x01], 2);
expect_reject!(
ccs_tls10_header_len_one_must_be_rejected,
TLS_RECORD_CHANGE_CIPHER,
[0x03, 0x01],
1
);
expect_reject!(
ccs_tls10_header_len_zero_must_be_rejected,
TLS_RECORD_CHANGE_CIPHER,
[0x03, 0x01],
0
);
expect_reject!(
ccs_tls10_header_len_two_must_be_rejected,
TLS_RECORD_CHANGE_CIPHER,
[0x03, 0x01],
2
);
expect_reject!(alert_tls10_header_len_two_must_be_rejected, TLS_RECORD_ALERT, [0x03, 0x01], 2);
expect_reject!(alert_tls10_header_len_one_must_be_rejected, TLS_RECORD_ALERT, [0x03, 0x01], 1);
expect_reject!(alert_tls10_header_len_three_must_be_rejected, TLS_RECORD_ALERT, [0x03, 0x01], 3);
expect_reject!(
alert_tls10_header_len_two_must_be_rejected,
TLS_RECORD_ALERT,
[0x03, 0x01],
2
);
expect_reject!(
alert_tls10_header_len_one_must_be_rejected,
TLS_RECORD_ALERT,
[0x03, 0x01],
1
);
expect_reject!(
alert_tls10_header_len_three_must_be_rejected,
TLS_RECORD_ALERT,
[0x03, 0x01],
3
);
expect_accept!(handshake_tls10_header_min_len_is_accepted, TLS_RECORD_HANDSHAKE, [0x03, 0x01], 4);
expect_accept!(handshake_tls10_header_plaintext_limit_is_accepted, TLS_RECORD_HANDSHAKE, [0x03, 0x01], MAX_TLS_PLAINTEXT_SIZE as u16);
expect_reject!(handshake_tls10_header_too_small_must_be_rejected, TLS_RECORD_HANDSHAKE, [0x03, 0x01], 3);
expect_reject!(handshake_tls10_header_too_large_must_be_rejected, TLS_RECORD_HANDSHAKE, [0x03, 0x01], (MAX_TLS_PLAINTEXT_SIZE as u16) + 1);
expect_accept!(
handshake_tls10_header_min_len_is_accepted,
TLS_RECORD_HANDSHAKE,
[0x03, 0x01],
4
);
expect_accept!(
handshake_tls10_header_plaintext_limit_is_accepted,
TLS_RECORD_HANDSHAKE,
[0x03, 0x01],
MAX_TLS_PLAINTEXT_SIZE as u16
);
expect_reject!(
handshake_tls10_header_too_small_must_be_rejected,
TLS_RECORD_HANDSHAKE,
[0x03, 0x01],
3
);
expect_reject!(
handshake_tls10_header_too_large_must_be_rejected,
TLS_RECORD_HANDSHAKE,
[0x03, 0x01],
(MAX_TLS_PLAINTEXT_SIZE as u16) + 1
);
expect_reject!(unknown_type_tls13_zero_must_be_rejected, 0x00, TLS_VERSION, 0);
expect_reject!(unknown_type_tls13_small_must_be_rejected, 0x13, TLS_VERSION, 32);
expect_reject!(unknown_type_tls13_large_must_be_rejected, 0xfe, TLS_VERSION, MAX_TLS_CIPHERTEXT_SIZE as u16);
expect_reject!(unknown_type_tls10_small_must_be_rejected, 0x13, [0x03, 0x01], 32);
expect_reject!(
unknown_type_tls13_zero_must_be_rejected,
0x00,
TLS_VERSION,
0
);
expect_reject!(
unknown_type_tls13_small_must_be_rejected,
0x13,
TLS_VERSION,
32
);
expect_reject!(
unknown_type_tls13_large_must_be_rejected,
0xfe,
TLS_VERSION,
MAX_TLS_CIPHERTEXT_SIZE as u16
);
expect_reject!(
unknown_type_tls10_small_must_be_rejected,
0x13,
[0x03, 0x01],
32
);
expect_reject!(appdata_invalid_version_0302_must_be_rejected, TLS_RECORD_APPLICATION, [0x03, 0x02], 128);
expect_reject!(handshake_invalid_version_0302_must_be_rejected, TLS_RECORD_HANDSHAKE, [0x03, 0x02], 128);
expect_reject!(alert_invalid_version_0302_must_be_rejected, TLS_RECORD_ALERT, [0x03, 0x02], 2);
expect_reject!(ccs_invalid_version_0302_must_be_rejected, TLS_RECORD_CHANGE_CIPHER, [0x03, 0x02], 1);
expect_reject!(
appdata_invalid_version_0302_must_be_rejected,
TLS_RECORD_APPLICATION,
[0x03, 0x02],
128
);
expect_reject!(
handshake_invalid_version_0302_must_be_rejected,
TLS_RECORD_HANDSHAKE,
[0x03, 0x02],
128
);
expect_reject!(
alert_invalid_version_0302_must_be_rejected,
TLS_RECORD_ALERT,
[0x03, 0x02],
2
);
expect_reject!(
ccs_invalid_version_0302_must_be_rejected,
TLS_RECORD_CHANGE_CIPHER,
[0x03, 0x02],
1
);
expect_reject!(appdata_invalid_version_0304_must_be_rejected, TLS_RECORD_APPLICATION, [0x03, 0x04], 128);
expect_reject!(handshake_invalid_version_0304_must_be_rejected, TLS_RECORD_HANDSHAKE, [0x03, 0x04], 128);
expect_reject!(alert_invalid_version_0304_must_be_rejected, TLS_RECORD_ALERT, [0x03, 0x04], 2);
expect_reject!(ccs_invalid_version_0304_must_be_rejected, TLS_RECORD_CHANGE_CIPHER, [0x03, 0x04], 1);
expect_reject!(
appdata_invalid_version_0304_must_be_rejected,
TLS_RECORD_APPLICATION,
[0x03, 0x04],
128
);
expect_reject!(
handshake_invalid_version_0304_must_be_rejected,
TLS_RECORD_HANDSHAKE,
[0x03, 0x04],
128
);
expect_reject!(
alert_invalid_version_0304_must_be_rejected,
TLS_RECORD_ALERT,
[0x03, 0x04],
2
);
expect_reject!(
ccs_invalid_version_0304_must_be_rejected,
TLS_RECORD_CHANGE_CIPHER,
[0x03, 0x04],
1
);
expect_accept!(handshake_tls13_len_5_is_accepted, TLS_RECORD_HANDSHAKE, TLS_VERSION, 5);
expect_accept!(appdata_tls13_len_16385_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, (MAX_TLS_PLAINTEXT_SIZE as u16) + 1);
expect_accept!(
handshake_tls13_len_5_is_accepted,
TLS_RECORD_HANDSHAKE,
TLS_VERSION,
5
);
expect_accept!(
appdata_tls13_len_16385_is_accepted,
TLS_RECORD_APPLICATION,
TLS_VERSION,
(MAX_TLS_PLAINTEXT_SIZE as u16) + 1
);
#[test]
fn matrix_version_policy_is_strict_and_deterministic() {
let versions = [[0x03, 0x01], TLS_VERSION, [0x03, 0x02], [0x03, 0x04], [0x00, 0x00]];
let versions = [
[0x03, 0x01],
TLS_VERSION,
[0x03, 0x02],
[0x03, 0x04],
[0x00, 0x00],
];
let record_types = [
TLS_RECORD_APPLICATION,
TLS_RECORD_CHANGE_CIPHER,
@@ -389,22 +565,62 @@ fn matrix_version_policy_is_strict_and_deterministic() {
#[test]
fn appdata_partition_property_holds_for_all_u16_edges() {
for len in [0u16, 1, 2, 3, 64, 255, 1024, 4096, 8192, 16_384, 16_385, 16_640, 16_641, u16::MAX] {
for len in [
0u16,
1,
2,
3,
64,
255,
1024,
4096,
8192,
16_384,
16_385,
16_640,
16_641,
u16::MAX,
] {
let accepted = validates(TLS_RECORD_APPLICATION, TLS_VERSION, len);
let expected = len >= 1 && usize::from(len) <= MAX_TLS_CIPHERTEXT_SIZE;
assert_eq!(accepted, expected, "unexpected appdata decision for len={len}");
assert_eq!(
accepted, expected,
"unexpected appdata decision for len={len}"
);
}
}
#[test]
fn handshake_partition_property_holds_for_all_u16_edges() {
for len in [0u16, 1, 2, 3, 4, 5, 64, 255, 1024, 4096, 8192, 16_383, 16_384, 16_385, u16::MAX] {
for len in [
0u16,
1,
2,
3,
4,
5,
64,
255,
1024,
4096,
8192,
16_383,
16_384,
16_385,
u16::MAX,
] {
let accepted_tls13 = validates(TLS_RECORD_HANDSHAKE, TLS_VERSION, len);
let accepted_tls10 = validates(TLS_RECORD_HANDSHAKE, [0x03, 0x01], len);
let expected = (4..=MAX_TLS_PLAINTEXT_SIZE).contains(&usize::from(len));
assert_eq!(accepted_tls13, expected, "TLS1.3 handshake mismatch for len={len}");
assert_eq!(accepted_tls10, expected, "TLS1.0 compat handshake mismatch for len={len}");
assert_eq!(
accepted_tls13, expected,
"TLS1.3 handshake mismatch for len={len}"
);
assert_eq!(
accepted_tls10, expected,
"TLS1.0 compat handshake mismatch for len={len}"
);
}
}
@@ -419,14 +635,24 @@ fn control_record_exact_lengths_are_enforced_under_fuzzed_lengths() {
let alert_ok = validates(TLS_RECORD_ALERT, TLS_VERSION, len);
assert_eq!(ccs_ok, len == 1, "ccs length gate mismatch for len={len}");
assert_eq!(alert_ok, len == 2, "alert length gate mismatch for len={len}");
assert_eq!(
alert_ok,
len == 2,
"alert length gate mismatch for len={len}"
);
}
}
#[test]
fn unknown_record_types_never_validate_under_supported_versions() {
for record_type in 0u8..=255 {
if matches!(record_type, TLS_RECORD_APPLICATION | TLS_RECORD_CHANGE_CIPHER | TLS_RECORD_ALERT | TLS_RECORD_HANDSHAKE) {
if matches!(
record_type,
TLS_RECORD_APPLICATION
| TLS_RECORD_CHANGE_CIPHER
| TLS_RECORD_ALERT
| TLS_RECORD_HANDSHAKE
) {
continue;
}
@@ -458,9 +684,15 @@ async fn reader_rejects_tls10_appdata_header_before_payload_processing() {
#[tokio::test]
async fn reader_rejects_zero_len_appdata_record() {
let (mut tx, rx) = tokio::io::duplex(128);
tx.write_all(&[TLS_RECORD_APPLICATION, TLS_VERSION[0], TLS_VERSION[1], 0x00, 0x00])
.await
.unwrap();
tx.write_all(&[
TLS_RECORD_APPLICATION,
TLS_VERSION[0],
TLS_VERSION[1],
0x00,
0x00,
])
.await
.unwrap();
tx.shutdown().await.unwrap();
let mut reader = FakeTlsReader::new(rx);
@@ -472,9 +704,16 @@ async fn reader_rejects_zero_len_appdata_record() {
#[tokio::test]
async fn reader_accepts_single_byte_tls13_appdata_and_yields_payload() {
let (mut tx, rx) = tokio::io::duplex(128);
tx.write_all(&[TLS_RECORD_APPLICATION, TLS_VERSION[0], TLS_VERSION[1], 0x00, 0x01, 0x5A])
.await
.unwrap();
tx.write_all(&[
TLS_RECORD_APPLICATION,
TLS_VERSION[0],
TLS_VERSION[1],
0x00,
0x01,
0x5A,
])
.await
.unwrap();
tx.shutdown().await.unwrap();
let mut reader = FakeTlsReader::new(rx);

View File

@@ -23,12 +23,12 @@ impl FrameMeta {
pub fn new() -> Self {
Self::default()
}
pub fn with_quickack(mut self) -> Self {
self.quickack = true;
self
}
pub fn with_simple_ack(mut self) -> Self {
self.simple_ack = true;
self
@@ -48,10 +48,10 @@ pub enum ReadFrameResult {
pub trait LayeredStream<U> {
/// Get reference to upstream
fn upstream(&self) -> &U;
/// Get mutable reference to upstream
fn upstream_mut(&mut self) -> &mut U;
/// Consume self and return upstream
fn into_upstream(self) -> U;
}
@@ -65,7 +65,7 @@ impl<R> ReadHalf<R> {
pub fn new(inner: R) -> Self {
Self { inner }
}
pub fn into_inner(self) -> R {
self.inner
}
@@ -90,7 +90,7 @@ impl<W> WriteHalf<W> {
pub fn new(inner: W) -> Self {
Self { inner }
}
pub fn into_inner(self) -> W {
self.inner
}
@@ -104,12 +104,12 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for WriteHalf<W> {
) -> Poll<Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
}