Tschuss Status Quo - Hallo, Zukunft!
This commit is contained in:
Alexey
2025-12-30 05:08:05 +03:00
parent 44169441b4
commit 3d9150a074
33 changed files with 6079 additions and 0 deletions

474
src/stream/crypto_stream.rs Normal file
View File

@@ -0,0 +1,474 @@
//! Encrypted stream wrappers using AES-CTR
use bytes::{Bytes, BytesMut, BufMut};
use std::io::{Error, ErrorKind, Result};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
use crate::crypto::AesCtr;
use parking_lot::Mutex;
/// Reader that decrypts data using AES-CTR
pub struct CryptoReader<R> {
upstream: R,
decryptor: AesCtr,
buffer: BytesMut,
}
impl<R> CryptoReader<R> {
/// Create new crypto reader
pub fn new(upstream: R, decryptor: AesCtr) -> Self {
Self {
upstream,
decryptor,
buffer: BytesMut::with_capacity(8192),
}
}
/// Get reference to upstream
pub fn get_ref(&self) -> &R {
&self.upstream
}
/// Get mutable reference to upstream
pub fn get_mut(&mut self) -> &mut R {
&mut self.upstream
}
/// Consume and return upstream
pub fn into_inner(self) -> R {
self.upstream
}
}
impl<R: AsyncRead + Unpin> AsyncRead for CryptoReader<R> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
let this = self.get_mut();
if !this.buffer.is_empty() {
let to_copy = this.buffer.len().min(buf.remaining());
buf.put_slice(&this.buffer.split_to(to_copy));
return Poll::Ready(Ok(()));
}
// Zero-copy Reader
let before = buf.filled().len();
match Pin::new(&mut this.upstream).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let after = buf.filled().len();
let bytes_read = after - before;
if bytes_read > 0 {
// Decrypt in-place
let filled = buf.filled_mut();
this.decryptor.apply(&mut filled[before..after]);
}
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
impl<R: AsyncRead + Unpin> CryptoReader<R> {
/// Read and decrypt exactly n bytes with Async
pub async fn read_exact_decrypt(&mut self, n: usize) -> Result<Bytes> {
let mut result = BytesMut::with_capacity(n);
if !self.buffer.is_empty() {
let to_take = self.buffer.len().min(n);
result.extend_from_slice(&self.buffer.split_to(to_take));
}
// Reread
while result.len() < n {
let mut temp = vec![0u8; n - result.len()];
let read = self.upstream.read(&mut temp).await?;
if read == 0 {
return Err(Error::new(ErrorKind::UnexpectedEof, "Connection closed"));
}
// Decrypt
self.decryptor.apply(&mut temp[..read]);
result.extend_from_slice(&temp[..read]);
}
Ok(result.freeze())
}
}
/// Writer that encrypts data using AES-CTR
pub struct CryptoWriter<W> {
upstream: W,
encryptor: AesCtr,
pending: BytesMut,
}
impl<W> CryptoWriter<W> {
pub fn new(upstream: W, encryptor: AesCtr) -> Self {
Self {
upstream,
encryptor,
pending: BytesMut::with_capacity(8192),
}
}
pub fn get_ref(&self) -> &W {
&self.upstream
}
pub fn get_mut(&mut self) -> &mut W {
&mut self.upstream
}
pub fn into_inner(self) -> W {
self.upstream
}
}
impl<W: AsyncWrite + Unpin> AsyncWrite for CryptoWriter<W> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
let this = self.get_mut();
if !this.pending.is_empty() {
match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) {
Poll::Ready(Ok(written)) => {
let _ = this.pending.split_to(written);
if !this.pending.is_empty() {
cx.waker().wake_by_ref();
return Poll::Pending;
}
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
// Pending Null
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
// Encrypt
let mut encrypted = buf.to_vec();
this.encryptor.apply(&mut encrypted);
// Write Try
match Pin::new(&mut this.upstream).poll_write(cx, &encrypted) {
Poll::Ready(Ok(written)) => {
if written < encrypted.len() {
// Partial write — сохраняем остаток в pending
this.pending.extend_from_slice(&encrypted[written..]);
}
Poll::Ready(Ok(buf.len()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => {
this.pending.extend_from_slice(&encrypted);
Poll::Ready(Ok(buf.len()))
}
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
let this = self.get_mut();
while !this.pending.is_empty() {
match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) {
Poll::Ready(Ok(0)) => {
return Poll::Ready(Err(Error::new(
ErrorKind::WriteZero,
"Failed to write pending data during flush",
)));
}
Poll::Ready(Ok(written)) => {
let _ = this.pending.split_to(written);
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
Pin::new(&mut this.upstream).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
let this = self.get_mut();
while !this.pending.is_empty() {
match Pin::new(&mut this.upstream).poll_write(cx, &this.pending) {
Poll::Ready(Ok(0)) => {
break;
}
Poll::Ready(Ok(written)) => {
let _ = this.pending.split_to(written);
}
Poll::Ready(Err(_)) => {
break;
}
Poll::Pending => return Poll::Pending,
}
}
Pin::new(&mut this.upstream).poll_shutdown(cx)
}
}
/// Passthrough stream for fast mode - no encryption/decryption
pub struct PassthroughStream<S> {
inner: S,
}
impl<S> PassthroughStream<S> {
pub fn new(inner: S) -> Self {
Self { inner }
}
pub fn into_inner(self) -> S {
self.inner
}
}
impl<S: AsyncRead + Unpin> AsyncRead for PassthroughStream<S> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl<S: AsyncWrite + Unpin> AsyncWrite for PassthroughStream<S> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::VecDeque;
use std::pin::Pin;
use std::task::{Context, Poll, Waker, RawWaker, RawWakerVTable};
use tokio::io::duplex;
/// Mock writer
struct PartialWriter {
chunk_size: usize,
data: Vec<u8>,
write_count: usize,
}
impl PartialWriter {
fn new(chunk_size: usize) -> Self {
Self {
chunk_size,
data: Vec::new(),
write_count: 0,
}
}
}
impl AsyncWrite for PartialWriter {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
self.write_count += 1;
let to_write = buf.len().min(self.chunk_size);
self.data.extend_from_slice(&buf[..to_write]);
Poll::Ready(Ok(to_write))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
Poll::Ready(Ok(()))
}
}
fn noop_waker() -> Waker {
const VTABLE: RawWakerVTable = RawWakerVTable::new(
|_| RawWaker::new(std::ptr::null(), &VTABLE),
|_| {},
|_| {},
|_| {},
);
unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
}
#[test]
fn test_crypto_writer_partial_write_correctness() {
let key = [0x42u8; 32];
let iv = 12345u128;
// 10-byte Writer
let mock_writer = PartialWriter::new(10);
let encryptor = AesCtr::new(&key, iv);
let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
// 25 byte
let original = b"Hello, this is test data!";
// First Write
let result = Pin::new(&mut crypto_writer).poll_write(&mut cx, original);
assert!(matches!(result, Poll::Ready(Ok(25))));
// Flush before continue Pending
loop {
match Pin::new(&mut crypto_writer).poll_flush(&mut cx) {
Poll::Ready(Ok(())) => break,
Poll::Ready(Err(e)) => panic!("Flush error: {}", e),
Poll::Pending => continue,
}
}
// Write Check
let encrypted = &crypto_writer.upstream.data;
assert_eq!(encrypted.len(), 25);
// Decrypt + Verify
let mut decryptor = AesCtr::new(&key, iv);
let mut decrypted = encrypted.clone();
decryptor.apply(&mut decrypted);
assert_eq!(&decrypted, original);
}
#[test]
fn test_crypto_writer_multiple_partial_writes() {
let key = [0xAB; 32];
let iv = 9999u128;
let mock_writer = PartialWriter::new(3);
let encryptor = AesCtr::new(&key, iv);
let mut crypto_writer = CryptoWriter::new(mock_writer, encryptor);
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
let data1 = b"First";
let data2 = b"Second";
let data3 = b"Third";
Pin::new(&mut crypto_writer).poll_write(&mut cx, data1).unwrap();
// Flush
while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {}
Pin::new(&mut crypto_writer).poll_write(&mut cx, data2).unwrap();
while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {}
Pin::new(&mut crypto_writer).poll_write(&mut cx, data3).unwrap();
while Pin::new(&mut crypto_writer).poll_flush(&mut cx).is_pending() {}
// Assemble
let mut expected = Vec::new();
expected.extend_from_slice(data1);
expected.extend_from_slice(data2);
expected.extend_from_slice(data3);
// Decrypt
let mut decryptor = AesCtr::new(&key, iv);
let mut decrypted = crypto_writer.upstream.data.clone();
decryptor.apply(&mut decrypted);
assert_eq!(decrypted, expected);
}
#[tokio::test]
async fn test_crypto_stream_roundtrip() {
let key = [0u8; 32];
let iv = 12345u128;
let (client, server) = duplex(4096);
let encryptor = AesCtr::new(&key, iv);
let decryptor = AesCtr::new(&key, iv);
let mut writer = CryptoWriter::new(client, encryptor);
let mut reader = CryptoReader::new(server, decryptor);
// Write
let original = b"Hello, encrypted world!";
writer.write_all(original).await.unwrap();
writer.flush().await.unwrap();
// Read
let mut buf = vec![0u8; original.len()];
reader.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, original);
}
#[tokio::test]
async fn test_crypto_stream_large_data() {
let key = [0x55u8; 32];
let iv = 777u128;
let (client, server) = duplex(1024);
let encryptor = AesCtr::new(&key, iv);
let decryptor = AesCtr::new(&key, iv);
let mut writer = CryptoWriter::new(client, encryptor);
let mut reader = CryptoReader::new(server, decryptor);
// Hugeload
let original: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
// Write
let write_data = original.clone();
let write_handle = tokio::spawn(async move {
writer.write_all(&write_data).await.unwrap();
writer.flush().await.unwrap();
writer.shutdown().await.unwrap();
});
// Read
let mut received = Vec::new();
let mut buf = vec![0u8; 1024];
loop {
match reader.read(&mut buf).await {
Ok(0) => break,
Ok(n) => received.extend_from_slice(&buf[..n]),
Err(e) => panic!("Read error: {}", e),
}
}
write_handle.await.unwrap();
assert_eq!(received, original);
}
}

585
src/stream/frame_stream.rs Normal file
View File

@@ -0,0 +1,585 @@
//! MTProto frame stream wrappers
use bytes::{Bytes, BytesMut};
use std::io::{Error, ErrorKind, Result};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt};
use crate::protocol::constants::*;
use crate::crypto::crc32;
use crate::crypto::random::SECURE_RANDOM;
use super::traits::{FrameMeta, LayeredStream};
// ============= Abridged (Compact) Frame =============
/// Reader for abridged MTProto framing
pub struct AbridgedFrameReader<R> {
upstream: R,
}
impl<R> AbridgedFrameReader<R> {
pub fn new(upstream: R) -> Self {
Self { upstream }
}
}
impl<R: AsyncRead + Unpin> AbridgedFrameReader<R> {
/// Read a frame and return (data, metadata)
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
let mut meta = FrameMeta::new();
// Read length byte
let mut len_byte = [0u8];
self.upstream.read_exact(&mut len_byte).await?;
let mut len = len_byte[0] as usize;
// Check QuickACK flag (high bit)
if len >= 0x80 {
meta.quickack = true;
len -= 0x80;
}
// Extended length (3 bytes)
if len == 0x7f {
let mut len_bytes = [0u8; 3];
self.upstream.read_exact(&mut len_bytes).await?;
len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], 0]) as usize;
}
// Length is in 4-byte words
let byte_len = len * 4;
// Read data
let mut data = vec![0u8; byte_len];
self.upstream.read_exact(&mut data).await?;
Ok((Bytes::from(data), meta))
}
}
impl<R> LayeredStream<R> for AbridgedFrameReader<R> {
fn upstream(&self) -> &R { &self.upstream }
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
fn into_upstream(self) -> R { self.upstream }
}
/// Writer for abridged MTProto framing
pub struct AbridgedFrameWriter<W> {
upstream: W,
}
impl<W> AbridgedFrameWriter<W> {
pub fn new(upstream: W) -> Self {
Self { upstream }
}
}
impl<W: AsyncWrite + Unpin> AbridgedFrameWriter<W> {
/// Write a frame
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
if data.len() % 4 != 0 {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("Abridged frame must be aligned to 4 bytes, got {}", data.len()),
));
}
// Simple ACK: send reversed data
if meta.simple_ack {
let reversed: Vec<u8> = data.iter().rev().copied().collect();
self.upstream.write_all(&reversed).await?;
return Ok(());
}
let len_div_4 = data.len() / 4;
if len_div_4 < 0x7f {
// Short length (1 byte)
self.upstream.write_all(&[len_div_4 as u8]).await?;
} else if len_div_4 < (1 << 24) {
// Long length (4 bytes: 0x7f + 3 bytes)
let mut header = [0x7f, 0, 0, 0];
header[1..4].copy_from_slice(&(len_div_4 as u32).to_le_bytes()[..3]);
self.upstream.write_all(&header).await?;
} else {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("Frame too large: {} bytes", data.len()),
));
}
self.upstream.write_all(data).await?;
Ok(())
}
pub async fn flush(&mut self) -> Result<()> {
self.upstream.flush().await
}
}
impl<W> LayeredStream<W> for AbridgedFrameWriter<W> {
fn upstream(&self) -> &W { &self.upstream }
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
fn into_upstream(self) -> W { self.upstream }
}
// ============= Intermediate Frame =============
/// Reader for intermediate MTProto framing
pub struct IntermediateFrameReader<R> {
upstream: R,
}
impl<R> IntermediateFrameReader<R> {
pub fn new(upstream: R) -> Self {
Self { upstream }
}
}
impl<R: AsyncRead + Unpin> IntermediateFrameReader<R> {
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
let mut meta = FrameMeta::new();
// Read 4-byte length
let mut len_bytes = [0u8; 4];
self.upstream.read_exact(&mut len_bytes).await?;
let mut len = u32::from_le_bytes(len_bytes) as usize;
// Check QuickACK flag (high bit)
if len > 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
// Read data
let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?;
Ok((Bytes::from(data), meta))
}
}
impl<R> LayeredStream<R> for IntermediateFrameReader<R> {
fn upstream(&self) -> &R { &self.upstream }
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
fn into_upstream(self) -> R { self.upstream }
}
/// Writer for intermediate MTProto framing
pub struct IntermediateFrameWriter<W> {
upstream: W,
}
impl<W> IntermediateFrameWriter<W> {
pub fn new(upstream: W) -> Self {
Self { upstream }
}
}
impl<W: AsyncWrite + Unpin> IntermediateFrameWriter<W> {
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
if meta.simple_ack {
self.upstream.write_all(data).await?;
} else {
let len_bytes = (data.len() as u32).to_le_bytes();
self.upstream.write_all(&len_bytes).await?;
self.upstream.write_all(data).await?;
}
Ok(())
}
pub async fn flush(&mut self) -> Result<()> {
self.upstream.flush().await
}
}
impl<W> LayeredStream<W> for IntermediateFrameWriter<W> {
fn upstream(&self) -> &W { &self.upstream }
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
fn into_upstream(self) -> W { self.upstream }
}
// ============= Secure Intermediate Frame =============
/// Reader for secure intermediate MTProto framing (with padding)
pub struct SecureIntermediateFrameReader<R> {
upstream: R,
}
impl<R> SecureIntermediateFrameReader<R> {
pub fn new(upstream: R) -> Self {
Self { upstream }
}
}
impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
let mut meta = FrameMeta::new();
// Read 4-byte length
let mut len_bytes = [0u8; 4];
self.upstream.read_exact(&mut len_bytes).await?;
let mut len = u32::from_le_bytes(len_bytes) as usize;
// Check QuickACK flag
if len > 0x80000000 {
meta.quickack = true;
len -= 0x80000000;
}
// Read data (including padding)
let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?;
// Strip padding (not aligned to 4)
if len % 4 != 0 {
let actual_len = len - (len % 4);
data.truncate(actual_len);
}
Ok((Bytes::from(data), meta))
}
}
impl<R> LayeredStream<R> for SecureIntermediateFrameReader<R> {
fn upstream(&self) -> &R { &self.upstream }
fn upstream_mut(&mut self) -> &mut R { &mut self.upstream }
fn into_upstream(self) -> R { self.upstream }
}
/// Writer for secure intermediate MTProto framing
pub struct SecureIntermediateFrameWriter<W> {
upstream: W,
}
impl<W> SecureIntermediateFrameWriter<W> {
pub fn new(upstream: W) -> Self {
Self { upstream }
}
}
impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
if meta.simple_ack {
self.upstream.write_all(data).await?;
return Ok(());
}
// Add random padding (0-3 bytes)
let padding_len = SECURE_RANDOM.range(4);
let padding = SECURE_RANDOM.bytes(padding_len);
let total_len = data.len() + padding_len;
let len_bytes = (total_len as u32).to_le_bytes();
self.upstream.write_all(&len_bytes).await?;
self.upstream.write_all(data).await?;
self.upstream.write_all(&padding).await?;
Ok(())
}
pub async fn flush(&mut self) -> Result<()> {
self.upstream.flush().await
}
}
impl<W> LayeredStream<W> for SecureIntermediateFrameWriter<W> {
fn upstream(&self) -> &W { &self.upstream }
fn upstream_mut(&mut self) -> &mut W { &mut self.upstream }
fn into_upstream(self) -> W { self.upstream }
}
// ============= Full MTProto Frame (with CRC) =============
/// Reader for full MTProto framing with sequence numbers and CRC32
pub struct MtprotoFrameReader<R> {
upstream: R,
seq_no: i32,
}
impl<R> MtprotoFrameReader<R> {
pub fn new(upstream: R, start_seq: i32) -> Self {
Self { upstream, seq_no: start_seq }
}
}
impl<R: AsyncRead + Unpin> MtprotoFrameReader<R> {
pub async fn read_frame(&mut self) -> Result<Bytes> {
loop {
// Read length (4 bytes)
let mut len_bytes = [0u8; 4];
self.upstream.read_exact(&mut len_bytes).await?;
let len = u32::from_le_bytes(len_bytes) as usize;
// Skip padding-only messages
if len == 4 {
continue;
}
// Validate length
if len < MIN_MSG_LEN || len > MAX_MSG_LEN || len % PADDING_FILLER.len() != 0 {
return Err(Error::new(
ErrorKind::InvalidData,
format!("Invalid message length: {}", len),
));
}
// Read sequence number
let mut seq_bytes = [0u8; 4];
self.upstream.read_exact(&mut seq_bytes).await?;
let msg_seq = i32::from_le_bytes(seq_bytes);
if msg_seq != self.seq_no {
return Err(Error::new(
ErrorKind::InvalidData,
format!("Sequence mismatch: expected {}, got {}", self.seq_no, msg_seq),
));
}
self.seq_no += 1;
// Read data (length - 4 len - 4 seq - 4 crc = len - 12)
let data_len = len - 12;
let mut data = vec![0u8; data_len];
self.upstream.read_exact(&mut data).await?;
// Read and verify CRC32
let mut crc_bytes = [0u8; 4];
self.upstream.read_exact(&mut crc_bytes).await?;
let expected_crc = u32::from_le_bytes(crc_bytes);
// Compute CRC over len + seq + data
let mut crc_input = Vec::with_capacity(8 + data_len);
crc_input.extend_from_slice(&len_bytes);
crc_input.extend_from_slice(&seq_bytes);
crc_input.extend_from_slice(&data);
let computed_crc = crc32(&crc_input);
if computed_crc != expected_crc {
return Err(Error::new(
ErrorKind::InvalidData,
format!("CRC mismatch: expected {:08x}, got {:08x}", expected_crc, computed_crc),
));
}
return Ok(Bytes::from(data));
}
}
}
/// Writer for full MTProto framing
pub struct MtprotoFrameWriter<W> {
upstream: W,
seq_no: i32,
}
impl<W> MtprotoFrameWriter<W> {
pub fn new(upstream: W, start_seq: i32) -> Self {
Self { upstream, seq_no: start_seq }
}
}
impl<W: AsyncWrite + Unpin> MtprotoFrameWriter<W> {
pub async fn write_frame(&mut self, msg: &[u8]) -> Result<()> {
// Total length: 4 (len) + 4 (seq) + data + 4 (crc)
let len = msg.len() + 12;
let len_bytes = (len as u32).to_le_bytes();
let seq_bytes = self.seq_no.to_le_bytes();
self.seq_no += 1;
// Compute CRC
let mut crc_input = Vec::with_capacity(8 + msg.len());
crc_input.extend_from_slice(&len_bytes);
crc_input.extend_from_slice(&seq_bytes);
crc_input.extend_from_slice(msg);
let checksum = crc32(&crc_input);
let crc_bytes = checksum.to_le_bytes();
// Calculate padding for CBC alignment
let total_len = len_bytes.len() + seq_bytes.len() + msg.len() + crc_bytes.len();
let padding_needed = (CBC_PADDING - (total_len % CBC_PADDING)) % CBC_PADDING;
let padding_count = padding_needed / PADDING_FILLER.len();
// Write everything
self.upstream.write_all(&len_bytes).await?;
self.upstream.write_all(&seq_bytes).await?;
self.upstream.write_all(msg).await?;
self.upstream.write_all(&crc_bytes).await?;
for _ in 0..padding_count {
self.upstream.write_all(&PADDING_FILLER).await?;
}
Ok(())
}
pub async fn flush(&mut self) -> Result<()> {
self.upstream.flush().await
}
}
// ============= Frame Type Enum =============
/// Enum for different frame stream types
pub enum FrameReaderKind<R> {
Abridged(AbridgedFrameReader<R>),
Intermediate(IntermediateFrameReader<R>),
SecureIntermediate(SecureIntermediateFrameReader<R>),
}
impl<R: AsyncRead + Unpin> FrameReaderKind<R> {
pub fn new(upstream: R, proto_tag: ProtoTag) -> Self {
match proto_tag {
ProtoTag::Abridged => FrameReaderKind::Abridged(AbridgedFrameReader::new(upstream)),
ProtoTag::Intermediate => FrameReaderKind::Intermediate(IntermediateFrameReader::new(upstream)),
ProtoTag::Secure => FrameReaderKind::SecureIntermediate(SecureIntermediateFrameReader::new(upstream)),
}
}
pub async fn read_frame(&mut self) -> Result<(Bytes, FrameMeta)> {
match self {
FrameReaderKind::Abridged(r) => r.read_frame().await,
FrameReaderKind::Intermediate(r) => r.read_frame().await,
FrameReaderKind::SecureIntermediate(r) => r.read_frame().await,
}
}
}
pub enum FrameWriterKind<W> {
Abridged(AbridgedFrameWriter<W>),
Intermediate(IntermediateFrameWriter<W>),
SecureIntermediate(SecureIntermediateFrameWriter<W>),
}
impl<W: AsyncWrite + Unpin> FrameWriterKind<W> {
pub fn new(upstream: W, proto_tag: ProtoTag) -> Self {
match proto_tag {
ProtoTag::Abridged => FrameWriterKind::Abridged(AbridgedFrameWriter::new(upstream)),
ProtoTag::Intermediate => FrameWriterKind::Intermediate(IntermediateFrameWriter::new(upstream)),
ProtoTag::Secure => FrameWriterKind::SecureIntermediate(SecureIntermediateFrameWriter::new(upstream)),
}
}
pub async fn write_frame(&mut self, data: &[u8], meta: &FrameMeta) -> Result<()> {
match self {
FrameWriterKind::Abridged(w) => w.write_frame(data, meta).await,
FrameWriterKind::Intermediate(w) => w.write_frame(data, meta).await,
FrameWriterKind::SecureIntermediate(w) => w.write_frame(data, meta).await,
}
}
pub async fn flush(&mut self) -> Result<()> {
match self {
FrameWriterKind::Abridged(w) => w.flush().await,
FrameWriterKind::Intermediate(w) => w.flush().await,
FrameWriterKind::SecureIntermediate(w) => w.flush().await,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::duplex;
#[tokio::test]
async fn test_abridged_roundtrip() {
let (client, server) = duplex(1024);
let mut writer = AbridgedFrameWriter::new(client);
let mut reader = AbridgedFrameReader::new(server);
// Short frame
let data = vec![1u8, 2, 3, 4]; // 4 bytes = 1 word
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
writer.flush().await.unwrap();
let (received, _meta) = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &data[..]);
}
#[tokio::test]
async fn test_abridged_long_frame() {
let (client, server) = duplex(65536);
let mut writer = AbridgedFrameWriter::new(client);
let mut reader = AbridgedFrameReader::new(server);
// Long frame (> 0x7f words = 508 bytes)
let data: Vec<u8> = (0..1000).map(|i| (i % 256) as u8).collect();
let padded_len = (data.len() + 3) / 4 * 4;
let mut padded = data.clone();
padded.resize(padded_len, 0);
writer.write_frame(&padded, &FrameMeta::new()).await.unwrap();
writer.flush().await.unwrap();
let (received, _meta) = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &padded[..]);
}
#[tokio::test]
async fn test_intermediate_roundtrip() {
let (client, server) = duplex(1024);
let mut writer = IntermediateFrameWriter::new(client);
let mut reader = IntermediateFrameReader::new(server);
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
writer.flush().await.unwrap();
let (received, _meta) = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &data[..]);
}
#[tokio::test]
async fn test_secure_intermediate_padding() {
let (client, server) = duplex(1024);
let mut writer = SecureIntermediateFrameWriter::new(client);
let mut reader = SecureIntermediateFrameReader::new(server);
let data = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
writer.flush().await.unwrap();
let (received, _meta) = reader.read_frame().await.unwrap();
// Received should have padding stripped to align to 4
let expected_len = (data.len() / 4) * 4;
assert_eq!(received.len(), expected_len);
}
#[tokio::test]
async fn test_mtproto_frame_roundtrip() {
let (client, server) = duplex(1024);
let mut writer = MtprotoFrameWriter::new(client, 0);
let mut reader = MtprotoFrameReader::new(server, 0);
// Message must be padded properly
let data = vec![0u8; 16]; // Aligned to 4 and CBC_PADDING
writer.write_frame(&data).await.unwrap();
writer.flush().await.unwrap();
let received = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &data[..]);
}
#[tokio::test]
async fn test_frame_reader_kind() {
let (client, server) = duplex(1024);
let mut writer = FrameWriterKind::new(client, ProtoTag::Intermediate);
let mut reader = FrameReaderKind::new(server, ProtoTag::Intermediate);
let data = vec![1u8, 2, 3, 4];
writer.write_frame(&data, &FrameMeta::new()).await.unwrap();
writer.flush().await.unwrap();
let (received, _) = reader.read_frame().await.unwrap();
assert_eq!(&received[..], &data[..]);
}
}

