mirror of
https://github.com/telemt/telemt.git
synced 2026-06-25 12:21:10 +03:00
Harden masking fallback and frame readers after flow sync
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
+134
-21
@@ -11,16 +11,41 @@ use std::io::{Error, ErrorKind, Result};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
const DEFAULT_MAX_FRAME_SIZE: usize = 16 * 1024 * 1024;
|
||||
|
||||
fn reject_oversize_frame(len: usize, max_frame_size: usize, protocol: &str) -> Result<()> {
|
||||
if len > max_frame_size {
|
||||
return Err(Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("{protocol} frame too large: {len} bytes (max {max_frame_size})"),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============= Abridged (Compact) Frame =============
|
||||
|
||||
/// Reader for abridged MTProto framing
|
||||
pub struct AbridgedFrameReader<R> {
|
||||
upstream: R,
|
||||
max_frame_size: usize,
|
||||
}
|
||||
|
||||
impl<R> AbridgedFrameReader<R> {
|
||||
/// Creates a reader with the default maximum frame size.
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self { upstream }
|
||||
Self {
|
||||
upstream,
|
||||
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
|
||||
}
|
||||
}
|
||||
|
||||
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
|
||||
Self {
|
||||
upstream,
|
||||
max_frame_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,10 +73,12 @@ impl<R: AsyncRead + Unpin> AbridgedFrameReader<R> {
|
||||
len = u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], 0]) as usize;
|
||||
}
|
||||
|
||||
// Length is in 4-byte words
|
||||
let byte_len = len * 4;
|
||||
// Length is in 4-byte words.
|
||||
let byte_len = len
|
||||
.checked_mul(4)
|
||||
.ok_or_else(|| Error::new(ErrorKind::InvalidData, "abridged frame length overflow"))?;
|
||||
reject_oversize_frame(byte_len, self.max_frame_size, "abridged")?;
|
||||
|
||||
// Read data
|
||||
let mut data = vec![0u8; byte_len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
@@ -152,11 +179,23 @@ impl<W> LayeredStream<W> for AbridgedFrameWriter<W> {
|
||||
/// Reader for intermediate MTProto framing
|
||||
pub struct IntermediateFrameReader<R> {
|
||||
upstream: R,
|
||||
max_frame_size: usize,
|
||||
}
|
||||
|
||||
impl<R> IntermediateFrameReader<R> {
|
||||
/// Creates a reader with the default maximum frame size.
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self { upstream }
|
||||
Self {
|
||||
upstream,
|
||||
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
|
||||
}
|
||||
}
|
||||
|
||||
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
|
||||
Self {
|
||||
upstream,
|
||||
max_frame_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,8 +210,8 @@ impl<R: AsyncRead + Unpin> IntermediateFrameReader<R> {
|
||||
let header = parse_intermediate_header(len_bytes);
|
||||
let len = header.wire_len;
|
||||
meta.quickack = header.quickack;
|
||||
reject_oversize_frame(len, self.max_frame_size, "intermediate")?;
|
||||
|
||||
// Read data
|
||||
let mut data = vec![0u8; len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
@@ -243,11 +282,23 @@ impl<W> LayeredStream<W> for IntermediateFrameWriter<W> {
|
||||
/// Reader for secure intermediate MTProto framing (with padding)
|
||||
pub struct SecureIntermediateFrameReader<R> {
|
||||
upstream: R,
|
||||
max_frame_size: usize,
|
||||
}
|
||||
|
||||
impl<R> SecureIntermediateFrameReader<R> {
|
||||
/// Creates a reader with the default maximum frame size.
|
||||
pub fn new(upstream: R) -> Self {
|
||||
Self { upstream }
|
||||
Self {
|
||||
upstream,
|
||||
max_frame_size: DEFAULT_MAX_FRAME_SIZE,
|
||||
}
|
||||
}
|
||||
|
||||
fn with_max_frame_size(upstream: R, max_frame_size: usize) -> Self {
|
||||
Self {
|
||||
upstream,
|
||||
max_frame_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -262,17 +313,16 @@ impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
|
||||
let header = parse_intermediate_header(len_bytes);
|
||||
let len = header.wire_len;
|
||||
meta.quickack = header.quickack;
|
||||
|
||||
// Read data (including padding)
|
||||
let mut data = vec![0u8; len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
|
||||
reject_oversize_frame(len, self.max_frame_size, "secure intermediate")?;
|
||||
let payload_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
|
||||
Error::new(
|
||||
ErrorKind::InvalidData,
|
||||
format!("Invalid secure frame length: {len}"),
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut data = vec![0u8; len];
|
||||
self.upstream.read_exact(&mut data).await?;
|
||||
data.truncate(payload_len);
|
||||
|
||||
Ok((Bytes::from(data), meta))
|
||||
@@ -321,7 +371,9 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
|
||||
let padding_len = secure_padding_len(data.len(), &self.rng);
|
||||
let padding = self.rng.bytes(padding_len);
|
||||
|
||||
let total_len = data.len() + padding_len;
|
||||
let total_len = data.len().checked_add(padding_len).ok_or_else(|| {
|
||||
Error::new(ErrorKind::InvalidInput, "secure frame length overflow")
|
||||
})?;
|
||||
let len = encode_intermediate_header(total_len, meta.quickack).ok_or_else(|| {
|
||||
Error::new(
|
||||
ErrorKind::InvalidInput,
|
||||
@@ -507,15 +559,26 @@ pub enum FrameReaderKind<R> {
|
||||
}
|
||||
|
||||
impl<R: AsyncRead + Unpin> FrameReaderKind<R> {
|
||||
/// Creates a frame reader with the default maximum frame size.
|
||||
pub fn new(upstream: R, proto_tag: ProtoTag) -> Self {
|
||||
Self::with_max_frame_size(upstream, proto_tag, DEFAULT_MAX_FRAME_SIZE)
|
||||
}
|
||||
|
||||
fn with_max_frame_size(
|
||||
upstream: R,
|
||||
proto_tag: ProtoTag,
|
||||
max_frame_size: usize,
|
||||
) -> Self {
|
||||
match proto_tag {
|
||||
ProtoTag::Abridged => FrameReaderKind::Abridged(AbridgedFrameReader::new(upstream)),
|
||||
ProtoTag::Intermediate => {
|
||||
FrameReaderKind::Intermediate(IntermediateFrameReader::new(upstream))
|
||||
}
|
||||
ProtoTag::Secure => {
|
||||
FrameReaderKind::SecureIntermediate(SecureIntermediateFrameReader::new(upstream))
|
||||
}
|
||||
ProtoTag::Abridged => FrameReaderKind::Abridged(
|
||||
AbridgedFrameReader::with_max_frame_size(upstream, max_frame_size),
|
||||
),
|
||||
ProtoTag::Intermediate => FrameReaderKind::Intermediate(
|
||||
IntermediateFrameReader::with_max_frame_size(upstream, max_frame_size),
|
||||
),
|
||||
ProtoTag::Secure => FrameReaderKind::SecureIntermediate(
|
||||
SecureIntermediateFrameReader::with_max_frame_size(upstream, max_frame_size),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -569,7 +632,8 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::crypto::SecureRandom;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::duplex;
|
||||
use tokio::io::{AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn assert_secure_decoded_payload(decoded: &[u8], original: &[u8]) {
|
||||
assert!(decoded.starts_with(original));
|
||||
@@ -672,6 +736,55 @@ mod tests {
|
||||
assert!(meta.quickack);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abridged_reader_rejects_oversize_frame_before_body_read() {
|
||||
let (mut client, server) = duplex(1024);
|
||||
let mut reader = AbridgedFrameReader::new(server);
|
||||
let len_words = (DEFAULT_MAX_FRAME_SIZE / 4) + 1;
|
||||
let encoded = (len_words as u32).to_le_bytes();
|
||||
|
||||
client
|
||||
.write_all(&[0x7f, encoded[0], encoded[1], encoded[2]])
|
||||
.await
|
||||
.unwrap();
|
||||
let err = timeout(Duration::from_millis(50), reader.read_frame())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap_err();
|
||||
|
||||
assert_eq!(err.kind(), ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn intermediate_reader_rejects_oversize_frame_before_body_read() {
|
||||
let (mut client, server) = duplex(1024);
|
||||
let mut reader = IntermediateFrameReader::new(server);
|
||||
let len = encode_intermediate_header(DEFAULT_MAX_FRAME_SIZE + 1, false).unwrap();
|
||||
|
||||
client.write_all(&len.to_le_bytes()).await.unwrap();
|
||||
let err = timeout(Duration::from_millis(50), reader.read_frame())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap_err();
|
||||
|
||||
assert_eq!(err.kind(), ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn secure_reader_rejects_oversize_frame_before_body_read() {
|
||||
let (mut client, server) = duplex(1024);
|
||||
let mut reader = SecureIntermediateFrameReader::new(server);
|
||||
let len = encode_intermediate_header(DEFAULT_MAX_FRAME_SIZE + 4, false).unwrap();
|
||||
|
||||
client.write_all(&len.to_le_bytes()).await.unwrap();
|
||||
let err = timeout(Duration::from_millis(50), reader.read_frame())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap_err();
|
||||
|
||||
assert_eq!(err.kind(), ErrorKind::InvalidData);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_secure_intermediate_padding() {
|
||||
let (client, server) = duplex(1024);
|
||||
|
||||
Reference in New Issue
Block a user