mirror of
https://github.com/telemt/telemt.git
synced 2026-04-25 14:34:10 +03:00
Format
This commit is contained in:
@@ -8,8 +8,8 @@
|
||||
use bytes::BytesMut;
|
||||
use crossbeam_queue::ArrayQueue;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
// ============= Configuration =============
|
||||
|
||||
@@ -42,7 +42,7 @@ impl BufferPool {
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(DEFAULT_BUFFER_SIZE, DEFAULT_MAX_BUFFERS)
|
||||
}
|
||||
|
||||
|
||||
/// Create a buffer pool with custom configuration
|
||||
pub fn with_config(buffer_size: usize, max_buffers: usize) -> Self {
|
||||
Self {
|
||||
@@ -54,7 +54,7 @@ impl BufferPool {
|
||||
hits: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Get a buffer from the pool, or create a new one if empty
|
||||
pub fn get(self: &Arc<Self>) -> PooledBuffer {
|
||||
match self.buffers.pop() {
|
||||
@@ -76,7 +76,7 @@ impl BufferPool {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Try to get a buffer, returns None if pool is empty
|
||||
pub fn try_get(self: &Arc<Self>) -> Option<PooledBuffer> {
|
||||
self.buffers.pop().map(|mut buffer| {
|
||||
@@ -88,12 +88,12 @@ impl BufferPool {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
/// Return a buffer to the pool
|
||||
fn return_buffer(&self, mut buffer: BytesMut) {
|
||||
// Clear the buffer but keep capacity
|
||||
buffer.clear();
|
||||
|
||||
|
||||
// Only return if we haven't exceeded max and buffer is right size
|
||||
if buffer.capacity() >= self.buffer_size {
|
||||
// Try to push to pool, if full just drop
|
||||
@@ -103,7 +103,7 @@ impl BufferPool {
|
||||
// Actually we don't decrement here because the buffer might have been
|
||||
// grown beyond our size - we just let it go
|
||||
}
|
||||
|
||||
|
||||
/// Get pool statistics
|
||||
pub fn stats(&self) -> PoolStats {
|
||||
PoolStats {
|
||||
@@ -115,17 +115,21 @@ impl BufferPool {
|
||||
misses: self.misses.load(Ordering::Relaxed),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Get buffer size
|
||||
pub fn buffer_size(&self) -> usize {
|
||||
self.buffer_size
|
||||
}
|
||||
|
||||
|
||||
/// Preallocate buffers to fill the pool
|
||||
pub fn preallocate(&self, count: usize) {
|
||||
let to_alloc = count.min(self.max_buffers);
|
||||
for _ in 0..to_alloc {
|
||||
if self.buffers.push(BytesMut::with_capacity(self.buffer_size)).is_err() {
|
||||
if self
|
||||
.buffers
|
||||
.push(BytesMut::with_capacity(self.buffer_size))
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
self.allocated.fetch_add(1, Ordering::Relaxed);
|
||||
@@ -183,22 +187,22 @@ impl PooledBuffer {
|
||||
pub fn take(mut self) -> BytesMut {
|
||||
self.buffer.take().unwrap()
|
||||
}
|
||||
|
||||
|
||||
/// Get the capacity of the buffer
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.buffer.as_ref().map(|b| b.capacity()).unwrap_or(0)
|
||||
}
|
||||
|
||||
|
||||
/// Check if buffer is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.buffer.as_ref().map(|b| b.is_empty()).unwrap_or(true)
|
||||
}
|
||||
|
||||
|
||||
/// Get the length of data in buffer
|
||||
pub fn len(&self) -> usize {
|
||||
self.buffer.as_ref().map(|b| b.len()).unwrap_or(0)
|
||||
}
|
||||
|
||||
|
||||
/// Clear the buffer
|
||||
pub fn clear(&mut self) {
|
||||
if let Some(ref mut b) = self.buffer {
|
||||
@@ -209,7 +213,7 @@ impl PooledBuffer {
|
||||
|
||||
impl Deref for PooledBuffer {
|
||||
type Target = BytesMut;
|
||||
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.buffer.as_ref().expect("buffer taken")
|
||||
}
|
||||
@@ -259,7 +263,7 @@ impl<'a> ScopedBuffer<'a> {
|
||||
|
||||
impl<'a> Deref for ScopedBuffer<'a> {
|
||||
type Target = BytesMut;
|
||||
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.buffer.deref()
|
||||
}
|
||||
@@ -280,108 +284,108 @@ impl<'a> Drop for ScopedBuffer<'a> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_pool_basic() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||
|
||||
|
||||
// Get a buffer
|
||||
let mut buf1 = pool.get();
|
||||
buf1.extend_from_slice(b"hello");
|
||||
assert_eq!(&buf1[..], b"hello");
|
||||
|
||||
|
||||
// Drop returns to pool
|
||||
drop(buf1);
|
||||
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.pooled, 1);
|
||||
assert_eq!(stats.hits, 0);
|
||||
assert_eq!(stats.misses, 1);
|
||||
|
||||
|
||||
// Get again - should reuse
|
||||
let buf2 = pool.get();
|
||||
assert!(buf2.is_empty()); // Buffer was cleared
|
||||
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.pooled, 0);
|
||||
assert_eq!(stats.hits, 1);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_pool_multiple_buffers() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||
|
||||
|
||||
// Get multiple buffers
|
||||
let buf1 = pool.get();
|
||||
let buf2 = pool.get();
|
||||
let buf3 = pool.get();
|
||||
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.allocated, 3);
|
||||
assert_eq!(stats.pooled, 0);
|
||||
|
||||
|
||||
// Return all
|
||||
drop(buf1);
|
||||
drop(buf2);
|
||||
drop(buf3);
|
||||
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.pooled, 3);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_pool_overflow() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 2));
|
||||
|
||||
|
||||
// Get 3 buffers (more than max)
|
||||
let buf1 = pool.get();
|
||||
let buf2 = pool.get();
|
||||
let buf3 = pool.get();
|
||||
|
||||
|
||||
// Return all - only 2 should be pooled
|
||||
drop(buf1);
|
||||
drop(buf2);
|
||||
drop(buf3);
|
||||
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.pooled, 2);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_pool_take() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||
|
||||
|
||||
let mut buf = pool.get();
|
||||
buf.extend_from_slice(b"data");
|
||||
|
||||
|
||||
// Take ownership, buffer should not return to pool
|
||||
let taken = buf.take();
|
||||
assert_eq!(&taken[..], b"data");
|
||||
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.pooled, 0);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_pool_preallocate() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||
pool.preallocate(5);
|
||||
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.pooled, 5);
|
||||
assert_eq!(stats.allocated, 5);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_pool_try_get() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||
|
||||
|
||||
// Pool is empty, try_get returns None
|
||||
assert!(pool.try_get().is_none());
|
||||
|
||||
|
||||
// Add a buffer to pool
|
||||
pool.preallocate(1);
|
||||
|
||||
|
||||
// Now try_get should succeed once while the buffer is held
|
||||
let buf = pool.try_get();
|
||||
assert!(buf.is_some());
|
||||
@@ -391,50 +395,50 @@ mod tests {
|
||||
drop(buf);
|
||||
assert!(pool.try_get().is_some());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_hit_rate() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||
|
||||
|
||||
// First get is a miss
|
||||
let buf1 = pool.get();
|
||||
drop(buf1);
|
||||
|
||||
|
||||
// Second get is a hit
|
||||
let buf2 = pool.get();
|
||||
drop(buf2);
|
||||
|
||||
|
||||
// Third get is a hit
|
||||
let _buf3 = pool.get();
|
||||
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.hits, 2);
|
||||
assert_eq!(stats.misses, 1);
|
||||
assert!((stats.hit_rate() - 66.67).abs() < 1.0);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_scoped_buffer() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||
let mut buf = pool.get();
|
||||
|
||||
|
||||
{
|
||||
let mut scoped = ScopedBuffer::new(&mut buf);
|
||||
scoped.extend_from_slice(b"scoped data");
|
||||
assert_eq!(&scoped[..], b"scoped data");
|
||||
}
|
||||
|
||||
|
||||
// After scoped is dropped, buffer is cleared
|
||||
assert!(buf.is_empty());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_concurrent_access() {
|
||||
use std::thread;
|
||||
|
||||
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 100));
|
||||
let mut handles = vec![];
|
||||
|
||||
|
||||
for _ in 0..10 {
|
||||
let pool_clone = Arc::clone(&pool);
|
||||
handles.push(thread::spawn(move || {
|
||||
@@ -445,11 +449,11 @@ mod tests {
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
|
||||
let stats = pool.stats();
|
||||
// All buffers should be returned
|
||||
assert!(stats.pooled > 0);
|
||||
|
||||
@@ -37,7 +37,7 @@
|
||||
//!
|
||||
//! Backpressure
|
||||
//! - pending ciphertext buffer is bounded (configurable per connection)
|
||||
//! - pending is full and upstream is pending
|
||||
//! - pending is full and upstream is pending
|
||||
//! -> poll_write returns Poll::Pending
|
||||
//! -> do not accept any plaintext
|
||||
//!
|
||||
@@ -59,8 +59,8 @@ use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tracing::{debug, trace};
|
||||
|
||||
use crate::crypto::AesCtr;
|
||||
use super::state::{StreamState, YieldBuffer};
|
||||
use crate::crypto::AesCtr;
|
||||
|
||||
// ============= Constants =============
|
||||
|
||||
@@ -152,9 +152,9 @@ impl<R> CryptoReader<R> {
|
||||
|
||||
fn take_poison_error(&mut self) -> io::Error {
|
||||
match &mut self.state {
|
||||
CryptoReaderState::Poisoned { error } => error.take().unwrap_or_else(|| {
|
||||
io::Error::other("stream previously poisoned")
|
||||
}),
|
||||
CryptoReaderState::Poisoned { error } => error
|
||||
.take()
|
||||
.unwrap_or_else(|| io::Error::other("stream previously poisoned")),
|
||||
_ => io::Error::other("stream not poisoned"),
|
||||
}
|
||||
}
|
||||
@@ -221,7 +221,11 @@ impl<R: AsyncRead + Unpin> AsyncRead for CryptoReader<R> {
|
||||
let filled = buf.filled_mut();
|
||||
this.decryptor.apply(&mut filled[before..after]);
|
||||
|
||||
trace!(bytes_read, state = this.state_name(), "CryptoReader decrypted chunk");
|
||||
trace!(
|
||||
bytes_read,
|
||||
state = this.state_name(),
|
||||
"CryptoReader decrypted chunk"
|
||||
);
|
||||
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
@@ -503,9 +507,9 @@ impl<W> CryptoWriter<W> {
|
||||
|
||||
fn take_poison_error(&mut self) -> io::Error {
|
||||
match &mut self.state {
|
||||
CryptoWriterState::Poisoned { error } => error.take().unwrap_or_else(|| {
|
||||
io::Error::other("stream previously poisoned")
|
||||
}),
|
||||
CryptoWriterState::Poisoned { error } => error
|
||||
.take()
|
||||
.unwrap_or_else(|| io::Error::other("stream previously poisoned")),
|
||||
_ => io::Error::other("stream not poisoned"),
|
||||
}
|
||||
}
|
||||
@@ -525,7 +529,11 @@ impl<W> CryptoWriter<W> {
|
||||
}
|
||||
|
||||
/// Select how many plaintext bytes can be accepted in buffering path
|
||||
fn select_to_accept_for_buffering(state: &CryptoWriterState, buf_len: usize, max_pending: usize) -> usize {
|
||||
fn select_to_accept_for_buffering(
|
||||
state: &CryptoWriterState,
|
||||
buf_len: usize,
|
||||
max_pending: usize,
|
||||
) -> usize {
|
||||
if buf_len == 0 {
|
||||
return 0;
|
||||
}
|
||||
@@ -602,11 +610,7 @@ impl<W: AsyncWrite + Unpin> CryptoWriter<W> {
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// Poisoned?
|
||||
@@ -629,8 +633,11 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Pending => {
|
||||
// Upstream blocked. Apply ideal backpressure
|
||||
let to_accept =
|
||||
Self::select_to_accept_for_buffering(&this.state, buf.len(), this.max_pending_write);
|
||||
let to_accept = Self::select_to_accept_for_buffering(
|
||||
&this.state,
|
||||
buf.len(),
|
||||
this.max_pending_write,
|
||||
);
|
||||
|
||||
if to_accept == 0 {
|
||||
trace!(
|
||||
|
||||
@@ -9,8 +9,8 @@ use bytes::{Bytes, BytesMut};
|
||||
use std::io::Result;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::protocol::constants::ProtoTag;
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::protocol::constants::ProtoTag;
|
||||
|
||||
// ============= Frame Types =============
|
||||
|
||||
@@ -31,27 +31,27 @@ impl Frame {
|
||||
meta: FrameMeta::default(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Create a new frame with data and metadata
|
||||
pub fn with_meta(data: Bytes, meta: FrameMeta) -> Self {
|
||||
Self { data, meta }
|
||||
}
|
||||
|
||||
|
||||
/// Create an empty frame
|
||||
pub fn empty() -> Self {
|
||||
Self::new(Bytes::new())
|
||||
}
|
||||
|
||||
|
||||
/// Check if frame is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.is_empty()
|
||||
}
|
||||
|
||||
|
||||
/// Get frame length
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
|
||||
/// Create a QuickAck request frame
|
||||
pub fn quickack(data: Bytes) -> Self {
|
||||
Self {
|
||||
@@ -62,7 +62,7 @@ impl Frame {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Create a simple ACK frame
|
||||
pub fn simple_ack(data: Bytes) -> Self {
|
||||
Self {
|
||||
@@ -91,25 +91,25 @@ impl FrameMeta {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
|
||||
/// Create with quickack flag
|
||||
pub fn with_quickack(mut self) -> Self {
|
||||
self.quickack = true;
|
||||
self
|
||||
}
|
||||
|
||||
|
||||
/// Create with simple_ack flag
|
||||
pub fn with_simple_ack(mut self) -> Self {
|
||||
self.simple_ack = true;
|
||||
self
|
||||
}
|
||||
|
||||
|
||||
/// Create with padding length
|
||||
pub fn with_padding(mut self, len: u8) -> Self {
|
||||
self.padding_len = len;
|
||||
self
|
||||
}
|
||||
|
||||
|
||||
/// Check if any special flags are set
|
||||
pub fn has_flags(&self) -> bool {
|
||||
self.quickack || self.simple_ack
|
||||
@@ -122,12 +122,12 @@ impl FrameMeta {
|
||||
pub trait FrameCodec: Send + Sync {
|
||||
/// Get the protocol tag for this codec
|
||||
fn proto_tag(&self) -> ProtoTag;
|
||||
|
||||
|
||||
/// Encode a frame into the destination buffer
|
||||
///
|
||||
/// Returns the number of bytes written.
|
||||
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> Result<usize>;
|
||||
|
||||
|
||||
/// Try to decode a frame from the source buffer
|
||||
///
|
||||
/// Returns:
|
||||
@@ -137,10 +137,10 @@ pub trait FrameCodec: Send + Sync {
|
||||
///
|
||||
/// On success, the consumed bytes are removed from `src`.
|
||||
fn decode(&self, src: &mut BytesMut) -> Result<Option<Frame>>;
|
||||
|
||||
|
||||
/// Get the minimum bytes needed to determine frame length
|
||||
fn min_header_size(&self) -> usize;
|
||||
|
||||
|
||||
/// Get the maximum allowed frame size
|
||||
fn max_frame_size(&self) -> usize {
|
||||
// Default: 16MB
|
||||
@@ -162,30 +162,28 @@ pub fn create_codec(proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Box<dyn Fram
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_frame_creation() {
|
||||
let frame = Frame::new(Bytes::from_static(b"test"));
|
||||
assert_eq!(frame.len(), 4);
|
||||
assert!(!frame.is_empty());
|
||||
assert!(!frame.meta.quickack);
|
||||
|
||||
|
||||
let frame = Frame::empty();
|
||||
assert!(frame.is_empty());
|
||||
|
||||
|
||||
let frame = Frame::quickack(Bytes::from_static(b"ack"));
|
||||
assert!(frame.meta.quickack);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_frame_meta() {
|
||||
let meta = FrameMeta::new()
|
||||
.with_quickack()
|
||||
.with_padding(3);
|
||||
|
||||
let meta = FrameMeta::new().with_quickack().with_padding(3);
|
||||
|
||||
assert!(meta.quickack);
|
||||
assert!(!meta.simple_ack);
|
||||
assert_eq!(meta.padding_len, 3);
|
||||
assert!(meta.has_flags());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,16 +5,16 @@
|
||||
|
||||
#![allow(dead_code)]
|
||||
|
||||
use bytes::{Bytes, BytesMut, BufMut};
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use std::io::{self, Error, ErrorKind};
|
||||
use std::sync::Arc;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
|
||||
use super::frame::{Frame, FrameCodec as FrameCodecTrait, FrameMeta};
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::protocol::constants::{
|
||||
ProtoTag, is_valid_secure_payload_len, secure_padding_len, secure_payload_len_from_wire_len,
|
||||
};
|
||||
use crate::crypto::SecureRandom;
|
||||
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
|
||||
|
||||
// ============= Unified Codec =============
|
||||
|
||||
@@ -40,13 +40,13 @@ impl FrameCodec {
|
||||
rng,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Set maximum frame size
|
||||
pub fn with_max_frame_size(mut self, size: usize) -> Self {
|
||||
self.max_frame_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
|
||||
/// Get protocol tag
|
||||
pub fn proto_tag(&self) -> ProtoTag {
|
||||
self.proto_tag
|
||||
@@ -56,7 +56,7 @@ impl FrameCodec {
|
||||
impl Decoder for FrameCodec {
|
||||
type Item = Frame;
|
||||
type Error = io::Error;
|
||||
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
match self.proto_tag {
|
||||
ProtoTag::Abridged => decode_abridged(src, self.max_frame_size),
|
||||
@@ -68,7 +68,7 @@ impl Decoder for FrameCodec {
|
||||
|
||||
impl Encoder<Frame> for FrameCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
|
||||
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
match self.proto_tag {
|
||||
ProtoTag::Abridged => encode_abridged(&frame, dst),
|
||||
@@ -84,18 +84,18 @@ fn decode_abridged(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Fra
|
||||
if src.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
|
||||
let mut meta = FrameMeta::new();
|
||||
let first_byte = src[0];
|
||||
|
||||
|
||||
// Extract length and quickack flag
|
||||
let mut len_words = (first_byte & 0x7f) as usize;
|
||||
if first_byte >= 0x80 {
|
||||
meta.quickack = true;
|
||||
}
|
||||
|
||||
|
||||
let header_len;
|
||||
|
||||
|
||||
if len_words == 0x7f {
|
||||
// Extended length (3 more bytes needed)
|
||||
if src.len() < 4 {
|
||||
@@ -106,46 +106,49 @@ fn decode_abridged(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Fra
|
||||
} else {
|
||||
header_len = 1;
|
||||
}
|
||||
|
||||
|
||||
// Length is in 4-byte words
|
||||
let byte_len = len_words.checked_mul(4).ok_or_else(|| {
|
||||
Error::new(ErrorKind::InvalidData, "frame length overflow")
|
||||
})?;
|
||||
|
||||
let byte_len = len_words
|
||||
.checked_mul(4)
|
||||
.ok_or_else(|| Error::new(ErrorKind::InvalidData, "frame length overflow"))?;
|
||||
|
||||
// Validate size
|
||||
if byte_len > max_size {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("frame too large: {} bytes (max {})", byte_len, max_size)
|
||||
format!("frame too large: {} bytes (max {})", byte_len, max_size),
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
let total_len = header_len + byte_len;
|
||||
|
||||
|
||||
if src.len() < total_len {
|
||||
// Reserve space for the rest of the frame
|
||||
src.reserve(total_len - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
|
||||
// Extract data
|
||||
let _ = src.split_to(header_len);
|
||||
let data = src.split_to(byte_len).freeze();
|
||||
|
||||
|
||||
Ok(Some(Frame::with_meta(data, meta)))
|
||||
}
|
||||
|
||||
fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
||||
let data = &frame.data;
|
||||
|
||||
|
||||
// Validate alignment
|
||||
if !data.len().is_multiple_of(4) {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("abridged frame must be 4-byte aligned, got {} bytes", data.len())
|
||||
format!(
|
||||
"abridged frame must be 4-byte aligned, got {} bytes",
|
||||
data.len()
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
// Simple ACK: send reversed data without header
|
||||
if frame.meta.simple_ack {
|
||||
dst.reserve(data.len());
|
||||
@@ -154,9 +157,9 @@ fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
||||
let len_words = data.len() / 4;
|
||||
|
||||
|
||||
if len_words < 0x7f {
|
||||
// Short header
|
||||
dst.reserve(1 + data.len());
|
||||
@@ -178,10 +181,10 @@ fn encode_abridged(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
||||
} else {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("frame too large: {} bytes", data.len())
|
||||
format!("frame too large: {} bytes", data.len()),
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
dst.extend_from_slice(data);
|
||||
Ok(())
|
||||
}
|
||||
@@ -192,58 +195,58 @@ fn decode_intermediate(src: &mut BytesMut, max_size: usize) -> io::Result<Option
|
||||
if src.len() < 4 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
|
||||
let mut meta = FrameMeta::new();
|
||||
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
|
||||
|
||||
|
||||
// Check QuickACK flag
|
||||
if len >= 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
|
||||
|
||||
// Validate size
|
||||
if len > max_size {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("frame too large: {} bytes (max {})", len, max_size)
|
||||
format!("frame too large: {} bytes (max {})", len, max_size),
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
let total_len = 4 + len;
|
||||
|
||||
|
||||
if src.len() < total_len {
|
||||
src.reserve(total_len - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
|
||||
// Extract data
|
||||
let _ = src.split_to(4);
|
||||
let data = src.split_to(len).freeze();
|
||||
|
||||
|
||||
Ok(Some(Frame::with_meta(data, meta)))
|
||||
}
|
||||
|
||||
fn encode_intermediate(frame: &Frame, dst: &mut BytesMut) -> io::Result<()> {
|
||||
let data = &frame.data;
|
||||
|
||||
|
||||
// Simple ACK: just send data
|
||||
if frame.meta.simple_ack {
|
||||
dst.reserve(data.len());
|
||||
dst.extend_from_slice(data);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
||||
dst.reserve(4 + data.len());
|
||||
|
||||
|
||||
let mut len = data.len() as u32;
|
||||
if frame.meta.quickack {
|
||||
len |= 0x80000000;
|
||||
}
|
||||
|
||||
|
||||
dst.extend_from_slice(&len.to_le_bytes());
|
||||
dst.extend_from_slice(data);
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -253,31 +256,31 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
|
||||
if src.len() < 4 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
|
||||
let mut meta = FrameMeta::new();
|
||||
let mut len = u32::from_le_bytes([src[0], src[1], src[2], src[3]]) as usize;
|
||||
|
||||
|
||||
// Check QuickACK flag
|
||||
if len >= 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
|
||||
|
||||
// Validate size
|
||||
if len > max_size {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("frame too large: {} bytes (max {})", len, max_size)
|
||||
format!("frame too large: {} bytes (max {})", len, max_size),
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
let total_len = 4 + len;
|
||||
|
||||
|
||||
if src.len() < total_len {
|
||||
src.reserve(total_len - src.len());
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
|
||||
let data_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
|
||||
Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
@@ -285,28 +288,28 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
|
||||
)
|
||||
})?;
|
||||
let padding_len = len - data_len;
|
||||
|
||||
|
||||
meta.padding_len = padding_len as u8;
|
||||
|
||||
|
||||
// Extract data (excluding padding)
|
||||
let _ = src.split_to(4);
|
||||
let all_data = src.split_to(len);
|
||||
// Copy only the data portion, excluding padding
|
||||
let data = Bytes::copy_from_slice(&all_data[..data_len]);
|
||||
|
||||
|
||||
Ok(Some(Frame::with_meta(data, meta)))
|
||||
}
|
||||
|
||||
fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::Result<()> {
|
||||
let data = &frame.data;
|
||||
|
||||
|
||||
// Simple ACK: just send data
|
||||
if frame.meta.simple_ack {
|
||||
dst.reserve(data.len());
|
||||
dst.extend_from_slice(data);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
||||
if !is_valid_secure_payload_len(data.len()) {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
@@ -316,23 +319,23 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::R
|
||||
|
||||
// Generate padding that keeps total length non-divisible by 4.
|
||||
let padding_len = secure_padding_len(data.len(), rng);
|
||||
|
||||
|
||||
let total_len = data.len() + padding_len;
|
||||
dst.reserve(4 + total_len);
|
||||
|
||||
|
||||
let mut len = total_len as u32;
|
||||
if frame.meta.quickack {
|
||||
len |= 0x80000000;
|
||||
}
|
||||
|
||||
|
||||
dst.extend_from_slice(&len.to_le_bytes());
|
||||
dst.extend_from_slice(data);
|
||||
|
||||
|
||||
if padding_len > 0 {
|
||||
let padding = rng.bytes(padding_len);
|
||||
dst.extend_from_slice(&padding);
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -360,7 +363,7 @@ impl Default for AbridgedCodec {
|
||||
impl Decoder for AbridgedCodec {
|
||||
type Item = Frame;
|
||||
type Error = io::Error;
|
||||
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
decode_abridged(src, self.max_frame_size)
|
||||
}
|
||||
@@ -368,7 +371,7 @@ impl Decoder for AbridgedCodec {
|
||||
|
||||
impl Encoder<Frame> for AbridgedCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
|
||||
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
encode_abridged(&frame, dst)
|
||||
}
|
||||
@@ -378,17 +381,17 @@ impl FrameCodecTrait for AbridgedCodec {
|
||||
fn proto_tag(&self) -> ProtoTag {
|
||||
ProtoTag::Abridged
|
||||
}
|
||||
|
||||
|
||||
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
|
||||
let before = dst.len();
|
||||
encode_abridged(frame, dst)?;
|
||||
Ok(dst.len() - before)
|
||||
}
|
||||
|
||||
|
||||
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
|
||||
decode_abridged(src, self.max_frame_size)
|
||||
}
|
||||
|
||||
|
||||
fn min_header_size(&self) -> usize {
|
||||
1
|
||||
}
|
||||
@@ -416,7 +419,7 @@ impl Default for IntermediateCodec {
|
||||
impl Decoder for IntermediateCodec {
|
||||
type Item = Frame;
|
||||
type Error = io::Error;
|
||||
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
decode_intermediate(src, self.max_frame_size)
|
||||
}
|
||||
@@ -424,7 +427,7 @@ impl Decoder for IntermediateCodec {
|
||||
|
||||
impl Encoder<Frame> for IntermediateCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
|
||||
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
encode_intermediate(&frame, dst)
|
||||
}
|
||||
@@ -434,17 +437,17 @@ impl FrameCodecTrait for IntermediateCodec {
|
||||
fn proto_tag(&self) -> ProtoTag {
|
||||
ProtoTag::Intermediate
|
||||
}
|
||||
|
||||
|
||||
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
|
||||
let before = dst.len();
|
||||
encode_intermediate(frame, dst)?;
|
||||
Ok(dst.len() - before)
|
||||
}
|
||||
|
||||
|
||||
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
|
||||
decode_intermediate(src, self.max_frame_size)
|
||||
}
|
||||
|
||||
|
||||
fn min_header_size(&self) -> usize {
|
||||
4
|
||||
}
|
||||
@@ -474,7 +477,7 @@ impl Default for SecureCodec {
|
||||
impl Decoder for SecureCodec {
|
||||
type Item = Frame;
|
||||
type Error = io::Error;
|
||||
|
||||
|
||||
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
|
||||
decode_secure(src, self.max_frame_size)
|
||||
}
|
||||
@@ -482,7 +485,7 @@ impl Decoder for SecureCodec {
|
||||
|
||||
impl Encoder<Frame> for SecureCodec {
|
||||
type Error = io::Error;
|
||||
|
||||
|
||||
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
encode_secure(&frame, dst, &self.rng)
|
||||
}
|
||||
@@ -492,17 +495,17 @@ impl FrameCodecTrait for SecureCodec {
|
||||
fn proto_tag(&self) -> ProtoTag {
|
||||
ProtoTag::Secure
|
||||
}
|
||||
|
||||
|
||||
fn encode(&self, frame: &Frame, dst: &mut BytesMut) -> io::Result<usize> {
|
||||
let before = dst.len();
|
||||
encode_secure(frame, dst, &self.rng)?;
|
||||
Ok(dst.len() - before)
|
||||
}
|
||||
|
||||
|
||||
fn decode(&self, src: &mut BytesMut) -> io::Result<Option<Frame>> {
|
||||
decode_secure(src, self.max_frame_size)
|
||||
}
|
||||
|
||||
|
||||
fn min_header_size(&self) -> usize {
|
||||
4
|
||||
}
|
||||
@@ -513,121 +516,127 @@ impl FrameCodecTrait for SecureCodec {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashSet;
|
||||
use tokio_util::codec::{FramedRead, FramedWrite};
|
||||
use tokio::io::duplex;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use crate::crypto::SecureRandom;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::io::duplex;
|
||||
use tokio_util::codec::{FramedRead, FramedWrite};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_framed_abridged() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
|
||||
let mut writer = FramedWrite::new(client, AbridgedCodec::new());
|
||||
let mut reader = FramedRead::new(server, AbridgedCodec::new());
|
||||
|
||||
|
||||
// Write a frame
|
||||
let frame = Frame::new(Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]));
|
||||
writer.send(frame).await.unwrap();
|
||||
|
||||
|
||||
// Read it back
|
||||
let received = reader.next().await.unwrap().unwrap();
|
||||
assert_eq!(&received.data[..], &[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_framed_intermediate() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
|
||||
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
|
||||
let mut reader = FramedRead::new(server, IntermediateCodec::new());
|
||||
|
||||
|
||||
let frame = Frame::new(Bytes::from_static(b"hello world"));
|
||||
writer.send(frame).await.unwrap();
|
||||
|
||||
|
||||
let received = reader.next().await.unwrap().unwrap();
|
||||
assert_eq!(&received.data[..], b"hello world");
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_framed_secure() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
|
||||
let mut writer = FramedWrite::new(client, SecureCodec::new(Arc::new(SecureRandom::new())));
|
||||
let mut reader = FramedRead::new(server, SecureCodec::new(Arc::new(SecureRandom::new())));
|
||||
|
||||
|
||||
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
let frame = Frame::new(original.clone());
|
||||
writer.send(frame).await.unwrap();
|
||||
|
||||
|
||||
let received = reader.next().await.unwrap().unwrap();
|
||||
assert_eq!(&received.data[..], &original[..]);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_unified_codec() {
|
||||
for proto_tag in [ProtoTag::Abridged, ProtoTag::Intermediate, ProtoTag::Secure] {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let mut writer = FramedWrite::new(client, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
|
||||
let mut reader = FramedRead::new(server, FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())));
|
||||
|
||||
|
||||
let mut writer = FramedWrite::new(
|
||||
client,
|
||||
FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())),
|
||||
);
|
||||
let mut reader = FramedRead::new(
|
||||
server,
|
||||
FrameCodec::new(proto_tag, Arc::new(SecureRandom::new())),
|
||||
);
|
||||
|
||||
// Use 4-byte aligned data for abridged compatibility
|
||||
let original = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
let frame = Frame::new(original.clone());
|
||||
writer.send(frame).await.unwrap();
|
||||
|
||||
|
||||
let received = reader.next().await.unwrap().unwrap();
|
||||
assert_eq!(received.data.len(), 8);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_frames() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
|
||||
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
|
||||
let mut reader = FramedRead::new(server, IntermediateCodec::new());
|
||||
|
||||
|
||||
// Send multiple frames
|
||||
for i in 0..10 {
|
||||
let data: Vec<u8> = (0..((i + 1) * 10)).map(|j| (j % 256) as u8).collect();
|
||||
let frame = Frame::new(Bytes::from(data));
|
||||
writer.send(frame).await.unwrap();
|
||||
}
|
||||
|
||||
|
||||
// Receive them
|
||||
for i in 0..10 {
|
||||
let received = reader.next().await.unwrap().unwrap();
|
||||
assert_eq!(received.data.len(), (i + 1) * 10);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_quickack_flag() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
|
||||
let mut writer = FramedWrite::new(client, IntermediateCodec::new());
|
||||
let mut reader = FramedRead::new(server, IntermediateCodec::new());
|
||||
|
||||
|
||||
let frame = Frame::quickack(Bytes::from_static(b"urgent"));
|
||||
writer.send(frame).await.unwrap();
|
||||
|
||||
|
||||
let received = reader.next().await.unwrap().unwrap();
|
||||
assert!(received.meta.quickack);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_frame_too_large() {
|
||||
let mut codec = FrameCodec::new(ProtoTag::Intermediate, Arc::new(SecureRandom::new()))
|
||||
.with_max_frame_size(100);
|
||||
|
||||
|
||||
// Create a "frame" that claims to be very large
|
||||
let mut buf = BytesMut::new();
|
||||
buf.extend_from_slice(&1000u32.to_le_bytes()); // length = 1000
|
||||
buf.extend_from_slice(&[0u8; 10]); // partial data
|
||||
|
||||
|
||||
let result = codec.decode(&mut buf);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
@@ -2,13 +2,13 @@
|
||||
|
||||
#![allow(dead_code)]
|
||||
|
||||
use super::traits::{FrameMeta, LayeredStream};
|
||||
use crate::crypto::{SecureRandom, crc32};
|
||||
use crate::protocol::constants::*;
|
||||
use bytes::Bytes;
|
||||
use std::io::{Error, ErrorKind, Result};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
||||
use crate::protocol::constants::*;
|
||||
use crate::crypto::{crc32, SecureRandom};
|
||||
use std::sync::Arc;
|
||||
use super::traits::{FrameMeta, LayeredStream};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
// ============= Abridged (Compact) Frame =============
|
||||
|
||||
@@ -27,41 +27,47 @@ impl<R: AsyncRead + Unpin> AbridgedFrameReader<R> {
|
||||
/// Read a frame and return (data, metadata)
|
||||
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
|
||||
let mut meta = FrameMeta::new();
|
||||
|
||||
|
||||
// Read length byte
|
||||
let mut len_byte = [0u8];
|
||||
self.upstream.read_exact(&mut len_byte).await?;
|
||||
|
||||
|
||||
let mut len = len_byte[0] as usize;
|
||||
|
||||
|
||||
// Check QuickACK flag (high bit)
|
||||
if len >= 0x80 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80;
|
||||
}
|
||||
|
||||
|
||||
// Extended length (3 bytes)
|
||||
if len == 0x7f {
|
||||
let mut len_bytes = [0u8; 3];
|
||||
self.upstream.read_exact(&mut len_bytes).await?;
|
||||
len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], 0]) as usize;
|
||||
}
|
||||
|
||||
|
||||
// Length is in 4-byte words
|
||||
let byte_len = len * 4;
|
||||
|
||||
|
||||
// Read data
|
||||
let mut data = vec![0u8; byte_len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
|
||||
Ok((Bytes::from(data), meta))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> LayeredStream<R> for AbridgedFrameReader<R> {
|
||||
fn upstream(&self) -> &R { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
|
||||
fn into_upstream(self) -> R { self.upstream }
|
||||
fn upstream(&self) -> &R {
|
||||
&self.upstream
|
||||
}
|
||||
fn upstream_mut(&mut self) -> &mut R {
|
||||
&mut self.upstream
|
||||
}
|
||||
fn into_upstream(self) -> R {
|
||||
self.upstream
|
||||
}
|
||||
}
|
||||
|
||||
/// Writer for abridged MTProto framing
|
||||
@@ -81,19 +87,22 @@ impl<W: AsyncWrite + Unpin> AbridgedFrameWriter<W> {
|
||||
if !data.len().is_multiple_of(4) {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("Abridged frame must be aligned to 4 bytes, got {}", data.len()),
|
||||
format!(
|
||||
"Abridged frame must be aligned to 4 bytes, got {}",
|
||||
data.len()
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
// Simple ACK: send reversed data
|
||||
if meta.simple_ack {
|
||||
let reversed: Vec<u8> = data.iter().rev().copied().collect();
|
||||
self.upstream.write_all(&reversed).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
||||
let len_div_4 = data.len() / 4;
|
||||
|
||||
|
||||
if len_div_4 < 0x7f {
|
||||
// Short length (1 byte)
|
||||
self.upstream.write_all(&[len_div_4 as u8]).await?;
|
||||
@@ -108,20 +117,26 @@ impl<W: AsyncWrite + Unpin> AbridgedFrameWriter<W> {
|
||||
format!("Frame too large: {} bytes", data.len()),
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
self.upstream.write_all(data).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
pub async fn flush(&mut self) -> Result<()> {
|
||||
self.upstream.flush().await
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> LayeredStream<W> for AbridgedFrameWriter<W> {
|
||||
fn upstream(&self) -> &W { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
|
||||
fn into_upstream(self) -> W { self.upstream }
|
||||
fn upstream(&self) -> &W {
|
||||
&self.upstream
|
||||
}
|
||||
fn upstream_mut(&mut self) -> &mut W {
|
||||
&mut self.upstream
|
||||
}
|
||||
fn into_upstream(self) -> W {
|
||||
self.upstream
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Intermediate Frame =============
|
||||
@@ -140,31 +155,37 @@ impl<R> IntermediateFrameReader<R> {
|
||||
impl<R: AsyncRead + Unpin> IntermediateFrameReader<R> {
|
||||
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
|
||||
let mut meta = FrameMeta::new();
|
||||
|
||||
|
||||
// Read 4-byte length
|
||||
let mut len_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut len_bytes).await?;
|
||||
|
||||
|
||||
let mut len = u32::from_le_bytes(len_bytes) as usize;
|
||||
|
||||
|
||||
// Check QuickACK flag (high bit)
|
||||
if len > 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
|
||||
|
||||
// Read data
|
||||
let mut data = vec![0u8; len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
|
||||
Ok((Bytes::from(data), meta))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> LayeredStream<R> for IntermediateFrameReader<R> {
|
||||
fn upstream(&self) -> &R { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
|
||||
fn into_upstream(self) -> R { self.upstream }
|
||||
fn upstream(&self) -> &R {
|
||||
&self.upstream
|
||||
}
|
||||
fn upstream_mut(&mut self) -> &mut R {
|
||||
&mut self.upstream
|
||||
}
|
||||
fn into_upstream(self) -> R {
|
||||
self.upstream
|
||||
}
|
||||
}
|
||||
|
||||
/// Writer for intermediate MTProto framing
|
||||
@@ -189,16 +210,22 @@ impl<W: AsyncWrite + Unpin> IntermediateFrameWriter<W> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
pub async fn flush(&mut self) -> Result<()> {
|
||||
self.upstream.flush().await
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> LayeredStream<W> for IntermediateFrameWriter<W> {
|
||||
fn upstream(&self) -> &W { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
|
||||
fn into_upstream(self) -> W { self.upstream }
|
||||
fn upstream(&self) -> &W {
|
||||
&self.upstream
|
||||
}
|
||||
fn upstream_mut(&mut self) -> &mut W {
|
||||
&mut self.upstream
|
||||
}
|
||||
fn into_upstream(self) -> W {
|
||||
self.upstream
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Secure Intermediate Frame =============
|
||||
@@ -217,23 +244,23 @@ impl<R> SecureIntermediateFrameReader<R> {
|
||||
impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
|
||||
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
|
||||
let mut meta = FrameMeta::new();
|
||||
|
||||
|
||||
// Read 4-byte length
|
||||
let mut len_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut len_bytes).await?;
|
||||
|
||||
|
||||
let mut len = u32::from_le_bytes(len_bytes) as usize;
|
||||
|
||||
|
||||
// Check QuickACK flag
|
||||
if len > 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
|
||||
|
||||
// Read data (including padding)
|
||||
let mut data = vec![0u8; len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
|
||||
let payload_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
|
||||
Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
@@ -241,15 +268,21 @@ impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
|
||||
)
|
||||
})?;
|
||||
data.truncate(payload_len);
|
||||
|
||||
|
||||
Ok((Bytes::from(data), meta))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> LayeredStream<R> for SecureIntermediateFrameReader<R> {
|
||||
fn upstream(&self) -> &R { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
|
||||
fn into_upstream(self) -> R { self.upstream }
|
||||
fn upstream(&self) -> &R {
|
||||
&self.upstream
|
||||
}
|
||||
fn upstream_mut(&mut self) -> &mut R {
|
||||
&mut self.upstream
|
||||
}
|
||||
fn into_upstream(self) -> R {
|
||||
self.upstream
|
||||
}
|
||||
}
|
||||
|
||||
/// Writer for secure intermediate MTProto framing
|
||||
@@ -270,7 +303,7 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
|
||||
self.upstream.write_all(data).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
||||
if !is_valid_secure_payload_len(data.len()) {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
@@ -281,26 +314,32 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
|
||||
// Add padding so total length is never divisible by 4 (MTProto Secure)
|
||||
let padding_len = secure_padding_len(data.len(), &self.rng);
|
||||
let padding = self.rng.bytes(padding_len);
|
||||
|
||||
|
||||
let total_len = data.len() + padding_len;
|
||||
let len_bytes = (total_len as u32).to_le_bytes();
|
||||
|
||||
|
||||
self.upstream.write_all(&len_bytes).await?;
|
||||
self.upstream.write_all(data).await?;
|
||||
self.upstream.write_all(&padding).await?;
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
pub async fn flush(&mut self) -> Result<()> {
|
||||
self.upstream.flush().await
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> LayeredStream<W> for SecureIntermediateFrameWriter<W> {
|
||||
fn upstream(&self) -> &W { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
|
||||
fn into_upstream(self) -> W { self.upstream }
|
||||
fn upstream(&self) -> &W {
|
||||
&self.upstream
|
||||
}
|
||||
fn upstream_mut(&mut self) -> &mut W {
|
||||
&mut self.upstream
|
||||
}
|
||||
fn into_upstream(self) -> W {
|
||||
self.upstream
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Full MTProto Frame (with CRC) =============
|
||||
@@ -313,7 +352,10 @@ pub struct MtprotoFrameReader<R> {
|
||||
|
||||
impl<R> MtprotoFrameReader<R> {
|
||||
pub fn new(upstream: R, start_seq: i32) -> Self {
|
||||
Self { upstream, seq_no: start_seq }
|
||||
Self {
|
||||
upstream,
|
||||
seq_no: start_seq,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -324,57 +366,65 @@ impl<R: AsyncRead + Unpin> MtprotoFrameReader<R> {
|
||||
let mut len_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut len_bytes).await?;
|
||||
let len = u32::from_le_bytes(len_bytes) as usize;
|
||||
|
||||
|
||||
// Skip padding-only messages
|
||||
if len == 4 {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
// Validate length
|
||||
if !(MIN_MSG_LEN..=MAX_MSG_LEN).contains(&len) || !len.is_multiple_of(PADDING_FILLER.len()) {
|
||||
if !(MIN_MSG_LEN..=MAX_MSG_LEN).contains(&len)
|
||||
|| !len.is_multiple_of(PADDING_FILLER.len())
|
||||
{
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("Invalid message length: {}", len),
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
// Read sequence number
|
||||
let mut seq_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut seq_bytes).await?;
|
||||
let msg_seq = i32::from_le_bytes(seq_bytes);
|
||||
|
||||
|
||||
if msg_seq != self.seq_no {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("Sequence mismatch: expected {}, got {}", self.seq_no, msg_seq),
|
||||
format!(
|
||||
"Sequence mismatch: expected {}, got {}",
|
||||
self.seq_no, msg_seq
|
||||
),
|
||||
));
|
||||
}
|
||||
self.seq_no += 1;
|
||||
|
||||
|
||||
// Read data (length - 4 len - 4 seq - 4 crc = len - 12)
|
||||
let data_len = len - 12;
|
||||
let mut data = vec![0u8; data_len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
|
||||
// Read and verify CRC32
|
||||
let mut crc_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut crc_bytes).await?;
|
||||
let expected_crc = u32::from_le_bytes(crc_bytes);
|
||||
|
||||
|
||||
// Compute CRC over len + seq + data
|
||||
let mut crc_input = Vec::with_capacity(8 + data_len);
|
||||
crc_input.extend_from_slice(&len_bytes);
|
||||
crc_input.extend_from_slice(&seq_bytes);
|
||||
crc_input.extend_from_slice(&data);
|
||||
let computed_crc = crc32(&crc_input);
|
||||
|
||||
|
||||
if computed_crc != expected_crc {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("CRC mismatch: expected {:08x}, got {:08x}", expected_crc, computed_crc),
|
||||
format!(
|
||||
"CRC mismatch: expected {:08x}, got {:08x}",
|
||||
expected_crc, computed_crc
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
return Ok(Bytes::from(data));
|
||||
}
|
||||
}
|
||||
@@ -388,7 +438,10 @@ pub struct MtprotoFrameWriter<W> {
|
||||
|
||||
impl<W> MtprotoFrameWriter<W> {
|
||||
pub fn new(upstream: W, start_seq: i32) -> Self {
|
||||
Self { upstream, seq_no: start_seq }
|
||||
Self {
|
||||
upstream,
|
||||
seq_no: start_seq,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -396,11 +449,11 @@ impl<W: AsyncWrite + Unpin> MtprotoFrameWriter<W> {
|
||||
pub async fn write_frame(&mut self, msg: &[u8]) -> Result<()> {
|
||||
// Total length: 4 (len) + 4 (seq) + data + 4 (crc)
|
||||
let len = msg.len() + 12;
|
||||
|
||||
|
||||
let len_bytes = (len as u32).to_le_bytes();
|
||||
let seq_bytes = self.seq_no.to_le_bytes();
|
||||
self.seq_no += 1;
|
||||
|
||||
|
||||
// Compute CRC
|
||||
let mut crc_input = Vec::with_capacity(8 + msg.len());
|
||||
crc_input.extend_from_slice(&len_bytes);
|
||||
@@ -408,25 +461,25 @@ impl<W: AsyncWrite + Unpin> MtprotoFrameWriter<W> {
|
||||
crc_input.extend_from_slice(msg);
|
||||
let checksum = crc32(&crc_input);
|
||||
let crc_bytes = checksum.to_le_bytes();
|
||||
|
||||
|
||||
// Calculate padding for CBC alignment
|
||||
let total_len = len_bytes.len() + seq_bytes.len() + msg.len() + crc_bytes.len();
|
||||
let padding_needed = (CBC_PADDING - (total_len % CBC_PADDING)) % CBC_PADDING;
|
||||
let padding_count = padding_needed / PADDING_FILLER.len();
|
||||
|
||||
|
||||
// Write everything
|
||||
self.upstream.write_all(&len_bytes).await?;
|
||||
self.upstream.write_all(&seq_bytes).await?;
|
||||
self.upstream.write_all(msg).await?;
|
||||
self.upstream.write_all(&crc_bytes).await?;
|
||||
|
||||
|
||||
for _ in 0..padding_count {
|
||||
self.upstream.write_all(&PADDING_FILLER).await?;
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
pub async fn flush(&mut self) -> Result<()> {
|
||||
self.upstream.flush().await
|
||||
}
|
||||
@@ -445,11 +498,15 @@ impl<R: AsyncRead + Unpin> FrameReaderKind<R> {
|
||||
pub fn new(upstream: R, proto_tag: ProtoTag) -> Self {
|
||||
match proto_tag {
|
||||
ProtoTag::Abridged => FrameReaderKind::Abridged(AbridgedFrameReader::new(upstream)),
|
||||
ProtoTag::Intermediate => FrameReaderKind::Intermediate(IntermediateFrameReader::new(upstream)),
|
||||
ProtoTag::Secure => FrameReaderKind::SecureIntermediate(SecureIntermediateFrameReader::new(upstream)),
|
||||
ProtoTag::Intermediate => {
|
||||
FrameReaderKind::Intermediate(IntermediateFrameReader::new(upstream))
|
||||
}
|
||||
ProtoTag::Secure => {
|
||||
FrameReaderKind::SecureIntermediate(SecureIntermediateFrameReader::new(upstream))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
|
||||
match self {
|
||||
FrameReaderKind::Abridged(r) => r.read_frame().await,
|
||||
@@ -469,11 +526,15 @@ impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
|
||||
pub fn new(upstream: W, proto_tag: ProtoTag, rng: Arc<SecureRandom>) -> Self {
|
||||
match proto_tag {
|
||||
ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)),
|
||||
ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)),
|
||||
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream, rng)),
|
||||
ProtoTag::Intermediate => {
|
||||
FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream))
|
||||
}
|
||||
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(
|
||||
SecureIntermediateFrameWriter::new(upstream, rng),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
|
||||
match self {
|
||||
FrameWriterKind::Abridged(w) => w.write_frame(data, meta).await,
|
||||
@@ -481,7 +542,7 @@ impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
|
||||
FrameWriterKind::SecureIntermediate(w) => w.write_frame(data, meta).await,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub async fn flush(&mut self) -> Result<()> {
|
||||
match self {
|
||||
FrameWriterKind::Abridged(w) => w.flush().await,
|
||||
@@ -494,103 +555,110 @@ impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::duplex;
|
||||
use std::sync::Arc;
|
||||
use crate::crypto::SecureRandom;
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::io::duplex;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_abridged_roundtrip() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
|
||||
let mut writer = AbridgedFrameWriter::new(client);
|
||||
let mut reader = AbridgedFrameReader::new(server);
|
||||
|
||||
|
||||
// Short frame
|
||||
let data = vec![1u8, 2, 3, 4]; // 4 bytes = 1 word
|
||||
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
|
||||
let (received, _meta) = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &data[..]);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_abridged_long_frame() {
|
||||
let (client, server) = duplex(65536);
|
||||
|
||||
|
||||
let mut writer = AbridgedFrameWriter::new(client);
|
||||
let mut reader = AbridgedFrameReader::new(server);
|
||||
|
||||
|
||||
// Long frame (> 0x7f words = 508 bytes)
|
||||
let data: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
|
||||
let padded_len = (data.len() + 3) / 4 * 4;
|
||||
let mut padded = data.clone();
|
||||
padded.resize(padded_len, 0);
|
||||
|
||||
writer.write_frame(&padded, &FrameMeta::new()).await.unwrap();
|
||||
|
||||
writer
|
||||
.write_frame(&padded, &FrameMeta::new())
|
||||
.await
|
||||
.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
|
||||
let (received, _meta) = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &padded[..]);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_intermediate_roundtrip() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
|
||||
let mut writer = IntermediateFrameWriter::new(client);
|
||||
let mut reader = IntermediateFrameReader::new(server);
|
||||
|
||||
|
||||
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
|
||||
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
|
||||
let (received, _meta) = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &data[..]);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_secure_intermediate_padding() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
|
||||
let mut writer = SecureIntermediateFrameWriter::new(client, Arc::new(SecureRandom::new()));
|
||||
let mut reader = SecureIntermediateFrameReader::new(server);
|
||||
|
||||
|
||||
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
|
||||
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
|
||||
let (received, _meta) = reader.read_frame().await.unwrap();
|
||||
assert_eq!(received.len(), data.len());
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mtproto_frame_roundtrip() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
|
||||
let mut writer = MtprotoFrameWriter::new(client, 0);
|
||||
let mut reader = MtprotoFrameReader::new(server, 0);
|
||||
|
||||
|
||||
// Message must be padded properly
|
||||
let data = vec![0u8; 16]; // Aligned to 4 and CBC_PADDING
|
||||
writer.write_frame(&data).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
|
||||
let received = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &data[..]);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_frame_reader_kind() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate, Arc::new(SecureRandom::new()));
|
||||
|
||||
let mut writer = FrameWriterKind::new(
|
||||
client,
|
||||
ProtoTag::Intermediate,
|
||||
Arc::new(SecureRandom::new()),
|
||||
);
|
||||
let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate);
|
||||
|
||||
|
||||
let data = vec![1u8, 2, 3, 4];
|
||||
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
|
||||
let (received, _) = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &data[..]);
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
//! Stream wrappers for MTProto protocol layers
|
||||
|
||||
pub mod state;
|
||||
pub mod buffer_pool;
|
||||
pub mod traits;
|
||||
pub mod crypto_stream;
|
||||
pub mod tls_stream;
|
||||
pub mod frame;
|
||||
pub mod frame_codec;
|
||||
pub mod state;
|
||||
pub mod tls_stream;
|
||||
pub mod traits;
|
||||
|
||||
// Legacy compatibility - will be removed later
|
||||
pub mod frame_stream;
|
||||
@@ -14,13 +14,12 @@ pub mod frame_stream;
|
||||
// Re-export state machine types
|
||||
#[allow(unused_imports)]
|
||||
pub use state::{
|
||||
StreamState, Transition, PollResult,
|
||||
ReadBuffer, WriteBuffer, HeaderBuffer, YieldBuffer,
|
||||
HeaderBuffer, PollResult, ReadBuffer, StreamState, Transition, WriteBuffer, YieldBuffer,
|
||||
};
|
||||
|
||||
// Re-export buffer pool
|
||||
#[allow(unused_imports)]
|
||||
pub use buffer_pool::{BufferPool, PooledBuffer, PoolStats};
|
||||
pub use buffer_pool::{BufferPool, PoolStats, PooledBuffer};
|
||||
|
||||
// Re-export stream implementations
|
||||
#[allow(unused_imports)]
|
||||
@@ -29,21 +28,16 @@ pub use tls_stream::{FakeTlsReader, FakeTlsWriter};
|
||||
|
||||
// Re-export frame types
|
||||
#[allow(unused_imports)]
|
||||
pub use frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait, create_codec};
|
||||
pub use frame::{Frame, FrameCodec as FrameCodecTrait, FrameMeta, create_codec};
|
||||
|
||||
// Re-export tokio-util compatible codecs
|
||||
#[allow(unused_imports)]
|
||||
pub use frame_codec::{
|
||||
FrameCodec,
|
||||
AbridgedCodec, IntermediateCodec, SecureCodec,
|
||||
};
|
||||
pub use frame_codec::{AbridgedCodec, FrameCodec, IntermediateCodec, SecureCodec};
|
||||
|
||||
// Legacy re-exports for compatibility
|
||||
#[allow(unused_imports)]
|
||||
pub use frame_stream::{
|
||||
AbridgedFrameReader, AbridgedFrameWriter,
|
||||
IntermediateFrameReader, IntermediateFrameWriter,
|
||||
AbridgedFrameReader, AbridgedFrameWriter, FrameReaderKind, FrameWriterKind,
|
||||
IntermediateFrameReader, IntermediateFrameWriter, MtprotoFrameReader, MtprotoFrameWriter,
|
||||
SecureIntermediateFrameReader, SecureIntermediateFrameWriter,
|
||||
MtprotoFrameReader, MtprotoFrameWriter,
|
||||
FrameReaderKind, FrameWriterKind,
|
||||
};
|
||||
|
||||
@@ -14,10 +14,10 @@ use std::io;
|
||||
pub trait StreamState: Sized {
|
||||
/// Check if this is a terminal state (no more transitions possible)
|
||||
fn is_terminal(&self) -> bool;
|
||||
|
||||
|
||||
/// Check if stream is in poisoned/error state
|
||||
fn is_poisoned(&self) -> bool;
|
||||
|
||||
|
||||
/// Get human-readable state name for debugging
|
||||
fn state_name(&self) -> &'static str;
|
||||
}
|
||||
@@ -44,7 +44,7 @@ impl<S, O> Transition<S, O> {
|
||||
pub fn has_output(&self) -> bool {
|
||||
matches!(self, Transition::Complete(_) | Transition::Yield(_, _))
|
||||
}
|
||||
|
||||
|
||||
/// Map the output value
|
||||
pub fn map_output<U, F: FnOnce(O) -> U>(self, f: F) -> Transition<S, U> {
|
||||
match self {
|
||||
@@ -55,7 +55,7 @@ impl<S, O> Transition<S, O> {
|
||||
Transition::Error(e) => Transition::Error(e),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Map the state value
|
||||
pub fn map_state<T, F: FnOnce(S) -> T>(self, f: F) -> Transition<T, O> {
|
||||
match self {
|
||||
@@ -90,12 +90,12 @@ impl<T> PollResult<T> {
|
||||
pub fn is_ready(&self) -> bool {
|
||||
matches!(self, PollResult::Ready(_))
|
||||
}
|
||||
|
||||
|
||||
/// Check if result indicates EOF
|
||||
pub fn is_eof(&self) -> bool {
|
||||
matches!(self, PollResult::Eof)
|
||||
}
|
||||
|
||||
|
||||
/// Convert to Option, discarding non-ready states
|
||||
pub fn ok(self) -> Option<T> {
|
||||
match self {
|
||||
@@ -103,7 +103,7 @@ impl<T> PollResult<T> {
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Map the value
|
||||
pub fn map<U, F: FnOnce(T) -> U>(self, f: F) -> PollResult<U> {
|
||||
match self {
|
||||
@@ -146,7 +146,7 @@ impl ReadBuffer {
|
||||
target: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Create with specific capacity
|
||||
pub fn with_capacity(capacity: usize) -> Self {
|
||||
Self {
|
||||
@@ -154,7 +154,7 @@ impl ReadBuffer {
|
||||
target: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Create with target size
|
||||
pub fn with_target(target: usize) -> Self {
|
||||
Self {
|
||||
@@ -162,17 +162,17 @@ impl ReadBuffer {
|
||||
target: Some(target),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Get current buffer length
|
||||
pub fn len(&self) -> usize {
|
||||
self.buffer.len()
|
||||
}
|
||||
|
||||
|
||||
/// Check if buffer is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.buffer.is_empty()
|
||||
}
|
||||
|
||||
|
||||
/// Check if target is reached
|
||||
pub fn is_complete(&self) -> bool {
|
||||
match self.target {
|
||||
@@ -180,7 +180,7 @@ impl ReadBuffer {
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Get remaining bytes needed
|
||||
pub fn remaining(&self) -> usize {
|
||||
match self.target {
|
||||
@@ -188,18 +188,18 @@ impl ReadBuffer {
|
||||
None => 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Append data to buffer
|
||||
pub fn extend(&mut self, data: &[u8]) {
|
||||
self.buffer.extend_from_slice(data);
|
||||
}
|
||||
|
||||
|
||||
/// Take all data from buffer
|
||||
pub fn take(&mut self) -> Bytes {
|
||||
self.target = None;
|
||||
self.buffer.split().freeze()
|
||||
}
|
||||
|
||||
|
||||
/// Take exactly n bytes
|
||||
pub fn take_exact(&mut self, n: usize) -> Option<Bytes> {
|
||||
if self.buffer.len() >= n {
|
||||
@@ -208,23 +208,23 @@ impl ReadBuffer {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Get a slice of the buffer
|
||||
pub fn as_slice(&self) -> &[u8] {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
|
||||
/// Get mutable access to underlying BytesMut
|
||||
pub fn as_bytes_mut(&mut self) -> &mut BytesMut {
|
||||
&mut self.buffer
|
||||
}
|
||||
|
||||
|
||||
/// Clear the buffer
|
||||
pub fn clear(&mut self) {
|
||||
self.buffer.clear();
|
||||
self.target = None;
|
||||
}
|
||||
|
||||
|
||||
/// Set new target
|
||||
pub fn set_target(&mut self, target: usize) {
|
||||
self.target = Some(target);
|
||||
@@ -253,7 +253,7 @@ impl WriteBuffer {
|
||||
pub fn new() -> Self {
|
||||
Self::with_max_size(256 * 1024)
|
||||
}
|
||||
|
||||
|
||||
/// Create with specific max size
|
||||
pub fn with_max_size(max_size: usize) -> Self {
|
||||
Self {
|
||||
@@ -262,27 +262,27 @@ impl WriteBuffer {
|
||||
max_size,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Get pending bytes count
|
||||
pub fn len(&self) -> usize {
|
||||
self.buffer.len() - self.position
|
||||
}
|
||||
|
||||
|
||||
/// Check if buffer is empty (all written)
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.position >= self.buffer.len()
|
||||
}
|
||||
|
||||
|
||||
/// Check if buffer is full
|
||||
pub fn is_full(&self) -> bool {
|
||||
self.buffer.len() >= self.max_size
|
||||
}
|
||||
|
||||
|
||||
/// Get remaining capacity
|
||||
pub fn remaining_capacity(&self) -> usize {
|
||||
self.max_size.saturating_sub(self.buffer.len())
|
||||
}
|
||||
|
||||
|
||||
/// Append data to buffer
|
||||
pub fn extend(&mut self, data: &[u8]) -> Result<(), ()> {
|
||||
if self.buffer.len() + data.len() > self.max_size {
|
||||
@@ -291,23 +291,23 @@ impl WriteBuffer {
|
||||
self.buffer.extend_from_slice(data);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/// Get slice of data to write
|
||||
pub fn pending(&self) -> &[u8] {
|
||||
&self.buffer[self.position..]
|
||||
}
|
||||
|
||||
|
||||
/// Advance position by n bytes (after successful write)
|
||||
pub fn advance(&mut self, n: usize) {
|
||||
self.position += n;
|
||||
|
||||
|
||||
// If all data written, reset buffer
|
||||
if self.position >= self.buffer.len() {
|
||||
self.buffer.clear();
|
||||
self.position = 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Clear the buffer
|
||||
pub fn clear(&mut self) {
|
||||
self.buffer.clear();
|
||||
@@ -340,38 +340,38 @@ impl<const N: usize> HeaderBuffer<N> {
|
||||
filled: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Get slice for reading into
|
||||
pub fn unfilled_mut(&mut self) -> &mut [u8] {
|
||||
&mut self.data[self.filled..]
|
||||
}
|
||||
|
||||
|
||||
/// Advance filled count
|
||||
pub fn advance(&mut self, n: usize) {
|
||||
self.filled = (self.filled + n).min(N);
|
||||
}
|
||||
|
||||
|
||||
/// Check if completely filled
|
||||
pub fn is_complete(&self) -> bool {
|
||||
self.filled >= N
|
||||
}
|
||||
|
||||
|
||||
/// Get remaining bytes needed
|
||||
pub fn remaining(&self) -> usize {
|
||||
N - self.filled
|
||||
}
|
||||
|
||||
|
||||
/// Get filled bytes as slice
|
||||
pub fn as_slice(&self) -> &[u8] {
|
||||
&self.data[..self.filled]
|
||||
}
|
||||
|
||||
|
||||
/// Get complete buffer (panics if not complete)
|
||||
pub fn as_array(&self) -> &[u8; N] {
|
||||
assert!(self.is_complete());
|
||||
&self.data
|
||||
}
|
||||
|
||||
|
||||
/// Take the buffer, resetting state
|
||||
pub fn take(&mut self) -> [u8; N] {
|
||||
let data = self.data;
|
||||
@@ -379,7 +379,7 @@ impl<const N: usize> HeaderBuffer<N> {
|
||||
self.filled = 0;
|
||||
data
|
||||
}
|
||||
|
||||
|
||||
/// Reset to empty state
|
||||
pub fn reset(&mut self) {
|
||||
self.filled = 0;
|
||||
@@ -406,17 +406,17 @@ impl YieldBuffer {
|
||||
pub fn new(data: Bytes) -> Self {
|
||||
Self { data, position: 0 }
|
||||
}
|
||||
|
||||
|
||||
/// Check if all data has been yielded
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.position >= self.data.len()
|
||||
}
|
||||
|
||||
|
||||
/// Get remaining bytes
|
||||
pub fn remaining(&self) -> usize {
|
||||
self.data.len() - self.position
|
||||
}
|
||||
|
||||
|
||||
/// Copy data to output slice, return bytes copied
|
||||
pub fn copy_to(&mut self, dst: &mut [u8]) -> usize {
|
||||
let available = &self.data[self.position..];
|
||||
@@ -425,7 +425,7 @@ impl YieldBuffer {
|
||||
self.position += to_copy;
|
||||
to_copy
|
||||
}
|
||||
|
||||
|
||||
/// Get remaining data as slice
|
||||
pub fn as_slice(&self) -> &[u8] {
|
||||
&self.data[self.position..]
|
||||
@@ -468,106 +468,106 @@ macro_rules! ready_or_pending {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_read_buffer_basic() {
|
||||
let mut buf = ReadBuffer::with_target(10);
|
||||
assert_eq!(buf.remaining(), 10);
|
||||
assert!(!buf.is_complete());
|
||||
|
||||
|
||||
buf.extend(b"hello");
|
||||
assert_eq!(buf.len(), 5);
|
||||
assert_eq!(buf.remaining(), 5);
|
||||
assert!(!buf.is_complete());
|
||||
|
||||
|
||||
buf.extend(b"world");
|
||||
assert_eq!(buf.len(), 10);
|
||||
assert!(buf.is_complete());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_read_buffer_take() {
|
||||
let mut buf = ReadBuffer::new();
|
||||
buf.extend(b"test data");
|
||||
|
||||
|
||||
let data = buf.take();
|
||||
assert_eq!(&data[..], b"test data");
|
||||
assert!(buf.is_empty());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_write_buffer_basic() {
|
||||
let mut buf = WriteBuffer::with_max_size(100);
|
||||
assert!(buf.is_empty());
|
||||
|
||||
|
||||
buf.extend(b"hello").unwrap();
|
||||
assert_eq!(buf.len(), 5);
|
||||
assert!(!buf.is_empty());
|
||||
|
||||
|
||||
buf.advance(3);
|
||||
assert_eq!(buf.len(), 2);
|
||||
assert_eq!(buf.pending(), b"lo");
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_write_buffer_overflow() {
|
||||
let mut buf = WriteBuffer::with_max_size(10);
|
||||
assert!(buf.extend(b"short").is_ok());
|
||||
assert!(buf.extend(b"toolong").is_err());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_header_buffer() {
|
||||
let mut buf = HeaderBuffer::<5>::new();
|
||||
assert!(!buf.is_complete());
|
||||
assert_eq!(buf.remaining(), 5);
|
||||
|
||||
|
||||
buf.unfilled_mut()[..3].copy_from_slice(b"hel");
|
||||
buf.advance(3);
|
||||
assert_eq!(buf.remaining(), 2);
|
||||
|
||||
|
||||
buf.unfilled_mut()[..2].copy_from_slice(b"lo");
|
||||
buf.advance(2);
|
||||
assert!(buf.is_complete());
|
||||
assert_eq!(buf.as_array(), b"hello");
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_yield_buffer() {
|
||||
let mut buf = YieldBuffer::new(Bytes::from_static(b"hello world"));
|
||||
|
||||
|
||||
let mut dst = [0u8; 5];
|
||||
assert_eq!(buf.copy_to(&mut dst), 5);
|
||||
assert_eq!(&dst, b"hello");
|
||||
|
||||
|
||||
assert_eq!(buf.remaining(), 6);
|
||||
|
||||
|
||||
let mut dst = [0u8; 10];
|
||||
assert_eq!(buf.copy_to(&mut dst), 6);
|
||||
assert_eq!(&dst[..6], b" world");
|
||||
|
||||
|
||||
assert!(buf.is_empty());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_transition_map() {
|
||||
let t: Transition<i32, String> = Transition::Complete("hello".to_string());
|
||||
let t = t.map_output(|s| s.len());
|
||||
|
||||
|
||||
match t {
|
||||
Transition::Complete(5) => {}
|
||||
_ => panic!("Expected Complete(5)"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_poll_result() {
|
||||
let r: PollResult<i32> = PollResult::Ready(42);
|
||||
assert!(r.is_ready());
|
||||
assert_eq!(r.ok(), Some(42));
|
||||
|
||||
|
||||
let r: PollResult<i32> = PollResult::Eof;
|
||||
assert!(r.is_eof());
|
||||
assert_eq!(r.ok(), None);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,16 +38,13 @@ use bytes::{Bytes, BytesMut};
|
||||
use std::io::{self, Error, ErrorKind, Result};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
|
||||
|
||||
use super::state::{HeaderBuffer, StreamState, WriteBuffer, YieldBuffer};
|
||||
use crate::protocol::constants::{
|
||||
MAX_TLS_PLAINTEXT_SIZE,
|
||||
TLS_VERSION,
|
||||
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER,
|
||||
TLS_RECORD_HANDSHAKE, TLS_RECORD_ALERT,
|
||||
MAX_TLS_CIPHERTEXT_SIZE,
|
||||
MAX_TLS_CIPHERTEXT_SIZE, MAX_TLS_PLAINTEXT_SIZE, TLS_RECORD_ALERT, TLS_RECORD_APPLICATION,
|
||||
TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, TLS_VERSION,
|
||||
};
|
||||
use super::state::{StreamState, HeaderBuffer, YieldBuffer, WriteBuffer};
|
||||
|
||||
// ============= Constants =============
|
||||
|
||||
@@ -84,7 +81,11 @@ impl TlsRecordHeader {
|
||||
let record_type = header[0];
|
||||
let version = [header[1], header[2]];
|
||||
let length = u16::from_be_bytes([header[3], header[4]]);
|
||||
Some(Self { record_type, version, length })
|
||||
Some(Self {
|
||||
record_type,
|
||||
version,
|
||||
length,
|
||||
})
|
||||
}
|
||||
|
||||
/// Validate the header.
|
||||
@@ -111,8 +112,7 @@ impl TlsRecordHeader {
|
||||
ErrorKind::InvalidData,
|
||||
format!(
|
||||
"invalid TLS version for record type 0x{:02x}: {:02x?}",
|
||||
self.record_type,
|
||||
self.version
|
||||
self.record_type, self.version
|
||||
),
|
||||
));
|
||||
}
|
||||
@@ -129,8 +129,7 @@ impl TlsRecordHeader {
|
||||
ErrorKind::InvalidData,
|
||||
format!(
|
||||
"invalid TLS application data length: {} (min 1, max {})",
|
||||
len,
|
||||
MAX_TLS_CIPHERTEXT_SIZE
|
||||
len, MAX_TLS_CIPHERTEXT_SIZE
|
||||
),
|
||||
));
|
||||
}
|
||||
@@ -160,8 +159,7 @@ impl TlsRecordHeader {
|
||||
ErrorKind::InvalidData,
|
||||
format!(
|
||||
"invalid TLS handshake length: {} (min 4, max {})",
|
||||
len,
|
||||
MAX_TLS_PLAINTEXT_SIZE
|
||||
len, MAX_TLS_PLAINTEXT_SIZE
|
||||
),
|
||||
));
|
||||
}
|
||||
@@ -212,14 +210,10 @@ enum TlsReaderState {
|
||||
},
|
||||
|
||||
/// Have buffered data ready to yield to caller
|
||||
Yielding {
|
||||
buffer: YieldBuffer,
|
||||
},
|
||||
Yielding { buffer: YieldBuffer },
|
||||
|
||||
/// Stream encountered an error and cannot be used
|
||||
Poisoned {
|
||||
error: Option<io::Error>,
|
||||
},
|
||||
Poisoned { error: Option<io::Error> },
|
||||
}
|
||||
|
||||
impl StreamState for TlsReaderState {
|
||||
@@ -287,28 +281,41 @@ pub struct FakeTlsReader<R> {
|
||||
|
||||
impl<R> FakeTlsReader<R> {
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self { upstream, state: TlsReaderState::Idle }
|
||||
Self {
|
||||
upstream,
|
||||
state: TlsReaderState::Idle,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_ref(&self) -> &R { &self.upstream }
|
||||
pub fn get_mut(&mut self) -> &mut R { &mut self.upstream }
|
||||
pub fn into_inner(self) -> R { self.upstream }
|
||||
pub fn get_ref(&self) -> &R {
|
||||
&self.upstream
|
||||
}
|
||||
pub fn get_mut(&mut self) -> &mut R {
|
||||
&mut self.upstream
|
||||
}
|
||||
pub fn into_inner(self) -> R {
|
||||
self.upstream
|
||||
}
|
||||
|
||||
pub fn into_inner_with_pending_plaintext(mut self) -> (R, Vec<u8>) {
|
||||
let pending = match std::mem::replace(&mut self.state, TlsReaderState::Idle) {
|
||||
TlsReaderState::Yielding { buffer } => buffer.as_slice().to_vec(),
|
||||
TlsReaderState::ReadingBody { record_type, buffer, .. }
|
||||
if record_type == TLS_RECORD_APPLICATION =>
|
||||
{
|
||||
buffer.to_vec()
|
||||
}
|
||||
TlsReaderState::ReadingBody {
|
||||
record_type,
|
||||
buffer,
|
||||
..
|
||||
} if record_type == TLS_RECORD_APPLICATION => buffer.to_vec(),
|
||||
_ => Vec::new(),
|
||||
};
|
||||
(self.upstream, pending)
|
||||
}
|
||||
|
||||
pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() }
|
||||
pub fn state_name(&self) -> &'static str { self.state.state_name() }
|
||||
pub fn is_poisoned(&self) -> bool {
|
||||
self.state.is_poisoned()
|
||||
}
|
||||
pub fn state_name(&self) -> &'static str {
|
||||
self.state.state_name()
|
||||
}
|
||||
|
||||
fn poison(&mut self, error: io::Error) {
|
||||
self.state = TlsReaderState::Poisoned { error: Some(error) };
|
||||
@@ -316,9 +323,9 @@ impl<R> FakeTlsReader<R> {
|
||||
|
||||
fn take_poison_error(&mut self) -> io::Error {
|
||||
match &mut self.state {
|
||||
TlsReaderState::Poisoned { error } => error.take().unwrap_or_else(|| {
|
||||
io::Error::other("stream previously poisoned")
|
||||
}),
|
||||
TlsReaderState::Poisoned { error } => error
|
||||
.take()
|
||||
.unwrap_or_else(|| io::Error::other("stream previously poisoned")),
|
||||
_ => io::Error::other("stream not poisoned"),
|
||||
}
|
||||
}
|
||||
@@ -353,9 +360,8 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
// Poisoned state: always return the stored error
|
||||
TlsReaderState::Poisoned { error } => {
|
||||
this.state = TlsReaderState::Poisoned { error: None };
|
||||
let err = error.unwrap_or_else(|| {
|
||||
io::Error::other("stream previously poisoned")
|
||||
});
|
||||
let err =
|
||||
error.unwrap_or_else(|| io::Error::other("stream previously poisoned"));
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
|
||||
@@ -428,12 +434,20 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
}
|
||||
|
||||
// Read TLS payload
|
||||
TlsReaderState::ReadingBody { record_type, length, mut buffer } => {
|
||||
TlsReaderState::ReadingBody {
|
||||
record_type,
|
||||
length,
|
||||
mut buffer,
|
||||
} => {
|
||||
let result = poll_read_body(&mut this.upstream, cx, &mut buffer, length);
|
||||
|
||||
match result {
|
||||
BodyPollResult::Pending => {
|
||||
this.state = TlsReaderState::ReadingBody { record_type, length, buffer };
|
||||
this.state = TlsReaderState::ReadingBody {
|
||||
record_type,
|
||||
length,
|
||||
buffer,
|
||||
};
|
||||
return Poll::Pending;
|
||||
}
|
||||
BodyPollResult::Error(e) => {
|
||||
@@ -469,7 +483,10 @@ impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
|
||||
TLS_RECORD_HANDSHAKE => {
|
||||
// After FakeTLS handshake is done, we do not expect any Handshake records.
|
||||
let err = Error::new(ErrorKind::InvalidData, "unexpected TLS handshake record");
|
||||
let err = Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"unexpected TLS handshake record",
|
||||
);
|
||||
this.poison(Error::new(err.kind(), err.to_string()));
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
@@ -528,7 +545,10 @@ fn poll_read_header<R: AsyncRead + Unpin>(
|
||||
let header_bytes = *header.as_array();
|
||||
match TlsRecordHeader::parse(&header_bytes) {
|
||||
Some(h) => HeaderPollResult::Complete(h),
|
||||
None => HeaderPollResult::Error(Error::new(ErrorKind::InvalidData, "failed to parse TLS header")),
|
||||
None => HeaderPollResult::Error(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"failed to parse TLS header",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -614,9 +634,7 @@ enum TlsWriterState {
|
||||
},
|
||||
|
||||
/// Stream encountered an error and cannot be used
|
||||
Poisoned {
|
||||
error: Option<io::Error>,
|
||||
},
|
||||
Poisoned { error: Option<io::Error> },
|
||||
}
|
||||
|
||||
impl StreamState for TlsWriterState {
|
||||
@@ -652,15 +670,28 @@ pub struct FakeTlsWriter<W> {
|
||||
|
||||
impl<W> FakeTlsWriter<W> {
|
||||
pub fn new(upstream: W) -> Self {
|
||||
Self { upstream, state: TlsWriterState::Idle }
|
||||
Self {
|
||||
upstream,
|
||||
state: TlsWriterState::Idle,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_ref(&self) -> &W { &self.upstream }
|
||||
pub fn get_mut(&mut self) -> &mut W { &mut self.upstream }
|
||||
pub fn into_inner(self) -> W { self.upstream }
|
||||
pub fn get_ref(&self) -> &W {
|
||||
&self.upstream
|
||||
}
|
||||
pub fn get_mut(&mut self) -> &mut W {
|
||||
&mut self.upstream
|
||||
}
|
||||
pub fn into_inner(self) -> W {
|
||||
self.upstream
|
||||
}
|
||||
|
||||
pub fn is_poisoned(&self) -> bool { self.state.is_poisoned() }
|
||||
pub fn state_name(&self) -> &'static str { self.state.state_name() }
|
||||
pub fn is_poisoned(&self) -> bool {
|
||||
self.state.is_poisoned()
|
||||
}
|
||||
pub fn state_name(&self) -> &'static str {
|
||||
self.state.state_name()
|
||||
}
|
||||
|
||||
pub fn has_pending(&self) -> bool {
|
||||
matches!(&self.state, TlsWriterState::WritingRecord { record, .. } if !record.is_empty())
|
||||
@@ -672,9 +703,9 @@ impl<W> FakeTlsWriter<W> {
|
||||
|
||||
fn take_poison_error(&mut self) -> io::Error {
|
||||
match &mut self.state {
|
||||
TlsWriterState::Poisoned { error } => error.take().unwrap_or_else(|| {
|
||||
io::Error::other("stream previously poisoned")
|
||||
}),
|
||||
TlsWriterState::Poisoned { error } => error
|
||||
.take()
|
||||
.unwrap_or_else(|| io::Error::other("stream previously poisoned")),
|
||||
_ => io::Error::other("stream not poisoned"),
|
||||
}
|
||||
}
|
||||
@@ -725,11 +756,7 @@ impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// Take ownership of state to avoid borrow conflicts.
|
||||
@@ -738,17 +765,21 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
match state {
|
||||
TlsWriterState::Poisoned { error } => {
|
||||
this.state = TlsWriterState::Poisoned { error: None };
|
||||
let err = error.unwrap_or_else(|| {
|
||||
Error::other("stream previously poisoned")
|
||||
});
|
||||
let err = error.unwrap_or_else(|| Error::other("stream previously poisoned"));
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
|
||||
TlsWriterState::WritingRecord { mut record, payload_size } => {
|
||||
TlsWriterState::WritingRecord {
|
||||
mut record,
|
||||
payload_size,
|
||||
} => {
|
||||
// Finish writing previous record before accepting new bytes.
|
||||
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
|
||||
FlushResult::Pending => {
|
||||
this.state = TlsWriterState::WritingRecord { record, payload_size };
|
||||
this.state = TlsWriterState::WritingRecord {
|
||||
record,
|
||||
payload_size,
|
||||
};
|
||||
return Poll::Pending;
|
||||
}
|
||||
FlushResult::Error(e) => {
|
||||
@@ -780,9 +811,7 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
let record_data = Self::build_record(chunk);
|
||||
|
||||
match Pin::new(&mut this.upstream).poll_write(cx, &record_data) {
|
||||
Poll::Ready(Ok(n)) if n == record_data.len() => {
|
||||
Poll::Ready(Ok(chunk_size))
|
||||
}
|
||||
Poll::Ready(Ok(n)) if n == record_data.len() => Poll::Ready(Ok(chunk_size)),
|
||||
|
||||
Poll::Ready(Ok(n)) => {
|
||||
// Partial write of the record: store remainder.
|
||||
@@ -827,27 +856,29 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
match state {
|
||||
TlsWriterState::Poisoned { error } => {
|
||||
this.state = TlsWriterState::Poisoned { error: None };
|
||||
let err = error.unwrap_or_else(|| {
|
||||
Error::other("stream previously poisoned")
|
||||
});
|
||||
let err = error.unwrap_or_else(|| Error::other("stream previously poisoned"));
|
||||
return Poll::Ready(Err(err));
|
||||
}
|
||||
|
||||
TlsWriterState::WritingRecord { mut record, payload_size } => {
|
||||
match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
|
||||
FlushResult::Pending => {
|
||||
this.state = TlsWriterState::WritingRecord { record, payload_size };
|
||||
return Poll::Pending;
|
||||
}
|
||||
FlushResult::Error(e) => {
|
||||
this.poison(Error::new(e.kind(), e.to_string()));
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
FlushResult::Complete(_) => {
|
||||
this.state = TlsWriterState::Idle;
|
||||
}
|
||||
TlsWriterState::WritingRecord {
|
||||
mut record,
|
||||
payload_size,
|
||||
} => match Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record) {
|
||||
FlushResult::Pending => {
|
||||
this.state = TlsWriterState::WritingRecord {
|
||||
record,
|
||||
payload_size,
|
||||
};
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
FlushResult::Error(e) => {
|
||||
this.poison(Error::new(e.kind(), e.to_string()));
|
||||
return Poll::Ready(Err(e));
|
||||
}
|
||||
FlushResult::Complete(_) => {
|
||||
this.state = TlsWriterState::Idle;
|
||||
}
|
||||
},
|
||||
|
||||
TlsWriterState::Idle => {
|
||||
this.state = TlsWriterState::Idle;
|
||||
@@ -863,7 +894,10 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
let state = std::mem::replace(&mut this.state, TlsWriterState::Idle);
|
||||
|
||||
match state {
|
||||
TlsWriterState::WritingRecord { mut record, payload_size: _ } => {
|
||||
TlsWriterState::WritingRecord {
|
||||
mut record,
|
||||
payload_size: _,
|
||||
} => {
|
||||
// Best-effort flush (do not block shutdown forever).
|
||||
let _ = Self::poll_flush_record_inner(&mut this.upstream, cx, &mut record);
|
||||
this.state = TlsWriterState::Idle;
|
||||
@@ -905,10 +939,10 @@ mod size_adversarial_tests;
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::VecDeque;
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
|
||||
// ============= Test Helpers =============
|
||||
|
||||
|
||||
/// Build a valid TLS Application Data record
|
||||
fn build_tls_record(data: &[u8]) -> Vec<u8> {
|
||||
let mut record = vec![
|
||||
@@ -921,24 +955,25 @@ mod tests {
|
||||
record.extend_from_slice(data);
|
||||
record
|
||||
}
|
||||
|
||||
|
||||
/// Build a Change Cipher Spec record
|
||||
fn build_ccs_record() -> Vec<u8> {
|
||||
vec![
|
||||
TLS_RECORD_CHANGE_CIPHER,
|
||||
TLS_VERSION[0],
|
||||
TLS_VERSION[1],
|
||||
0x00, 0x01, // length = 1
|
||||
0x01, // CCS byte
|
||||
0x00,
|
||||
0x01, // length = 1
|
||||
0x01, // CCS byte
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
/// Mock reader that returns data in chunks
|
||||
struct ChunkedReader {
|
||||
data: VecDeque<u8>,
|
||||
chunk_size: usize,
|
||||
}
|
||||
|
||||
|
||||
impl ChunkedReader {
|
||||
fn new(data: &[u8], chunk_size: usize) -> Self {
|
||||
Self {
|
||||
@@ -947,7 +982,7 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl AsyncRead for ChunkedReader {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
@@ -957,92 +992,92 @@ mod tests {
|
||||
if self.data.is_empty() {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
|
||||
let to_read = self.chunk_size.min(self.data.len()).min(buf.remaining());
|
||||
for _ in 0..to_read {
|
||||
if let Some(byte) = self.data.pop_front() {
|
||||
buf.put_slice(&[byte]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ============= FakeTlsReader Tests =============
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_single_record() {
|
||||
let payload = b"Hello, TLS!";
|
||||
let record = build_tls_record(payload);
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&record, 100);
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
|
||||
assert_eq!(&buf[..], payload);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_multiple_records() {
|
||||
let payload1 = b"First record";
|
||||
let payload2 = b"Second record";
|
||||
|
||||
|
||||
let mut data = build_tls_record(payload1);
|
||||
data.extend_from_slice(&build_tls_record(payload2));
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&data, 100);
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let buf1 = tls_reader.read_exact(payload1.len()).await.unwrap();
|
||||
assert_eq!(&buf1[..], payload1);
|
||||
|
||||
|
||||
let buf2 = tls_reader.read_exact(payload2.len()).await.unwrap();
|
||||
assert_eq!(&buf2[..], payload2);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_partial_header() {
|
||||
// Read header byte by byte
|
||||
let payload = b"Test";
|
||||
let record = build_tls_record(payload);
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&record, 1); // 1 byte at a time!
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
|
||||
|
||||
|
||||
assert_eq!(&buf[..], payload);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_partial_body() {
|
||||
let payload = b"This is a longer payload that will be read in parts";
|
||||
let record = build_tls_record(payload);
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&record, 7); // Awkward chunk size
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
|
||||
|
||||
|
||||
assert_eq!(&buf[..], payload);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_skip_ccs() {
|
||||
// CCS record followed by application data
|
||||
let mut data = build_ccs_record();
|
||||
let payload = b"After CCS";
|
||||
data.extend_from_slice(&build_tls_record(payload));
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&data, 100);
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
|
||||
|
||||
|
||||
assert_eq!(&buf[..], payload);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_multiple_ccs() {
|
||||
// Multiple CCS records
|
||||
@@ -1050,127 +1085,127 @@ mod tests {
|
||||
data.extend_from_slice(&build_ccs_record());
|
||||
let payload = b"After multiple CCS";
|
||||
data.extend_from_slice(&build_tls_record(payload));
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&data, 3); // Small chunks
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let buf = tls_reader.read_exact(payload.len()).await.unwrap();
|
||||
|
||||
|
||||
assert_eq!(&buf[..], payload);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_eof() {
|
||||
let reader = ChunkedReader::new(&[], 100);
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let mut buf = vec![0u8; 10];
|
||||
let read = tls_reader.read(&mut buf).await.unwrap();
|
||||
|
||||
|
||||
assert_eq!(read, 0);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_state_names() {
|
||||
let reader = ChunkedReader::new(&[], 100);
|
||||
let tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
assert_eq!(tls_reader.state_name(), "Idle");
|
||||
assert!(!tls_reader.is_poisoned());
|
||||
}
|
||||
|
||||
|
||||
// ============= FakeTlsWriter Tests =============
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_writer_single_write() {
|
||||
let (client, mut server) = duplex(4096);
|
||||
let mut writer = FakeTlsWriter::new(client);
|
||||
|
||||
|
||||
let payload = b"Hello, TLS!";
|
||||
writer.write_all(payload).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
|
||||
// Read the TLS record from server
|
||||
let mut header = [0u8; 5];
|
||||
server.read_exact(&mut header).await.unwrap();
|
||||
|
||||
|
||||
assert_eq!(header[0], TLS_RECORD_APPLICATION);
|
||||
assert_eq!(&header[1..3], &TLS_VERSION);
|
||||
|
||||
|
||||
let length = u16::from_be_bytes([header[3], header[4]]) as usize;
|
||||
assert_eq!(length, payload.len());
|
||||
|
||||
|
||||
let mut body = vec![0u8; length];
|
||||
server.read_exact(&mut body).await.unwrap();
|
||||
assert_eq!(&body, payload);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_writer_large_data_chunking() {
|
||||
let (client, mut server) = duplex(65536);
|
||||
let mut writer = FakeTlsWriter::new(client);
|
||||
|
||||
|
||||
// Write data larger than MAX_TLS_PAYLOAD
|
||||
let payload: Vec<u8> = (0..20000).map(|i| (i % 256) as u8).collect();
|
||||
writer.write_all(&payload).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
|
||||
// Read back - should be multiple records
|
||||
let mut received = Vec::new();
|
||||
let mut records_count = 0;
|
||||
|
||||
|
||||
while received.len() < payload.len() {
|
||||
let mut header = [0u8; 5];
|
||||
if server.read_exact(&mut header).await.is_err() {
|
||||
break;
|
||||
}
|
||||
|
||||
|
||||
assert_eq!(header[0], TLS_RECORD_APPLICATION);
|
||||
records_count += 1;
|
||||
|
||||
|
||||
let length = u16::from_be_bytes([header[3], header[4]]) as usize;
|
||||
assert!(length <= MAX_TLS_PAYLOAD);
|
||||
|
||||
|
||||
let mut body = vec![0u8; length];
|
||||
server.read_exact(&mut body).await.unwrap();
|
||||
received.extend_from_slice(&body);
|
||||
}
|
||||
|
||||
|
||||
assert_eq!(received, payload);
|
||||
assert!(records_count >= 2); // Should have multiple records
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_stream_roundtrip() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
|
||||
let mut writer = FakeTlsWriter::new(client);
|
||||
let mut reader = FakeTlsReader::new(server);
|
||||
|
||||
|
||||
let original = b"Hello, fake TLS!";
|
||||
writer.write_all_tls(original).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
|
||||
let received = reader.read_exact(original.len()).await.unwrap();
|
||||
assert_eq!(&received[..], original);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_stream_roundtrip_large() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
|
||||
let mut writer = FakeTlsWriter::new(client);
|
||||
let mut reader = FakeTlsReader::new(server);
|
||||
|
||||
|
||||
let original: Vec<u8> = (0..50000).map(|i| (i % 256) as u8).collect();
|
||||
|
||||
|
||||
// Write in background
|
||||
let write_data = original.clone();
|
||||
let write_handle = tokio::spawn(async move {
|
||||
writer.write_all_tls(&write_data).await.unwrap();
|
||||
writer.shutdown().await.unwrap();
|
||||
});
|
||||
|
||||
|
||||
// Read
|
||||
let mut received = Vec::new();
|
||||
let mut buf = vec![0u8; 1024];
|
||||
@@ -1181,87 +1216,95 @@ mod tests {
|
||||
}
|
||||
received.extend_from_slice(&buf[..n]);
|
||||
}
|
||||
|
||||
|
||||
write_handle.await.unwrap();
|
||||
assert_eq!(received, original);
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_writer_state_names() {
|
||||
let (client, _server) = duplex(4096);
|
||||
let writer = FakeTlsWriter::new(client);
|
||||
|
||||
|
||||
assert_eq!(writer.state_name(), "Idle");
|
||||
assert!(!writer.is_poisoned());
|
||||
assert!(!writer.has_pending());
|
||||
}
|
||||
|
||||
|
||||
// ============= Error Handling Tests =============
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_invalid_version() {
|
||||
let invalid_record = vec![
|
||||
TLS_RECORD_APPLICATION,
|
||||
0x04, 0x00, // Invalid version
|
||||
0x00, 0x05, // length = 5
|
||||
0x01, 0x02, 0x03, 0x04, 0x05,
|
||||
0x04,
|
||||
0x00, // Invalid version
|
||||
0x00,
|
||||
0x05, // length = 5
|
||||
0x01,
|
||||
0x02,
|
||||
0x03,
|
||||
0x04,
|
||||
0x05,
|
||||
];
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&invalid_record, 100);
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let mut buf = vec![0u8; 5];
|
||||
let result = tls_reader.read(&mut buf).await;
|
||||
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(tls_reader.is_poisoned());
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_unexpected_eof_header() {
|
||||
// Partial header
|
||||
let partial = vec![TLS_RECORD_APPLICATION, 0x03];
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&partial, 100);
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let mut buf = vec![0u8; 10];
|
||||
let result = tls_reader.read(&mut buf).await;
|
||||
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_reader_unexpected_eof_body() {
|
||||
// Valid header but truncated body
|
||||
let mut record = vec![
|
||||
TLS_RECORD_APPLICATION,
|
||||
TLS_VERSION[0], TLS_VERSION[1],
|
||||
0x00, 0x10, // length = 16
|
||||
TLS_VERSION[0],
|
||||
TLS_VERSION[1],
|
||||
0x00,
|
||||
0x10, // length = 16
|
||||
];
|
||||
record.extend_from_slice(&[0u8; 8]); // Only 8 bytes of body
|
||||
|
||||
|
||||
let reader = ChunkedReader::new(&record, 100);
|
||||
let mut tls_reader = FakeTlsReader::new(reader);
|
||||
|
||||
|
||||
let mut buf = vec![0u8; 16];
|
||||
let result = tls_reader.read(&mut buf).await;
|
||||
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
|
||||
// ============= Header Parsing Tests =============
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_tls_record_header_parse() {
|
||||
let header = [0x17, 0x03, 0x03, 0x01, 0x00];
|
||||
let parsed = TlsRecordHeader::parse(&header).unwrap();
|
||||
|
||||
|
||||
assert_eq!(parsed.record_type, TLS_RECORD_APPLICATION);
|
||||
assert_eq!(parsed.version, TLS_VERSION);
|
||||
assert_eq!(parsed.length, 256);
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_tls_record_header_validate() {
|
||||
let valid = TlsRecordHeader {
|
||||
@@ -1270,14 +1313,14 @@ mod tests {
|
||||
length: 100,
|
||||
};
|
||||
assert!(valid.validate().is_ok());
|
||||
|
||||
|
||||
let invalid_version = TlsRecordHeader {
|
||||
record_type: TLS_RECORD_APPLICATION,
|
||||
version: [0x04, 0x00],
|
||||
length: 100,
|
||||
};
|
||||
assert!(invalid_version.validate().is_err());
|
||||
|
||||
|
||||
let too_large = TlsRecordHeader {
|
||||
record_type: TLS_RECORD_APPLICATION,
|
||||
version: TLS_VERSION,
|
||||
@@ -1285,7 +1328,7 @@ mod tests {
|
||||
};
|
||||
assert!(too_large.validate().is_err());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_tls_record_header_to_bytes() {
|
||||
let header = TlsRecordHeader {
|
||||
@@ -1293,7 +1336,7 @@ mod tests {
|
||||
version: TLS_VERSION,
|
||||
length: 0x1234,
|
||||
};
|
||||
|
||||
|
||||
let bytes = header.to_bytes();
|
||||
assert_eq!(bytes, [0x17, 0x03, 0x03, 0x12, 0x34]);
|
||||
}
|
||||
|
||||
@@ -13,8 +13,7 @@ fn reading_body_pending_application_plaintext_is_preserved_on_into_inner() {
|
||||
|
||||
let (_inner, pending) = reader.into_inner_with_pending_plaintext();
|
||||
assert_eq!(
|
||||
pending,
|
||||
sample,
|
||||
pending, sample,
|
||||
"partial application-data body must survive into fallback path"
|
||||
);
|
||||
}
|
||||
@@ -59,7 +58,10 @@ fn partial_header_state_does_not_produce_plaintext() {
|
||||
reader.state = TlsReaderState::ReadingHeader { header };
|
||||
|
||||
let (_inner, pending) = reader.into_inner_with_pending_plaintext();
|
||||
assert!(pending.is_empty(), "partial header bytes are not plaintext payload");
|
||||
assert!(
|
||||
pending.is_empty(),
|
||||
"partial header bytes are not plaintext payload"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -83,7 +85,10 @@ fn adversarial_poisoned_state_never_leaks_pending_bytes() {
|
||||
};
|
||||
|
||||
let (_inner, pending) = reader.into_inner_with_pending_plaintext();
|
||||
assert!(pending.is_empty(), "poisoned state must fail-closed for fallback payload");
|
||||
assert!(
|
||||
pending.is_empty(),
|
||||
"poisoned state must fail-closed for fallback payload"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -101,7 +106,10 @@ fn stress_large_application_fragment_survives_state_extraction() {
|
||||
};
|
||||
|
||||
let (_inner, pending) = reader.into_inner_with_pending_plaintext();
|
||||
assert_eq!(pending, payload, "large pending application plaintext must be preserved exactly");
|
||||
assert_eq!(
|
||||
pending, payload,
|
||||
"large pending application plaintext must be preserved exactly"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -308,52 +308,228 @@ macro_rules! expect_accept {
|
||||
};
|
||||
}
|
||||
|
||||
expect_reject!(appdata_zero_len_must_be_rejected, TLS_RECORD_APPLICATION, TLS_VERSION, 0);
|
||||
expect_accept!(appdata_one_len_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, 1);
|
||||
expect_accept!(appdata_small_len_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, 32);
|
||||
expect_accept!(appdata_medium_len_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, 1024);
|
||||
expect_accept!(appdata_plaintext_limit_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, MAX_TLS_PLAINTEXT_SIZE as u16);
|
||||
expect_accept!(appdata_ciphertext_limit_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, MAX_TLS_CIPHERTEXT_SIZE as u16);
|
||||
expect_reject!(appdata_ciphertext_plus_one_must_be_rejected, TLS_RECORD_APPLICATION, TLS_VERSION, (MAX_TLS_CIPHERTEXT_SIZE as u16) + 1);
|
||||
expect_reject!(
|
||||
appdata_zero_len_must_be_rejected,
|
||||
TLS_RECORD_APPLICATION,
|
||||
TLS_VERSION,
|
||||
0
|
||||
);
|
||||
expect_accept!(
|
||||
appdata_one_len_is_accepted,
|
||||
TLS_RECORD_APPLICATION,
|
||||
TLS_VERSION,
|
||||
1
|
||||
);
|
||||
expect_accept!(
|
||||
appdata_small_len_is_accepted,
|
||||
TLS_RECORD_APPLICATION,
|
||||
TLS_VERSION,
|
||||
32
|
||||
);
|
||||
expect_accept!(
|
||||
appdata_medium_len_is_accepted,
|
||||
TLS_RECORD_APPLICATION,
|
||||
TLS_VERSION,
|
||||
1024
|
||||
);
|
||||
expect_accept!(
|
||||
appdata_plaintext_limit_is_accepted,
|
||||
TLS_RECORD_APPLICATION,
|
||||
TLS_VERSION,
|
||||
MAX_TLS_PLAINTEXT_SIZE as u16
|
||||
);
|
||||
expect_accept!(
|
||||
appdata_ciphertext_limit_is_accepted,
|
||||
TLS_RECORD_APPLICATION,
|
||||
TLS_VERSION,
|
||||
MAX_TLS_CIPHERTEXT_SIZE as u16
|
||||
);
|
||||
expect_reject!(
|
||||
appdata_ciphertext_plus_one_must_be_rejected,
|
||||
TLS_RECORD_APPLICATION,
|
||||
TLS_VERSION,
|
||||
(MAX_TLS_CIPHERTEXT_SIZE as u16) + 1
|
||||
);
|
||||
|
||||
expect_reject!(appdata_tls10_header_len_one_must_be_rejected, TLS_RECORD_APPLICATION, [0x03, 0x01], 1);
|
||||
expect_reject!(appdata_tls10_header_medium_must_be_rejected, TLS_RECORD_APPLICATION, [0x03, 0x01], 1024);
|
||||
expect_reject!(appdata_tls10_header_ciphertext_limit_must_be_rejected, TLS_RECORD_APPLICATION, [0x03, 0x01], MAX_TLS_CIPHERTEXT_SIZE as u16);
|
||||
expect_reject!(
|
||||
appdata_tls10_header_len_one_must_be_rejected,
|
||||
TLS_RECORD_APPLICATION,
|
||||
[0x03, 0x01],
|
||||
1
|
||||
);
|
||||
expect_reject!(
|
||||
appdata_tls10_header_medium_must_be_rejected,
|
||||
TLS_RECORD_APPLICATION,
|
||||
[0x03, 0x01],
|
||||
1024
|
||||
);
|
||||
expect_reject!(
|
||||
appdata_tls10_header_ciphertext_limit_must_be_rejected,
|
||||
TLS_RECORD_APPLICATION,
|
||||
[0x03, 0x01],
|
||||
MAX_TLS_CIPHERTEXT_SIZE as u16
|
||||
);
|
||||
|
||||
expect_reject!(ccs_tls10_header_len_one_must_be_rejected, TLS_RECORD_CHANGE_CIPHER, [0x03, 0x01], 1);
|
||||
expect_reject!(ccs_tls10_header_len_zero_must_be_rejected, TLS_RECORD_CHANGE_CIPHER, [0x03, 0x01], 0);
|
||||
expect_reject!(ccs_tls10_header_len_two_must_be_rejected, TLS_RECORD_CHANGE_CIPHER, [0x03, 0x01], 2);
|
||||
expect_reject!(
|
||||
ccs_tls10_header_len_one_must_be_rejected,
|
||||
TLS_RECORD_CHANGE_CIPHER,
|
||||
[0x03, 0x01],
|
||||
1
|
||||
);
|
||||
expect_reject!(
|
||||
ccs_tls10_header_len_zero_must_be_rejected,
|
||||
TLS_RECORD_CHANGE_CIPHER,
|
||||
[0x03, 0x01],
|
||||
0
|
||||
);
|
||||
expect_reject!(
|
||||
ccs_tls10_header_len_two_must_be_rejected,
|
||||
TLS_RECORD_CHANGE_CIPHER,
|
||||
[0x03, 0x01],
|
||||
2
|
||||
);
|
||||
|
||||
expect_reject!(alert_tls10_header_len_two_must_be_rejected, TLS_RECORD_ALERT, [0x03, 0x01], 2);
|
||||
expect_reject!(alert_tls10_header_len_one_must_be_rejected, TLS_RECORD_ALERT, [0x03, 0x01], 1);
|
||||
expect_reject!(alert_tls10_header_len_three_must_be_rejected, TLS_RECORD_ALERT, [0x03, 0x01], 3);
|
||||
expect_reject!(
|
||||
alert_tls10_header_len_two_must_be_rejected,
|
||||
TLS_RECORD_ALERT,
|
||||
[0x03, 0x01],
|
||||
2
|
||||
);
|
||||
expect_reject!(
|
||||
alert_tls10_header_len_one_must_be_rejected,
|
||||
TLS_RECORD_ALERT,
|
||||
[0x03, 0x01],
|
||||
1
|
||||
);
|
||||
expect_reject!(
|
||||
alert_tls10_header_len_three_must_be_rejected,
|
||||
TLS_RECORD_ALERT,
|
||||
[0x03, 0x01],
|
||||
3
|
||||
);
|
||||
|
||||
expect_accept!(handshake_tls10_header_min_len_is_accepted, TLS_RECORD_HANDSHAKE, [0x03, 0x01], 4);
|
||||
expect_accept!(handshake_tls10_header_plaintext_limit_is_accepted, TLS_RECORD_HANDSHAKE, [0x03, 0x01], MAX_TLS_PLAINTEXT_SIZE as u16);
|
||||
expect_reject!(handshake_tls10_header_too_small_must_be_rejected, TLS_RECORD_HANDSHAKE, [0x03, 0x01], 3);
|
||||
expect_reject!(handshake_tls10_header_too_large_must_be_rejected, TLS_RECORD_HANDSHAKE, [0x03, 0x01], (MAX_TLS_PLAINTEXT_SIZE as u16) + 1);
|
||||
expect_accept!(
|
||||
handshake_tls10_header_min_len_is_accepted,
|
||||
TLS_RECORD_HANDSHAKE,
|
||||
[0x03, 0x01],
|
||||
4
|
||||
);
|
||||
expect_accept!(
|
||||
handshake_tls10_header_plaintext_limit_is_accepted,
|
||||
TLS_RECORD_HANDSHAKE,
|
||||
[0x03, 0x01],
|
||||
MAX_TLS_PLAINTEXT_SIZE as u16
|
||||
);
|
||||
expect_reject!(
|
||||
handshake_tls10_header_too_small_must_be_rejected,
|
||||
TLS_RECORD_HANDSHAKE,
|
||||
[0x03, 0x01],
|
||||
3
|
||||
);
|
||||
expect_reject!(
|
||||
handshake_tls10_header_too_large_must_be_rejected,
|
||||
TLS_RECORD_HANDSHAKE,
|
||||
[0x03, 0x01],
|
||||
(MAX_TLS_PLAINTEXT_SIZE as u16) + 1
|
||||
);
|
||||
|
||||
expect_reject!(unknown_type_tls13_zero_must_be_rejected, 0x00, TLS_VERSION, 0);
|
||||
expect_reject!(unknown_type_tls13_small_must_be_rejected, 0x13, TLS_VERSION, 32);
|
||||
expect_reject!(unknown_type_tls13_large_must_be_rejected, 0xfe, TLS_VERSION, MAX_TLS_CIPHERTEXT_SIZE as u16);
|
||||
expect_reject!(unknown_type_tls10_small_must_be_rejected, 0x13, [0x03, 0x01], 32);
|
||||
expect_reject!(
|
||||
unknown_type_tls13_zero_must_be_rejected,
|
||||
0x00,
|
||||
TLS_VERSION,
|
||||
0
|
||||
);
|
||||
expect_reject!(
|
||||
unknown_type_tls13_small_must_be_rejected,
|
||||
0x13,
|
||||
TLS_VERSION,
|
||||
32
|
||||
);
|
||||
expect_reject!(
|
||||
unknown_type_tls13_large_must_be_rejected,
|
||||
0xfe,
|
||||
TLS_VERSION,
|
||||
MAX_TLS_CIPHERTEXT_SIZE as u16
|
||||
);
|
||||
expect_reject!(
|
||||
unknown_type_tls10_small_must_be_rejected,
|
||||
0x13,
|
||||
[0x03, 0x01],
|
||||
32
|
||||
);
|
||||
|
||||
expect_reject!(appdata_invalid_version_0302_must_be_rejected, TLS_RECORD_APPLICATION, [0x03, 0x02], 128);
|
||||
expect_reject!(handshake_invalid_version_0302_must_be_rejected, TLS_RECORD_HANDSHAKE, [0x03, 0x02], 128);
|
||||
expect_reject!(alert_invalid_version_0302_must_be_rejected, TLS_RECORD_ALERT, [0x03, 0x02], 2);
|
||||
expect_reject!(ccs_invalid_version_0302_must_be_rejected, TLS_RECORD_CHANGE_CIPHER, [0x03, 0x02], 1);
|
||||
expect_reject!(
|
||||
appdata_invalid_version_0302_must_be_rejected,
|
||||
TLS_RECORD_APPLICATION,
|
||||
[0x03, 0x02],
|
||||
128
|
||||
);
|
||||
expect_reject!(
|
||||
handshake_invalid_version_0302_must_be_rejected,
|
||||
TLS_RECORD_HANDSHAKE,
|
||||
[0x03, 0x02],
|
||||
128
|
||||
);
|
||||
expect_reject!(
|
||||
alert_invalid_version_0302_must_be_rejected,
|
||||
TLS_RECORD_ALERT,
|
||||
[0x03, 0x02],
|
||||
2
|
||||
);
|
||||
expect_reject!(
|
||||
ccs_invalid_version_0302_must_be_rejected,
|
||||
TLS_RECORD_CHANGE_CIPHER,
|
||||
[0x03, 0x02],
|
||||
1
|
||||
);
|
||||
|
||||
expect_reject!(appdata_invalid_version_0304_must_be_rejected, TLS_RECORD_APPLICATION, [0x03, 0x04], 128);
|
||||
expect_reject!(handshake_invalid_version_0304_must_be_rejected, TLS_RECORD_HANDSHAKE, [0x03, 0x04], 128);
|
||||
expect_reject!(alert_invalid_version_0304_must_be_rejected, TLS_RECORD_ALERT, [0x03, 0x04], 2);
|
||||
expect_reject!(ccs_invalid_version_0304_must_be_rejected, TLS_RECORD_CHANGE_CIPHER, [0x03, 0x04], 1);
|
||||
expect_reject!(
|
||||
appdata_invalid_version_0304_must_be_rejected,
|
||||
TLS_RECORD_APPLICATION,
|
||||
[0x03, 0x04],
|
||||
128
|
||||
);
|
||||
expect_reject!(
|
||||
handshake_invalid_version_0304_must_be_rejected,
|
||||
TLS_RECORD_HANDSHAKE,
|
||||
[0x03, 0x04],
|
||||
128
|
||||
);
|
||||
expect_reject!(
|
||||
alert_invalid_version_0304_must_be_rejected,
|
||||
TLS_RECORD_ALERT,
|
||||
[0x03, 0x04],
|
||||
2
|
||||
);
|
||||
expect_reject!(
|
||||
ccs_invalid_version_0304_must_be_rejected,
|
||||
TLS_RECORD_CHANGE_CIPHER,
|
||||
[0x03, 0x04],
|
||||
1
|
||||
);
|
||||
|
||||
expect_accept!(handshake_tls13_len_5_is_accepted, TLS_RECORD_HANDSHAKE, TLS_VERSION, 5);
|
||||
expect_accept!(appdata_tls13_len_16385_is_accepted, TLS_RECORD_APPLICATION, TLS_VERSION, (MAX_TLS_PLAINTEXT_SIZE as u16) + 1);
|
||||
expect_accept!(
|
||||
handshake_tls13_len_5_is_accepted,
|
||||
TLS_RECORD_HANDSHAKE,
|
||||
TLS_VERSION,
|
||||
5
|
||||
);
|
||||
expect_accept!(
|
||||
appdata_tls13_len_16385_is_accepted,
|
||||
TLS_RECORD_APPLICATION,
|
||||
TLS_VERSION,
|
||||
(MAX_TLS_PLAINTEXT_SIZE as u16) + 1
|
||||
);
|
||||
|
||||
#[test]
|
||||
fn matrix_version_policy_is_strict_and_deterministic() {
|
||||
let versions = [[0x03, 0x01], TLS_VERSION, [0x03, 0x02], [0x03, 0x04], [0x00, 0x00]];
|
||||
let versions = [
|
||||
[0x03, 0x01],
|
||||
TLS_VERSION,
|
||||
[0x03, 0x02],
|
||||
[0x03, 0x04],
|
||||
[0x00, 0x00],
|
||||
];
|
||||
let record_types = [
|
||||
TLS_RECORD_APPLICATION,
|
||||
TLS_RECORD_CHANGE_CIPHER,
|
||||
@@ -389,22 +565,62 @@ fn matrix_version_policy_is_strict_and_deterministic() {
|
||||
|
||||
#[test]
|
||||
fn appdata_partition_property_holds_for_all_u16_edges() {
|
||||
for len in [0u16, 1, 2, 3, 64, 255, 1024, 4096, 8192, 16_384, 16_385, 16_640, 16_641, u16::MAX] {
|
||||
for len in [
|
||||
0u16,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
64,
|
||||
255,
|
||||
1024,
|
||||
4096,
|
||||
8192,
|
||||
16_384,
|
||||
16_385,
|
||||
16_640,
|
||||
16_641,
|
||||
u16::MAX,
|
||||
] {
|
||||
let accepted = validates(TLS_RECORD_APPLICATION, TLS_VERSION, len);
|
||||
let expected = len >= 1 && usize::from(len) <= MAX_TLS_CIPHERTEXT_SIZE;
|
||||
assert_eq!(accepted, expected, "unexpected appdata decision for len={len}");
|
||||
assert_eq!(
|
||||
accepted, expected,
|
||||
"unexpected appdata decision for len={len}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handshake_partition_property_holds_for_all_u16_edges() {
|
||||
for len in [0u16, 1, 2, 3, 4, 5, 64, 255, 1024, 4096, 8192, 16_383, 16_384, 16_385, u16::MAX] {
|
||||
for len in [
|
||||
0u16,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
64,
|
||||
255,
|
||||
1024,
|
||||
4096,
|
||||
8192,
|
||||
16_383,
|
||||
16_384,
|
||||
16_385,
|
||||
u16::MAX,
|
||||
] {
|
||||
let accepted_tls13 = validates(TLS_RECORD_HANDSHAKE, TLS_VERSION, len);
|
||||
let accepted_tls10 = validates(TLS_RECORD_HANDSHAKE, [0x03, 0x01], len);
|
||||
let expected = (4..=MAX_TLS_PLAINTEXT_SIZE).contains(&usize::from(len));
|
||||
|
||||
assert_eq!(accepted_tls13, expected, "TLS1.3 handshake mismatch for len={len}");
|
||||
assert_eq!(accepted_tls10, expected, "TLS1.0 compat handshake mismatch for len={len}");
|
||||
assert_eq!(
|
||||
accepted_tls13, expected,
|
||||
"TLS1.3 handshake mismatch for len={len}"
|
||||
);
|
||||
assert_eq!(
|
||||
accepted_tls10, expected,
|
||||
"TLS1.0 compat handshake mismatch for len={len}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -419,14 +635,24 @@ fn control_record_exact_lengths_are_enforced_under_fuzzed_lengths() {
|
||||
let alert_ok = validates(TLS_RECORD_ALERT, TLS_VERSION, len);
|
||||
|
||||
assert_eq!(ccs_ok, len == 1, "ccs length gate mismatch for len={len}");
|
||||
assert_eq!(alert_ok, len == 2, "alert length gate mismatch for len={len}");
|
||||
assert_eq!(
|
||||
alert_ok,
|
||||
len == 2,
|
||||
"alert length gate mismatch for len={len}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_record_types_never_validate_under_supported_versions() {
|
||||
for record_type in 0u8..=255 {
|
||||
if matches!(record_type, TLS_RECORD_APPLICATION | TLS_RECORD_CHANGE_CIPHER | TLS_RECORD_ALERT | TLS_RECORD_HANDSHAKE) {
|
||||
if matches!(
|
||||
record_type,
|
||||
TLS_RECORD_APPLICATION
|
||||
| TLS_RECORD_CHANGE_CIPHER
|
||||
| TLS_RECORD_ALERT
|
||||
| TLS_RECORD_HANDSHAKE
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -458,9 +684,15 @@ async fn reader_rejects_tls10_appdata_header_before_payload_processing() {
|
||||
#[tokio::test]
|
||||
async fn reader_rejects_zero_len_appdata_record() {
|
||||
let (mut tx, rx) = tokio::io::duplex(128);
|
||||
tx.write_all(&[TLS_RECORD_APPLICATION, TLS_VERSION[0], TLS_VERSION[1], 0x00, 0x00])
|
||||
.await
|
||||
.unwrap();
|
||||
tx.write_all(&[
|
||||
TLS_RECORD_APPLICATION,
|
||||
TLS_VERSION[0],
|
||||
TLS_VERSION[1],
|
||||
0x00,
|
||||
0x00,
|
||||
])
|
||||
.await
|
||||
.unwrap();
|
||||
tx.shutdown().await.unwrap();
|
||||
|
||||
let mut reader = FakeTlsReader::new(rx);
|
||||
@@ -472,9 +704,16 @@ async fn reader_rejects_zero_len_appdata_record() {
|
||||
#[tokio::test]
|
||||
async fn reader_accepts_single_byte_tls13_appdata_and_yields_payload() {
|
||||
let (mut tx, rx) = tokio::io::duplex(128);
|
||||
tx.write_all(&[TLS_RECORD_APPLICATION, TLS_VERSION[0], TLS_VERSION[1], 0x00, 0x01, 0x5A])
|
||||
.await
|
||||
.unwrap();
|
||||
tx.write_all(&[
|
||||
TLS_RECORD_APPLICATION,
|
||||
TLS_VERSION[0],
|
||||
TLS_VERSION[1],
|
||||
0x00,
|
||||
0x01,
|
||||
0x5A,
|
||||
])
|
||||
.await
|
||||
.unwrap();
|
||||
tx.shutdown().await.unwrap();
|
||||
|
||||
let mut reader = FakeTlsReader::new(rx);
|
||||
|
||||
@@ -23,12 +23,12 @@ impl FrameMeta {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
|
||||
pub fn with_quickack(mut self) -> Self {
|
||||
self.quickack = true;
|
||||
self
|
||||
}
|
||||
|
||||
|
||||
pub fn with_simple_ack(mut self) -> Self {
|
||||
self.simple_ack = true;
|
||||
self
|
||||
@@ -48,10 +48,10 @@ pub enum ReadFrameResult {
|
||||
pub trait LayeredStream<U> {
|
||||
/// Get reference to upstream
|
||||
fn upstream(&self) -> &U;
|
||||
|
||||
|
||||
/// Get mutable reference to upstream
|
||||
fn upstream_mut(&mut self) -> &mut U;
|
||||
|
||||
|
||||
/// Consume self and return upstream
|
||||
fn into_upstream(self) -> U;
|
||||
}
|
||||
@@ -65,7 +65,7 @@ impl<R> ReadHalf<R> {
|
||||
pub fn new(inner: R) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
|
||||
pub fn into_inner(self) -> R {
|
||||
self.inner
|
||||
}
|
||||
@@ -90,7 +90,7 @@ impl<W> WriteHalf<W> {
|
||||
pub fn new(inner: W) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
|
||||
pub fn into_inner(self) -> W {
|
||||
self.inner
|
||||
}
|
||||
@@ -104,12 +104,12 @@ impl<W: AsyncWrite + Unpin> AsyncWrite for WriteHalf<W> {
|
||||
) -> Poll<Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user