10
src/stream/mod.rs Normal file
View File

@@ -0,0 +1,10 @@
//! Stream wrappers for MTProto protocol layers
pub mod traits;
pub mod crypto_stream;
pub mod tls_stream;
pub mod frame_stream;
pub use crypto_stream::{CryptoReader, CryptoWriter, PassthroughStream};
pub use tls_stream::{FakeTlsReader, FakeTlsWriter};
pub use frame_stream::*;

277
src/stream/tls_stream.rs Normal file
View File

@@ -0,0 +1,277 @@
//! Fake TLS 1.3 stream wrappers
use bytes::{Bytes, BytesMut};
use std::io::{Error, ErrorKind, Result};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, AsyncReadExt, AsyncWriteExt, ReadBuf};
use crate::protocol::constants::{
TLS_VERSION, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER,
MAX_TLS_CHUNK_SIZE,
};
use parking_lot::Mutex;
/// Reader that unwraps TLS 1.3 records
pub struct FakeTlsReader<R> {
upstream: R,
buffer: BytesMut,
pending_read: Option<PendingTlsRead>,
}
struct PendingTlsRead {
record_type: u8,
remaining: usize,
}
impl<R> FakeTlsReader<R> {
/// Create new fake TLS reader
pub fn new(upstream: R) -> Self {
Self {
upstream,
buffer: BytesMut::with_capacity(16384),
pending_read: None,
}
}
/// Get reference to upstream
pub fn get_ref(&self) -> &R {
&self.upstream
}
/// Get mutable reference to upstream
pub fn get_mut(&mut self) -> &mut R {
&mut self.upstream
}
/// Consume and return upstream
pub fn into_inner(self) -> R {
self.upstream
}
}
impl<R: AsyncRead + Unpin> FakeTlsReader<R> {
/// Read exactly n bytes through TLS layer
pub async fn read_exact(&mut self, n: usize) -> Result<Bytes> {
while self.buffer.len() < n {
let data = self.read_tls_record().await?;
if data.is_empty() {
return Err(Error::new(ErrorKind::UnexpectedEof, "Connection closed"));
}
self.buffer.extend_from_slice(&data);
}
Ok(self.buffer.split_to(n).freeze())
}
/// Read a single TLS record
async fn read_tls_record(&mut self) -> Result<Vec<u8>> {
loop {
// Read TLS record header (5 bytes)
let mut header = [0u8; 5];
self.upstream.read_exact(&mut header).await?;
let record_type = header[0];
let version = [header[1], header[2]];
let length = u16::from_be_bytes([header[3], header[4]]) as usize;
// Validate version
if version != TLS_VERSION {
return Err(Error::new(
ErrorKind::InvalidData,
format!("Invalid TLS version: {:02x?}", version),
));
}
// Read record body
let mut data = vec![0u8; length];
self.upstream.read_exact(&mut data).await?;
match record_type {
TLS_RECORD_CHANGE_CIPHER => continue, // Skip
TLS_RECORD_APPLICATION => return Ok(data),
_ => {
return Err(Error::new(
ErrorKind::InvalidData,
format!("Unexpected TLS record type: 0x{:02x}", record_type),
));
}
}
}
}
}
impl<R: AsyncRead + Unpin> AsyncRead for FakeTlsReader<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
// Drain buffer first
if !self.buffer.is_empty() {
let to_copy = self.buffer.len().min(buf.remaining());
buf.put_slice(&self.buffer.split_to(to_copy));
return Poll::Ready(Ok(()));
}
// We need to read a TLS record, but poll_read doesn't support async/await
// So we'll do a simplified version that reads header synchronously
// Read header
let mut header = [0u8; 5];
let mut header_buf = ReadBuf::new(&mut header);
match Pin::new(&mut self.upstream).poll_read(cx, &mut header_buf) {
Poll::Ready(Ok(())) => {
if header_buf.filled().len() < 5 {
// Need more data - store what we have and return pending
// For simplicity, we'll just return empty
return Poll::Ready(Ok(()));
}
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
let record_type = header[0];
let length = u16::from_be_bytes([header[3], header[4]]) as usize;
if record_type == TLS_RECORD_CHANGE_CIPHER {
// Skip this record, try again
cx.waker().wake_by_ref();
return Poll::Pending;
}
if record_type != TLS_RECORD_APPLICATION {
return Poll::Ready(Err(Error::new(
ErrorKind::InvalidData,
"Invalid TLS record type",
)));
}
// Read body
let mut body = vec![0u8; length];
let mut body_buf = ReadBuf::new(&mut body);
match Pin::new(&mut self.upstream).poll_read(cx, &mut body_buf) {
Poll::Ready(Ok(())) => {
let filled = body_buf.filled();
let to_copy = filled.len().min(buf.remaining());
buf.put_slice(&filled[..to_copy]);
if filled.len() > to_copy {
self.buffer.extend_from_slice(&filled[to_copy..]);
}
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
}
/// Writer that wraps data in TLS 1.3 records
pub struct FakeTlsWriter<W> {
upstream: W,
}
impl<W> FakeTlsWriter<W> {
/// Create new fake TLS writer
pub fn new(upstream: W) -> Self {
Self { upstream }
}
/// Get reference to upstream
pub fn get_ref(&self) -> &W {
&self.upstream
}
/// Get mutable reference to upstream
pub fn get_mut(&mut self) -> &mut W {
&mut self.upstream
}
/// Consume and return upstream
pub fn into_inner(self) -> W {
self.upstream
}
}
impl<W: AsyncWrite + Unpin> AsyncWrite for FakeTlsWriter<W> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
// Build TLS record
let chunk_size = buf.len().min(MAX_TLS_CHUNK_SIZE);
let chunk = &buf[..chunk_size];
let mut record = Vec::with_capacity(5 + chunk_size);
record.push(TLS_RECORD_APPLICATION);
record.extend_from_slice(&TLS_VERSION);
record.push((chunk_size >> 8) as u8);
record.push(chunk_size as u8);
record.extend_from_slice(chunk);
match Pin::new(&mut self.upstream).poll_write(cx, &record) {
Poll::Ready(Ok(written)) => {
if written >= 5 {
Poll::Ready(Ok(written - 5))
} else {
Poll::Ready(Ok(0))
}
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.upstream).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.upstream).poll_shutdown(cx)
}
}
impl<W: AsyncWrite + Unpin> FakeTlsWriter<W> {
/// Write all data wrapped in TLS records (async method)
pub async fn write_all_tls(&mut self, data: &[u8]) -> Result<()> {
for chunk in data.chunks(MAX_TLS_CHUNK_SIZE) {
let header = [
TLS_RECORD_APPLICATION,
TLS_VERSION[0],
TLS_VERSION[1],
(chunk.len() >> 8) as u8,
chunk.len() as u8,
];
self.upstream.write_all(&header).await?;
self.upstream.write_all(chunk).await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::duplex;
#[tokio::test]
async fn test_tls_stream_roundtrip() {
let (client, server) = duplex(4096);
let mut writer = FakeTlsWriter::new(client);
let mut reader = FakeTlsReader::new(server);
let original = b"Hello, fake TLS!";
writer.write_all_tls(original).await.unwrap();
writer.flush().await.unwrap();
let received = reader.read_exact(original.len()).await.unwrap();
assert_eq!(&received[..], original);
}
}

113
src/stream/traits.rs Normal file
View File

@@ -0,0 +1,113 @@
//! Stream traits and common types
use bytes::Bytes;
use std::io::Result;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
/// Extra metadata for frames
#[derive(Debug, Clone, Default)]
pub struct FrameMeta {
/// Quick ACK requested
pub quickack: bool,
/// This is a simple ACK message
pub simple_ack: bool,
/// Skip sending this frame
pub skip_send: bool,
}
impl FrameMeta {
pub fn new() -> Self {
Self::default()
}
pub fn with_quickack(mut self) -> Self {
self.quickack = true;
self
}
pub fn with_simple_ack(mut self) -> Self {
self.simple_ack = true;
self
}
}
/// Result of reading a frame
#[derive(Debug)]
pub enum ReadFrameResult {
/// Frame data with metadata
Frame(Bytes, FrameMeta),
/// Connection closed
Closed,
}
/// Trait for streams that wrap another stream
pub trait LayeredStream<U> {
/// Get reference to upstream
fn upstream(&self) -> &U;
/// Get mutable reference to upstream
fn upstream_mut(&mut self) -> &mut U;
/// Consume self and return upstream
fn into_upstream(self) -> U;
}
/// A split read half of a stream
pub struct ReadHalf<R> {
inner: R,
}
impl<R> ReadHalf<R> {
pub fn new(inner: R) -> Self {
Self { inner }
}
pub fn into_inner(self) -> R {
self.inner
}
}
impl<R: AsyncRead + Unpin> AsyncRead for ReadHalf<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
/// A split write half of a stream
pub struct WriteHalf<W> {
inner: W,
}
impl<W> WriteHalf<W> {
pub fn new(inner: W) -> Self {
Self { inner }
}
pub fn into_inner(self) -> W {
self.inner
}
}
impl<W: AsyncWrite + Unpin> AsyncWrite for WriteHalf<W> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}