Harden masking fallback and frame readers after flow sync

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey
2026-06-17 21:48:57 +03:00
parent 49742d38a7
commit 72800e4aa7
13 changed files with 401 additions and 88 deletions
+134 -21
View File
@@ -11,16 +11,41 @@ use std::io::{Error, ErrorKind, Result};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
const DEFAULT_MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
fn reject_oversize_frame(len: usize, max_frame_size: usize, protocol: &str) -> Result<()> {
if len > max_frame_size {
return Err(Error::new(
ErrorKind::InvalidData,
format!("{protocol} frame too large: {len} bytes (max {max_frame_size})"),
));
}
Ok(())
}
// ============= Abridged (Compact) Frame =============
/// Reader for abridged MTProto framing
pub struct AbridgedFrameReader<R> {
upstream: R,
max_frame_size: usize,
}
impl<R> AbridgedFrameReader<R> {
/// Creates a reader with the default maximum frame size.
pub fn new(upstream: R) -> Self {
Self { upstream }
Self {
upstream,
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
}
}
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
Self {
upstream,
max_frame_size,
}
}
}
@@ -48,10 +73,12 @@ impl<R: AsyncRead + Unpin> AbridgedFrameReader<R> {
len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], 0]) as usize;
}
// Length is in 4-byte words
let byte_len = len * 4;
// Length is in 4-byte words.
let byte_len = len
.checked_mul(4)
.ok_or_else(|| Error::new(ErrorKind::InvalidData, "abridged frame length overflow"))?;
reject_oversize_frame(byte_len, self.max_frame_size, "abridged")?;
// Read data
let mut data = vec![0u8; byte_len];
self.upstream.read_exact(&mut data).await?;
@@ -152,11 +179,23 @@ impl<W> LayeredStream<W> for AbridgedFrameWriter<W> {
/// Reader for intermediate MTProto framing
pub struct IntermediateFrameReader<R> {
upstream: R,
max_frame_size: usize,
}
impl<R> IntermediateFrameReader<R> {
/// Creates a reader with the default maximum frame size.
pub fn new(upstream: R) -> Self {
Self { upstream }
Self {
upstream,
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
}
}
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
Self {
upstream,
max_frame_size,
}
}
}
@@ -171,8 +210,8 @@ impl<R: AsyncRead + Unpin> IntermediateFrameReader<R> {
let header = parse_intermediate_header(len_bytes);
let len = header.wire_len;
meta.quickack = header.quickack;
reject_oversize_frame(len, self.max_frame_size, "intermediate")?;
// Read data
let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?;
@@ -243,11 +282,23 @@ impl<W> LayeredStream<W> for IntermediateFrameWriter<W> {
/// Reader for secure intermediate MTProto framing (with padding)
pub struct SecureIntermediateFrameReader<R> {
upstream: R,
max_frame_size: usize,
}
impl<R> SecureIntermediateFrameReader<R> {
/// Creates a reader with the default maximum frame size.
pub fn new(upstream: R) -> Self {
Self { upstream }
Self {
upstream,
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
}
}
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
Self {
upstream,
max_frame_size,
}
}
}
@@ -262,17 +313,16 @@ impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
let header = parse_intermediate_header(len_bytes);
let len = header.wire_len;
meta.quickack = header.quickack;
// Read data (including padding)
let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?;
reject_oversize_frame(len, self.max_frame_size, "secure intermediate")?;
let payload_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
Error::new(
ErrorKind::InvalidData,
format!("Invalid secure frame length: {len}"),
)
})?;
let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?;
data.truncate(payload_len);
Ok((Bytes::from(data), meta))
@@ -321,7 +371,9 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
let padding_len = secure_padding_len(data.len(), &self.rng);
let padding = self.rng.bytes(padding_len);
let total_len = data.len() + padding_len;
let total_len = data.len().checked_add(padding_len).ok_or_else(|| {
Error::new(ErrorKind::InvalidInput, "secure frame length overflow")
})?;
let len = encode_intermediate_header(total_len, meta.quickack).ok_or_else(|| {
Error::new(
ErrorKind::InvalidInput,
@@ -507,15 +559,26 @@ pub enum FrameReaderKind<R> {
}
impl<R: AsyncRead + Unpin> FrameReaderKind<R> {
/// Creates a frame reader with the default maximum frame size.
pub fn new(upstream: R, proto_tag: ProtoTag) -> Self {
Self::with_max_frame_size(upstream, proto_tag, DEFAULT_MAX_FRAME_SIZE)
}
fn with_max_frame_size(
upstream: R,
proto_tag: ProtoTag,
max_frame_size: usize,
) -> Self {
match proto_tag {
ProtoTag::Abridged => FrameReaderKind::Abridged(AbridgedFrameReader::new(upstream)),
ProtoTag::Intermediate => {
FrameReaderKind::Intermediate(IntermediateFrameReader::new(upstream))
}
ProtoTag::Secure => {
FrameReaderKind::SecureIntermediate(SecureIntermediateFrameReader::new(upstream))
}
ProtoTag::Abridged => FrameReaderKind::Abridged(
AbridgedFrameReader::with_max_frame_size(upstream, max_frame_size),
),
ProtoTag::Intermediate => FrameReaderKind::Intermediate(
IntermediateFrameReader::with_max_frame_size(upstream, max_frame_size),
),
ProtoTag::Secure => FrameReaderKind::SecureIntermediate(
SecureIntermediateFrameReader::with_max_frame_size(upstream, max_frame_size),
),
}
}
@@ -569,7 +632,8 @@ mod tests {
use super::*;
use crate::crypto::SecureRandom;
use std::sync::Arc;
use tokio::io::duplex;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, timeout};
fn assert_secure_decoded_payload(decoded: &[u8], original: &[u8]) {
assert!(decoded.starts_with(original));
@@ -672,6 +736,55 @@ mod tests {
assert!(meta.quickack);
}
#[tokio::test]
async fn abridged_reader_rejects_oversize_frame_before_body_read() {
let (mut client, server) = duplex(1024);
let mut reader = AbridgedFrameReader::new(server);
let len_words = (DEFAULT_MAX_FRAME_SIZE / 4) + 1;
let encoded = (len_words as u32).to_le_bytes();
client
.write_all(&[0x7f, encoded[0], encoded[1], encoded[2]])
.await
.unwrap();
let err = timeout(Duration::from_millis(50), reader.read_frame())
.await
.unwrap()
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidData);
}
#[tokio::test]
async fn intermediate_reader_rejects_oversize_frame_before_body_read() {
let (mut client, server) = duplex(1024);
let mut reader = IntermediateFrameReader::new(server);
let len = encode_intermediate_header(DEFAULT_MAX_FRAME_SIZE + 1, false).unwrap();
client.write_all(&len.to_le_bytes()).await.unwrap();
let err = timeout(Duration::from_millis(50), reader.read_frame())
.await
.unwrap()
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidData);
}
#[tokio::test]
async fn secure_reader_rejects_oversize_frame_before_body_read() {
let (mut client, server) = duplex(1024);
let mut reader = SecureIntermediateFrameReader::new(server);
let len = encode_intermediate_header(DEFAULT_MAX_FRAME_SIZE + 4, false).unwrap();
client.write_all(&len.to_le_bytes()).await.unwrap();
let err = timeout(Duration::from_millis(50), reader.read_frame())
.await
.unwrap()
.unwrap_err();
assert_eq!(err.kind(), ErrorKind::InvalidData);
}
#[tokio::test]
async fn test_secure_intermediate_padding() {
let (client, server) = duplex(1024);