mirror of
https://github.com/telemt/telemt.git
synced 2026-04-18 19:14:09 +03:00
1.0.0
Tschuss Status Quo - Hallo, Zukunft!
This commit is contained in:
474
src/stream/crypto_stream.rs
Normal file
474
src/stream/crypto_stream.rs
Normal file
@@ -0,0 +1,474 @@
|
||||
//! Encrypted stream wrappers using AES-CTR
|
||||
|
||||
use bytes::{Bytes, BytesMut, BufMut};
|
||||
use std::io::{Error, ErrorKind, Result};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
|
||||
use crate::crypto::AesCtr;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
/// Reader that decrypts data using AES-CTR
|
||||
pub struct CryptoReader<R> {
|
||||
upstream: R,
|
||||
decryptor: AesCtr,
|
||||
buffer: BytesMut,
|
||||
}
|
||||
|
||||
impl<R> CryptoReader<R> {
|
||||
/// Create new crypto reader
|
||||
pub fn new(upstream: R, decryptor: AesCtr) -> Self {
|
||||
Self {
|
||||
upstream,
|
||||
decryptor,
|
||||
buffer: BytesMut::with_capacity(8192),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get reference to upstream
|
||||
pub fn get_ref(&self) -> &R {
|
||||
&self.upstream
|
||||
}
|
||||
|
||||
/// Get mutable reference to upstream
|
||||
pub fn get_mut(&mut self) -> &mut R {
|
||||
&mut self.upstream
|
||||
}
|
||||
|
||||
/// Consume and return upstream
|
||||
pub fn into_inner(self) -> R {
|
||||
self.upstream
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> AsyncRead for CryptoReader<R> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<Result<()>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
if !this.buffer.is_empty() {
|
||||
let to_copy = this.buffer.len().min(buf.remaining());
|
||||
buf.put_slice(&this.buffer.split_to(to_copy));
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
// Zero-copy Reader
|
||||
let before = buf.filled().len();
|
||||
|
||||
match Pin::new(&mut this.upstream).poll_read(cx, buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let after = buf.filled().len();
|
||||
let bytes_read = after - before;
|
||||
|
||||
if bytes_read > 0 {
|
||||
// Decrypt in-place
|
||||
let filled = buf.filled_mut();
|
||||
this.decryptor.apply(&mut filled[before..after]);
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> CryptoReader<R> {
|
||||
/// Read and decrypt exactly n bytes with Async
|
||||
pub async fn read_exact_decrypt(&mut self, n: usize) -> Result<Bytes> {
|
||||
let mut result = BytesMut::with_capacity(n);
|
||||
|
||||
if !self.buffer.is_empty() {
|
||||
let to_take = self.buffer.len().min(n);
|
||||
result.extend_from_slice(&self.buffer.split_to(to_take));
|
||||
}
|
||||
|
||||
// Reread
|
||||
while result.len() < n {
|
||||
let mut temp = vec![0u8; n - result.len()];
|
||||
let read = self.upstream.read(&mut temp).await?;
|
||||
|
||||
if read == 0 {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "Connection closed"));
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
self.decryptor.apply(&mut temp[..read]);
|
||||
result.extend_from_slice(&temp[..read]);
|
||||
}
|
||||
|
||||
Ok(result.freeze())
|
||||
}
|
||||
}
|
||||
|
||||
/// Writer that encrypts data using AES-CTR
|
||||
pub struct CryptoWriter<W> {
|
||||
upstream: W,
|
||||
encryptor: AesCtr,
|
||||
pending: BytesMut,
|
||||
}
|
||||
|
||||
impl<W> CryptoWriter<W> {
|
||||
pub fn new(upstream: W, encryptor: AesCtr) -> Self {
|
||||
Self {
|
||||
upstream,
|
||||
encryptor,
|
||||
pending: BytesMut::with_capacity(8192),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_ref(&self) -> &W {
|
||||
&self.upstream
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self) -> &mut W {
|
||||
&mut self.upstream
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> W {
|
||||
self.upstream
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
if !this.pending.is_empty() {
|
||||
match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) {
|
||||
Poll::Ready(Ok(written)) => {
|
||||
let _ = this.pending.split_to(written);
|
||||
|
||||
if !this.pending.is_empty() {
|
||||
cx.waker().wake_by_ref();
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
// Pending Null
|
||||
if buf.is_empty() {
|
||||
return Poll::Ready(Ok(0));
|
||||
}
|
||||
|
||||
// Encrypt
|
||||
let mut encrypted = buf.to_vec();
|
||||
this.encryptor.apply(&mut encrypted);
|
||||
|
||||
// Write Try
|
||||
match Pin::new(&mut this.upstream).poll_write(cx, &encrypted) {
|
||||
Poll::Ready(Ok(written)) => {
|
||||
if written < encrypted.len() {
|
||||
// Partial write — сохраняем остаток в pending
|
||||
this.pending.extend_from_slice(&encrypted[written..]);
|
||||
}
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => {
|
||||
this.pending.extend_from_slice(&encrypted);
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
while !this.pending.is_empty() {
|
||||
match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) {
|
||||
Poll::Ready(Ok(0)) => {
|
||||
return Poll::Ready(Err(Error::new(
|
||||
ErrorKind::WriteZero,
|
||||
"Failed to write pending data during flush",
|
||||
)));
|
||||
}
|
||||
Poll::Ready(Ok(written)) => {
|
||||
let _ = this.pending.split_to(written);
|
||||
}
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
Pin::new(&mut this.upstream).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
while !this.pending.is_empty() {
|
||||
match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) {
|
||||
Poll::Ready(Ok(0)) => {
|
||||
break;
|
||||
}
|
||||
Poll::Ready(Ok(written)) => {
|
||||
let _ = this.pending.split_to(written);
|
||||
}
|
||||
Poll::Ready(Err(_)) => {
|
||||
break;
|
||||
}
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
Pin::new(&mut this.upstream).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
/// Passthrough stream for fast mode - no encryption/decryption
|
||||
pub struct PassthroughStream<S> {
|
||||
inner: S,
|
||||
}
|
||||
|
||||
impl<S> PassthroughStream<S> {
|
||||
pub fn new(inner: S) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> S {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> AsyncRead for PassthroughStream<S> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: AsyncWrite + Unpin> AsyncWrite for PassthroughStream<S> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::VecDeque;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll, Waker, RawWaker, RawWakerVTable};
|
||||
use tokio::io::duplex;
|
||||
|
||||
/// Mock writer
|
||||
struct PartialWriter {
|
||||
chunk_size: usize,
|
||||
data: Vec<u8>,
|
||||
write_count: usize,
|
||||
}
|
||||
|
||||
impl PartialWriter {
|
||||
fn new(chunk_size: usize) -> Self {
|
||||
Self {
|
||||
chunk_size,
|
||||
data: Vec::new(),
|
||||
write_count: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for PartialWriter {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
self.write_count += 1;
|
||||
let to_write = buf.len().min(self.chunk_size);
|
||||
self.data.extend_from_slice(&buf[..to_write]);
|
||||
Poll::Ready(Ok(to_write))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
fn noop_waker() -> Waker {
|
||||
const VTABLE: RawWakerVTable = RawWakerVTable::new(
|
||||
|_| RawWaker::new(std::ptr::null(), &VTABLE),
|
||||
|_| {},
|
||||
|_| {},
|
||||
|_| {},
|
||||
);
|
||||
unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crypto_writer_partial_write_correctness() {
|
||||
let key = [0x42u8; 32];
|
||||
let iv = 12345u128;
|
||||
|
||||
// 10-byte Writer
|
||||
let mock_writer = PartialWriter::new(10);
|
||||
let encryptor = AesCtr::new(&key, iv);
|
||||
let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor);
|
||||
|
||||
let waker = noop_waker();
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
// 25 byte
|
||||
let original = b"Hello, this is test data!";
|
||||
|
||||
// First Write
|
||||
let result = Pin::new(&mut crypto_writer).poll_write(&mut cx, original);
|
||||
assert!(matches!(result, Poll::Ready(Ok(25))));
|
||||
|
||||
// Flush before continue Pending
|
||||
loop {
|
||||
match Pin::new(&mut crypto_writer).poll_flush(&mut cx) {
|
||||
Poll::Ready(Ok(())) => break,
|
||||
Poll::Ready(Err(e)) => panic!("Flush error: {}", e),
|
||||
Poll::Pending => continue,
|
||||
}
|
||||
}
|
||||
|
||||
// Write Check
|
||||
let encrypted = &crypto_writer.upstream.data;
|
||||
assert_eq!(encrypted.len(), 25);
|
||||
|
||||
// Decrypt + Verify
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
let mut decrypted = encrypted.clone();
|
||||
decryptor.apply(&mut decrypted);
|
||||
|
||||
assert_eq!(&decrypted, original);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crypto_writer_multiple_partial_writes() {
|
||||
let key = [0xAB; 32];
|
||||
let iv = 9999u128;
|
||||
|
||||
let mock_writer = PartialWriter::new(3);
|
||||
let encryptor = AesCtr::new(&key, iv);
|
||||
let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor);
|
||||
|
||||
let waker = noop_waker();
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let data1 = b"First";
|
||||
let data2 = b"Second";
|
||||
let data3 = b"Third";
|
||||
|
||||
Pin::new(&mut crypto_writer).poll_write(&mut cx, data1).unwrap();
|
||||
// Flush
|
||||
while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {}
|
||||
|
||||
Pin::new(&mut crypto_writer).poll_write(&mut cx, data2).unwrap();
|
||||
while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {}
|
||||
|
||||
Pin::new(&mut crypto_writer).poll_write(&mut cx, data3).unwrap();
|
||||
while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {}
|
||||
|
||||
// Assemble
|
||||
let mut expected = Vec::new();
|
||||
expected.extend_from_slice(data1);
|
||||
expected.extend_from_slice(data2);
|
||||
expected.extend_from_slice(data3);
|
||||
|
||||
// Decrypt
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
let mut decrypted = crypto_writer.upstream.data.clone();
|
||||
decryptor.apply(&mut decrypted);
|
||||
|
||||
assert_eq!(decrypted, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_crypto_stream_roundtrip() {
|
||||
let key = [0u8; 32];
|
||||
let iv = 12345u128;
|
||||
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let encryptor = AesCtr::new(&key, iv);
|
||||
let decryptor = AesCtr::new(&key, iv);
|
||||
|
||||
let mut writer = CryptoWriter::new(client, encryptor);
|
||||
let mut reader = CryptoReader::new(server, decryptor);
|
||||
|
||||
// Write
|
||||
let original = b"Hello, encrypted world!";
|
||||
writer.write_all(original).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
// Read
|
||||
let mut buf = vec![0u8; original.len()];
|
||||
reader.read_exact(&mut buf).await.unwrap();
|
||||
|
||||
assert_eq!(&buf, original);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_crypto_stream_large_data() {
|
||||
let key = [0x55u8; 32];
|
||||
let iv = 777u128;
|
||||
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let encryptor = AesCtr::new(&key, iv);
|
||||
let decryptor = AesCtr::new(&key, iv);
|
||||
|
||||
let mut writer = CryptoWriter::new(client, encryptor);
|
||||
let mut reader = CryptoReader::new(server, decryptor);
|
||||
|
||||
// Hugeload
|
||||
let original: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
|
||||
|
||||
// Write
|
||||
let write_data = original.clone();
|
||||
let write_handle = tokio::spawn(async move {
|
||||
writer.write_all(&write_data).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
writer.shutdown().await.unwrap();
|
||||
});
|
||||
|
||||
// Read
|
||||
let mut received = Vec::new();
|
||||
let mut buf = vec![0u8; 1024];
|
||||
loop {
|
||||
match reader.read(&mut buf).await {
|
||||
Ok(0) => break,
|
||||
Ok(n) => received.extend_from_slice(&buf[..n]),
|
||||
Err(e) => panic!("Read error: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
write_handle.await.unwrap();
|
||||
|
||||
assert_eq!(received, original);
|
||||
}
|
||||
}
|
||||
585
src/stream/frame_stream.rs
Normal file
585
src/stream/frame_stream.rs
Normal file
@@ -0,0 +1,585 @@
|
||||
//! MTProto frame stream wrappers
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use std::io::{Error, ErrorKind, Result};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
|
||||
use crate::protocol::constants::*;
|
||||
use crate::crypto::crc32;
|
||||
use crate::crypto::random::SECURE_RANDOM;
|
||||
use super::traits::{FrameMeta, LayeredStream};
|
||||
|
||||
// ============= Abridged (Compact) Frame =============
|
||||
|
||||
/// Reader for abridged MTProto framing
|
||||
pub struct AbridgedFrameReader<R> {
|
||||
upstream: R,
|
||||
}
|
||||
|
||||
impl<R> AbridgedFrameReader<R> {
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self { upstream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> AbridgedFrameReader<R> {
|
||||
/// Read a frame and return (data, metadata)
|
||||
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
|
||||
let mut meta = FrameMeta::new();
|
||||
|
||||
// Read length byte
|
||||
let mut len_byte = [0u8];
|
||||
self.upstream.read_exact(&mut len_byte).await?;
|
||||
|
||||
let mut len = len_byte[0] as usize;
|
||||
|
||||
// Check QuickACK flag (high bit)
|
||||
if len >= 0x80 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80;
|
||||
}
|
||||
|
||||
// Extended length (3 bytes)
|
||||
if len == 0x7f {
|
||||
let mut len_bytes = [0u8; 3];
|
||||
self.upstream.read_exact(&mut len_bytes).await?;
|
||||
len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], 0]) as usize;
|
||||
}
|
||||
|
||||
// Length is in 4-byte words
|
||||
let byte_len = len * 4;
|
||||
|
||||
// Read data
|
||||
let mut data = vec![0u8; byte_len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
Ok((Bytes::from(data), meta))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> LayeredStream<R> for AbridgedFrameReader<R> {
|
||||
fn upstream(&self) -> &R { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
|
||||
fn into_upstream(self) -> R { self.upstream }
|
||||
}
|
||||
|
||||
/// Writer for abridged MTProto framing
|
||||
pub struct AbridgedFrameWriter<W> {
|
||||
upstream: W,
|
||||
}
|
||||
|
||||
impl<W> AbridgedFrameWriter<W> {
|
||||
pub fn new(upstream: W) -> Self {
|
||||
Self { upstream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> AbridgedFrameWriter<W> {
|
||||
/// Write a frame
|
||||
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
|
||||
if data.len() % 4 != 0 {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("Abridged frame must be aligned to 4 bytes, got {}", data.len()),
|
||||
));
|
||||
}
|
||||
|
||||
// Simple ACK: send reversed data
|
||||
if meta.simple_ack {
|
||||
let reversed: Vec<u8> = data.iter().rev().copied().collect();
|
||||
self.upstream.write_all(&reversed).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let len_div_4 = data.len() / 4;
|
||||
|
||||
if len_div_4 < 0x7f {
|
||||
// Short length (1 byte)
|
||||
self.upstream.write_all(&[len_div_4 as u8]).await?;
|
||||
} else if len_div_4 < (1 << 24) {
|
||||
// Long length (4 bytes: 0x7f + 3 bytes)
|
||||
let mut header = [0x7f, 0, 0, 0];
|
||||
header[1..4].copy_from_slice(&(len_div_4 as u32).to_le_bytes()[..3]);
|
||||
self.upstream.write_all(&header).await?;
|
||||
} else {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
format!("Frame too large: {} bytes", data.len()),
|
||||
));
|
||||
}
|
||||
|
||||
self.upstream.write_all(data).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn flush(&mut self) -> Result<()> {
|
||||
self.upstream.flush().await
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> LayeredStream<W> for AbridgedFrameWriter<W> {
|
||||
fn upstream(&self) -> &W { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
|
||||
fn into_upstream(self) -> W { self.upstream }
|
||||
}
|
||||
|
||||
// ============= Intermediate Frame =============
|
||||
|
||||
/// Reader for intermediate MTProto framing
|
||||
pub struct IntermediateFrameReader<R> {
|
||||
upstream: R,
|
||||
}
|
||||
|
||||
impl<R> IntermediateFrameReader<R> {
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self { upstream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> IntermediateFrameReader<R> {
|
||||
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
|
||||
let mut meta = FrameMeta::new();
|
||||
|
||||
// Read 4-byte length
|
||||
let mut len_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut len_bytes).await?;
|
||||
|
||||
let mut len = u32::from_le_bytes(len_bytes) as usize;
|
||||
|
||||
// Check QuickACK flag (high bit)
|
||||
if len > 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
|
||||
// Read data
|
||||
let mut data = vec![0u8; len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
Ok((Bytes::from(data), meta))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> LayeredStream<R> for IntermediateFrameReader<R> {
|
||||
fn upstream(&self) -> &R { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
|
||||
fn into_upstream(self) -> R { self.upstream }
|
||||
}
|
||||
|
||||
/// Writer for intermediate MTProto framing
|
||||
pub struct IntermediateFrameWriter<W> {
|
||||
upstream: W,
|
||||
}
|
||||
|
||||
impl<W> IntermediateFrameWriter<W> {
|
||||
pub fn new(upstream: W) -> Self {
|
||||
Self { upstream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> IntermediateFrameWriter<W> {
|
||||
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
|
||||
if meta.simple_ack {
|
||||
self.upstream.write_all(data).await?;
|
||||
} else {
|
||||
let len_bytes = (data.len() as u32).to_le_bytes();
|
||||
self.upstream.write_all(&len_bytes).await?;
|
||||
self.upstream.write_all(data).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn flush(&mut self) -> Result<()> {
|
||||
self.upstream.flush().await
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> LayeredStream<W> for IntermediateFrameWriter<W> {
|
||||
fn upstream(&self) -> &W { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
|
||||
fn into_upstream(self) -> W { self.upstream }
|
||||
}
|
||||
|
||||
// ============= Secure Intermediate Frame =============
|
||||
|
||||
/// Reader for secure intermediate MTProto framing (with padding)
|
||||
pub struct SecureIntermediateFrameReader<R> {
|
||||
upstream: R,
|
||||
}
|
||||
|
||||
impl<R> SecureIntermediateFrameReader<R> {
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self { upstream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
|
||||
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
|
||||
let mut meta = FrameMeta::new();
|
||||
|
||||
// Read 4-byte length
|
||||
let mut len_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut len_bytes).await?;
|
||||
|
||||
let mut len = u32::from_le_bytes(len_bytes) as usize;
|
||||
|
||||
// Check QuickACK flag
|
||||
if len > 0x80000000 {
|
||||
meta.quickack = true;
|
||||
len -= 0x80000000;
|
||||
}
|
||||
|
||||
// Read data (including padding)
|
||||
let mut data = vec![0u8; len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
// Strip padding (not aligned to 4)
|
||||
if len % 4 != 0 {
|
||||
let actual_len = len - (len % 4);
|
||||
data.truncate(actual_len);
|
||||
}
|
||||
|
||||
Ok((Bytes::from(data), meta))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> LayeredStream<R> for SecureIntermediateFrameReader<R> {
|
||||
fn upstream(&self) -> &R { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
|
||||
fn into_upstream(self) -> R { self.upstream }
|
||||
}
|
||||
|
||||
/// Writer for secure intermediate MTProto framing
|
||||
pub struct SecureIntermediateFrameWriter<W> {
|
||||
upstream: W,
|
||||
}
|
||||
|
||||
impl<W> SecureIntermediateFrameWriter<W> {
|
||||
pub fn new(upstream: W) -> Self {
|
||||
Self { upstream }
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
|
||||
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
|
||||
if meta.simple_ack {
|
||||
self.upstream.write_all(data).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Add random padding (0-3 bytes)
|
||||
let padding_len = SECURE_RANDOM.range(4);
|
||||
let padding = SECURE_RANDOM.bytes(padding_len);
|
||||
|
||||
let total_len = data.len() + padding_len;
|
||||
let len_bytes = (total_len as u32).to_le_bytes();
|
||||
|
||||
self.upstream.write_all(&len_bytes).await?;
|
||||
self.upstream.write_all(data).await?;
|
||||
self.upstream.write_all(&padding).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn flush(&mut self) -> Result<()> {
|
||||
self.upstream.flush().await
|
||||
}
|
||||
}
|
||||
|
||||
impl<W> LayeredStream<W> for SecureIntermediateFrameWriter<W> {
|
||||
fn upstream(&self) -> &W { &self.upstream }
|
||||
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
|
||||
fn into_upstream(self) -> W { self.upstream }
|
||||
}
|
||||
|
||||
// ============= Full MTProto Frame (with CRC) =============
|
||||
|
||||
/// Reader for full MTProto framing with sequence numbers and CRC32
|
||||
pub struct MtprotoFrameReader<R> {
|
||||
upstream: R,
|
||||
seq_no: i32,
|
||||
}
|
||||
|
||||
impl<R> MtprotoFrameReader<R> {
|
||||
pub fn new(upstream: R, start_seq: i32) -> Self {
|
||||
Self { upstream, seq_no: start_seq }
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> MtprotoFrameReader<R> {
|
||||
pub async fn read_frame(&mut self) -> Result<Bytes> {
|
||||
loop {
|
||||
// Read length (4 bytes)
|
||||
let mut len_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut len_bytes).await?;
|
||||
let len = u32::from_le_bytes(len_bytes) as usize;
|
||||
|
||||
// Skip padding-only messages
|
||||
if len == 4 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Validate length
|
||||
if len < MIN_MSG_LEN || len > MAX_MSG_LEN || len % PADDING_FILLER.len() != 0 {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("Invalid message length: {}", len),
|
||||
));
|
||||
}
|
||||
|
||||
// Read sequence number
|
||||
let mut seq_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut seq_bytes).await?;
|
||||
let msg_seq = i32::from_le_bytes(seq_bytes);
|
||||
|
||||
if msg_seq != self.seq_no {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("Sequence mismatch: expected {}, got {}", self.seq_no, msg_seq),
|
||||
));
|
||||
}
|
||||
self.seq_no += 1;
|
||||
|
||||
// Read data (length - 4 len - 4 seq - 4 crc = len - 12)
|
||||
let data_len = len - 12;
|
||||
let mut data = vec![0u8; data_len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
// Read and verify CRC32
|
||||
let mut crc_bytes = [0u8; 4];
|
||||
self.upstream.read_exact(&mut crc_bytes).await?;
|
||||
let expected_crc = u32::from_le_bytes(crc_bytes);
|
||||
|
||||
// Compute CRC over len + seq + data
|
||||
let mut crc_input = Vec::with_capacity(8 + data_len);
|
||||
crc_input.extend_from_slice(&len_bytes);
|
||||
crc_input.extend_from_slice(&seq_bytes);
|
||||
crc_input.extend_from_slice(&data);
|
||||
let computed_crc = crc32(&crc_input);
|
||||
|
||||
if computed_crc != expected_crc {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("CRC mismatch: expected {:08x}, got {:08x}", expected_crc, computed_crc),
|
||||
));
|
||||
}
|
||||
|
||||
return Ok(Bytes::from(data));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Writer for full MTProto framing
|
||||
pub struct MtprotoFrameWriter<W> {
|
||||
upstream: W,
|
||||
seq_no: i32,
|
||||
}
|
||||
|
||||
impl<W> MtprotoFrameWriter<W> {
|
||||
pub fn new(upstream: W, start_seq: i32) -> Self {
|
||||
Self { upstream, seq_no: start_seq }
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> MtprotoFrameWriter<W> {
|
||||
pub async fn write_frame(&mut self, msg: &[u8]) -> Result<()> {
|
||||
// Total length: 4 (len) + 4 (seq) + data + 4 (crc)
|
||||
let len = msg.len() + 12;
|
||||
|
||||
let len_bytes = (len as u32).to_le_bytes();
|
||||
let seq_bytes = self.seq_no.to_le_bytes();
|
||||
self.seq_no += 1;
|
||||
|
||||
// Compute CRC
|
||||
let mut crc_input = Vec::with_capacity(8 + msg.len());
|
||||
crc_input.extend_from_slice(&len_bytes);
|
||||
crc_input.extend_from_slice(&seq_bytes);
|
||||
crc_input.extend_from_slice(msg);
|
||||
let checksum = crc32(&crc_input);
|
||||
let crc_bytes = checksum.to_le_bytes();
|
||||
|
||||
// Calculate padding for CBC alignment
|
||||
let total_len = len_bytes.len() + seq_bytes.len() + msg.len() + crc_bytes.len();
|
||||
let padding_needed = (CBC_PADDING - (total_len % CBC_PADDING)) % CBC_PADDING;
|
||||
let padding_count = padding_needed / PADDING_FILLER.len();
|
||||
|
||||
// Write everything
|
||||
self.upstream.write_all(&len_bytes).await?;
|
||||
self.upstream.write_all(&seq_bytes).await?;
|
||||
self.upstream.write_all(msg).await?;
|
||||
self.upstream.write_all(&crc_bytes).await?;
|
||||
|
||||
for _ in 0..padding_count {
|
||||
self.upstream.write_all(&PADDING_FILLER).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn flush(&mut self) -> Result<()> {
|
||||
self.upstream.flush().await
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Frame Type Enum =============
|
||||
|
||||
/// Enum for different frame stream types
|
||||
pub enum FrameReaderKind<R> {
|
||||
Abridged(AbridgedFrameReader<R>),
|
||||
Intermediate(IntermediateFrameReader<R>),
|
||||
SecureIntermediate(SecureIntermediateFrameReader<R>),
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> FrameReaderKind<R> {
|
||||
pub fn new(upstream: R, proto_tag: ProtoTag) -> Self {
|
||||
match proto_tag {
|
||||
ProtoTag::Abridged => FrameReaderKind::Abridged(AbridgedFrameReader::new(upstream)),
|
||||
ProtoTag::Intermediate => FrameReaderKind::Intermediate(IntermediateFrameReader::new(upstream)),
|
||||
ProtoTag::Secure => FrameReaderKind::SecureIntermediate(SecureIntermediateFrameReader::new(upstream)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
|
||||
match self {
|
||||
FrameReaderKind::Abridged(r) => r.read_frame().await,
|
||||
FrameReaderKind::Intermediate(r) => r.read_frame().await,
|
||||
FrameReaderKind::SecureIntermediate(r) => r.read_frame().await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum FrameWriterKind<W> {
|
||||
Abridged(AbridgedFrameWriter<W>),
|
||||
Intermediate(IntermediateFrameWriter<W>),
|
||||
SecureIntermediate(SecureIntermediateFrameWriter<W>),
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
|
||||
pub fn new(upstream: W, proto_tag: ProtoTag) -> Self {
|
||||
match proto_tag {
|
||||
ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)),
|
||||
ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)),
|
||||
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
|
||||
match self {
|
||||
FrameWriterKind::Abridged(w) => w.write_frame(data, meta).await,
|
||||
FrameWriterKind::Intermediate(w) => w.write_frame(data, meta).await,
|
||||
FrameWriterKind::SecureIntermediate(w) => w.write_frame(data, meta).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn flush(&mut self) -> Result<()> {
|
||||
match self {
|
||||
FrameWriterKind::Abridged(w) => w.flush().await,
|
||||
FrameWriterKind::Intermediate(w) => w.flush().await,
|
||||
FrameWriterKind::SecureIntermediate(w) => w.flush().await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::duplex;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_abridged_roundtrip() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = AbridgedFrameWriter::new(client);
|
||||
let mut reader = AbridgedFrameReader::new(server);
|
||||
|
||||
// Short frame
|
||||
let data = vec![1u8, 2, 3, 4]; // 4 bytes = 1 word
|
||||
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let (received, _meta) = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &data[..]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_abridged_long_frame() {
|
||||
let (client, server) = duplex(65536);
|
||||
|
||||
let mut writer = AbridgedFrameWriter::new(client);
|
||||
let mut reader = AbridgedFrameReader::new(server);
|
||||
|
||||
// Long frame (> 0x7f words = 508 bytes)
|
||||
let data: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
|
||||
let padded_len = (data.len() + 3) / 4 * 4;
|
||||
let mut padded = data.clone();
|
||||
padded.resize(padded_len, 0);
|
||||
|
||||
writer.write_frame(&padded, &FrameMeta::new()).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let (received, _meta) = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &padded[..]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_intermediate_roundtrip() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = IntermediateFrameWriter::new(client);
|
||||
let mut reader = IntermediateFrameReader::new(server);
|
||||
|
||||
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
|
||||
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let (received, _meta) = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &data[..]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_secure_intermediate_padding() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = SecureIntermediateFrameWriter::new(client);
|
||||
let mut reader = SecureIntermediateFrameReader::new(server);
|
||||
|
||||
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
|
||||
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let (received, _meta) = reader.read_frame().await.unwrap();
|
||||
// Received should have padding stripped to align to 4
|
||||
let expected_len = (data.len() / 4) * 4;
|
||||
assert_eq!(received.len(), expected_len);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mtproto_frame_roundtrip() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = MtprotoFrameWriter::new(client, 0);
|
||||
let mut reader = MtprotoFrameReader::new(server, 0);
|
||||
|
||||
// Message must be padded properly
|
||||
let data = vec![0u8; 16]; // Aligned to 4 and CBC_PADDING
|
||||
writer.write_frame(&data).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let received = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &data[..]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_frame_reader_kind() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate);
|
||||
let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate);
|
||||
|
||||
let data = vec![1u8, 2, 3, 4];
|
||||
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let (received, _) = reader.read_frame().await.unwrap();
|
||||
assert_eq!(&received[..], &data[..]);
|
||||
}
|
||||
}
|
||||
10
src/stream/mod.rs
Normal file
10
src/stream/mod.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
//! Stream wrappers for MTProto protocol layers
|
||||
|
||||
pub mod traits;
|
||||
pub mod crypto_stream;
|
||||
pub mod tls_stream;
|
||||
pub mod frame_stream;
|
||||
|
||||
pub use crypto_stream::{CryptoReader, CryptoWriter, PassthroughStream};
|
||||
pub use tls_stream::{FakeTlsReader, FakeTlsWriter};
|
||||
pub use frame_stream::*;
|
||||
277
src/stream/tls_stream.rs
Normal file
277
src/stream/tls_stream.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
//! Fake TLS 1.3 stream wrappers
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use std::io::{Error, ErrorKind, Result};
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
|
||||
use crate::protocol::constants::{
|
||||
TLS_VERSION, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER,
|
||||
MAX_TLS_CHUNK_SIZE,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
|
||||
/// Reader that unwraps TLS 1.3 records
|
||||
pub struct FakeTlsReader<R> {
|
||||
upstream: R,
|
||||
buffer: BytesMut,
|
||||
pending_read: Option<PendingTlsRead>,
|
||||
}
|
||||
|
||||
struct PendingTlsRead {
|
||||
record_type: u8,
|
||||
remaining: usize,
|
||||
}
|
||||
|
||||
impl<R> FakeTlsReader<R> {
|
||||
/// Create new fake TLS reader
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self {
|
||||
upstream,
|
||||
buffer: BytesMut::with_capacity(16384),
|
||||
pending_read: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get reference to upstream
|
||||
pub fn get_ref(&self) -> &R {
|
||||
&self.upstream
|
||||
}
|
||||
|
||||
/// Get mutable reference to upstream
|
||||
pub fn get_mut(&mut self) -> &mut R {
|
||||
&mut self.upstream
|
||||
}
|
||||
|
||||
/// Consume and return upstream
|
||||
pub fn into_inner(self) -> R {
|
||||
self.upstream
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
|
||||
/// Read exactly n bytes through TLS layer
|
||||
pub async fn read_exact(&mut self, n: usize) -> Result<Bytes> {
|
||||
while self.buffer.len() < n {
|
||||
let data = self.read_tls_record().await?;
|
||||
if data.is_empty() {
|
||||
return Err(Error::new(ErrorKind::UnexpectedEof, "Connection closed"));
|
||||
}
|
||||
self.buffer.extend_from_slice(&data);
|
||||
}
|
||||
|
||||
Ok(self.buffer.split_to(n).freeze())
|
||||
}
|
||||
|
||||
/// Read a single TLS record
|
||||
async fn read_tls_record(&mut self) -> Result<Vec<u8>> {
|
||||
loop {
|
||||
// Read TLS record header (5 bytes)
|
||||
let mut header = [0u8; 5];
|
||||
self.upstream.read_exact(&mut header).await?;
|
||||
|
||||
let record_type = header[0];
|
||||
let version = [header[1], header[2]];
|
||||
let length = u16::from_be_bytes([header[3], header[4]]) as usize;
|
||||
|
||||
// Validate version
|
||||
if version != TLS_VERSION {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("Invalid TLS version: {:02x?}", version),
|
||||
));
|
||||
}
|
||||
|
||||
// Read record body
|
||||
let mut data = vec![0u8; length];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
match record_type {
|
||||
TLS_RECORD_CHANGE_CIPHER => continue, // Skip
|
||||
TLS_RECORD_APPLICATION => return Ok(data),
|
||||
_ => {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("Unexpected TLS record type: 0x{:02x}", record_type),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<Result<()>> {
|
||||
// Drain buffer first
|
||||
if !self.buffer.is_empty() {
|
||||
let to_copy = self.buffer.len().min(buf.remaining());
|
||||
buf.put_slice(&self.buffer.split_to(to_copy));
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
// We need to read a TLS record, but poll_read doesn't support async/await
|
||||
// So we'll do a simplified version that reads header synchronously
|
||||
|
||||
// Read header
|
||||
let mut header = [0u8; 5];
|
||||
let mut header_buf = ReadBuf::new(&mut header);
|
||||
|
||||
match Pin::new(&mut self.upstream).poll_read(cx, &mut header_buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
if header_buf.filled().len() < 5 {
|
||||
// Need more data - store what we have and return pending
|
||||
// For simplicity, we'll just return empty
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
|
||||
let record_type = header[0];
|
||||
let length = u16::from_be_bytes([header[3], header[4]]) as usize;
|
||||
|
||||
if record_type == TLS_RECORD_CHANGE_CIPHER {
|
||||
// Skip this record, try again
|
||||
cx.waker().wake_by_ref();
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
if record_type != TLS_RECORD_APPLICATION {
|
||||
return Poll::Ready(Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
"Invalid TLS record type",
|
||||
)));
|
||||
}
|
||||
|
||||
// Read body
|
||||
let mut body = vec![0u8; length];
|
||||
let mut body_buf = ReadBuf::new(&mut body);
|
||||
|
||||
match Pin::new(&mut self.upstream).poll_read(cx, &mut body_buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let filled = body_buf.filled();
|
||||
let to_copy = filled.len().min(buf.remaining());
|
||||
buf.put_slice(&filled[..to_copy]);
|
||||
|
||||
if filled.len() > to_copy {
|
||||
self.buffer.extend_from_slice(&filled[to_copy..]);
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Writer that wraps data in TLS 1.3 records
|
||||
pub struct FakeTlsWriter<W> {
|
||||
upstream: W,
|
||||
}
|
||||
|
||||
impl<W> FakeTlsWriter<W> {
|
||||
/// Create new fake TLS writer
|
||||
pub fn new(upstream: W) -> Self {
|
||||
Self { upstream }
|
||||
}
|
||||
|
||||
/// Get reference to upstream
|
||||
pub fn get_ref(&self) -> &W {
|
||||
&self.upstream
|
||||
}
|
||||
|
||||
/// Get mutable reference to upstream
|
||||
pub fn get_mut(&mut self) -> &mut W {
|
||||
&mut self.upstream
|
||||
}
|
||||
|
||||
/// Consume and return upstream
|
||||
pub fn into_inner(self) -> W {
|
||||
self.upstream
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
// Build TLS record
|
||||
let chunk_size = buf.len().min(MAX_TLS_CHUNK_SIZE);
|
||||
let chunk = &buf[..chunk_size];
|
||||
|
||||
let mut record = Vec::with_capacity(5 + chunk_size);
|
||||
record.push(TLS_RECORD_APPLICATION);
|
||||
record.extend_from_slice(&TLS_VERSION);
|
||||
record.push((chunk_size >> 8) as u8);
|
||||
record.push(chunk_size as u8);
|
||||
record.extend_from_slice(chunk);
|
||||
|
||||
match Pin::new(&mut self.upstream).poll_write(cx, &record) {
|
||||
Poll::Ready(Ok(written)) => {
|
||||
if written >= 5 {
|
||||
Poll::Ready(Ok(written - 5))
|
||||
} else {
|
||||
Poll::Ready(Ok(0))
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.upstream).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.upstream).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
|
||||
/// Write all data wrapped in TLS records (async method)
|
||||
pub async fn write_all_tls(&mut self, data: &[u8]) -> Result<()> {
|
||||
for chunk in data.chunks(MAX_TLS_CHUNK_SIZE) {
|
||||
let header = [
|
||||
TLS_RECORD_APPLICATION,
|
||||
TLS_VERSION[0],
|
||||
TLS_VERSION[1],
|
||||
(chunk.len() >> 8) as u8,
|
||||
chunk.len() as u8,
|
||||
];
|
||||
|
||||
self.upstream.write_all(&header).await?;
|
||||
self.upstream.write_all(chunk).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::duplex;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tls_stream_roundtrip() {
|
||||
let (client, server) = duplex(4096);
|
||||
|
||||
let mut writer = FakeTlsWriter::new(client);
|
||||
let mut reader = FakeTlsReader::new(server);
|
||||
|
||||
let original = b"Hello, fake TLS!";
|
||||
writer.write_all_tls(original).await.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
let received = reader.read_exact(original.len()).await.unwrap();
|
||||
assert_eq!(&received[..], original);
|
||||
}
|
||||
}
|
||||
113
src/stream/traits.rs
Normal file
113
src/stream/traits.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
//! Stream traits and common types
|
||||
|
||||
use bytes::Bytes;
|
||||
use std::io::Result;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
/// Extra metadata for frames
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct FrameMeta {
|
||||
/// Quick ACK requested
|
||||
pub quickack: bool,
|
||||
/// This is a simple ACK message
|
||||
pub simple_ack: bool,
|
||||
/// Skip sending this frame
|
||||
pub skip_send: bool,
|
||||
}
|
||||
|
||||
impl FrameMeta {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn with_quickack(mut self) -> Self {
|
||||
self.quickack = true;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_simple_ack(mut self) -> Self {
|
||||
self.simple_ack = true;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of reading a frame
|
||||
#[derive(Debug)]
|
||||
pub enum ReadFrameResult {
|
||||
/// Frame data with metadata
|
||||
Frame(Bytes, FrameMeta),
|
||||
/// Connection closed
|
||||
Closed,
|
||||
}
|
||||
|
||||
/// Trait for streams that wrap another stream
|
||||
pub trait LayeredStream<U> {
|
||||
/// Get reference to upstream
|
||||
fn upstream(&self) -> &U;
|
||||
|
||||
/// Get mutable reference to upstream
|
||||
fn upstream_mut(&mut self) -> &mut U;
|
||||
|
||||
/// Consume self and return upstream
|
||||
fn into_upstream(self) -> U;
|
||||
}
|
||||
|
||||
/// A split read half of a stream
|
||||
pub struct ReadHalf<R> {
|
||||
inner: R,
|
||||
}
|
||||
|
||||
impl<R> ReadHalf<R> {
|
||||
pub fn new(inner: R) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> R {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> AsyncRead for ReadHalf<R> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
/// A split write half of a stream
|
||||
pub struct WriteHalf<W> {
|
||||
inner: W,
|
||||
}
|
||||
|
||||
impl<W> WriteHalf<W> {
|
||||
pub fn new(inner: W) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> W {
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> AsyncWrite for WriteHalf<W> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize>> {
|
||||
Pin::new(&mut self.inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user