Flush-response experiments

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey 2026-02-22 23:53:10 +03:00
parent d552ae84d0
commit 197f9867e0
No known key found for this signature in database
7 changed files with 227 additions and 126 deletions

View File

@ -159,10 +159,13 @@ pub const MAX_TLS_CHUNK_SIZE: usize = 16384 + 256;
/// Generate padding length for Secure Intermediate protocol.
/// Total (data + padding) must not be divisible by 4 per MTProto spec.
pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize {
if data_len % 4 == 0 {
(rng.range(3) + 1) as usize // 1-3
} else {
rng.range(4) as usize // 0-3
let rem = data_len % 4;
match rem {
0 => (rng.range(3) + 1) as usize, // {1, 2, 3}
1 => rng.range(3) as usize, // {0, 1, 2}
2 => [0usize, 1, 3][rng.range(3) as usize], // {0, 1, 3}
3 => [0usize, 2, 3][rng.range(3) as usize], // {0, 2, 3}
_ => unreachable!(),
}
}
@ -332,4 +335,24 @@ mod tests {
assert_eq!(TG_DATACENTERS_V4.len(), 5);
assert_eq!(TG_DATACENTERS_V6.len(), 5);
}
#[test]
fn secure_padding_never_produces_aligned_total() {
let rng = SecureRandom::new();
for data_len in 0..1000 {
for _ in 0..100 {
let padding = secure_padding_len(data_len, &rng);
assert!(
padding <= 3,
"padding out of range: data_len={data_len}, padding={padding}"
);
assert_ne!(
(data_len + padding) % 4,
0,
"invariant violated: data_len={data_len}, padding={padding}, total={}",
data_len + padding
);
}
}
}
}

View File

@ -74,6 +74,34 @@ where
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
stats_clone.add_user_octets_to(&user_clone, data.len() as u64);
write_client_payload(&mut writer, proto_tag, flags, &data, rng_clone.as_ref()).await?;
// Drain all immediately queued ME responses and flush once.
while let Ok(next) = me_rx_task.try_recv() {
match next {
MeResponse::Data { flags, data } => {
trace!(conn_id, bytes = data.len(), flags, "ME->C data (batched)");
stats_clone.add_user_octets_to(&user_clone, data.len() as u64);
write_client_payload(
&mut writer,
proto_tag,
flags,
&data,
rng_clone.as_ref(),
).await?;
}
MeResponse::Ack(confirm) => {
trace!(conn_id, confirm, "ME->C quickack (batched)");
write_client_ack(&mut writer, proto_tag, confirm).await?;
}
MeResponse::Close => {
debug!(conn_id, "ME sent close (batched)");
let _ = writer.flush().await;
return Ok(());
}
}
}
writer.flush().await.map_err(ProxyError::Io)?;
}
Some(MeResponse::Ack(confirm)) => {
trace!(conn_id, confirm, "ME->C quickack");
@ -81,6 +109,7 @@ where
}
Some(MeResponse::Close) => {
debug!(conn_id, "ME sent close");
let _ = writer.flush().await;
return Ok(());
}
None => {
@ -99,8 +128,15 @@ where
let mut main_result: Result<()> = Ok(());
let mut client_closed = false;
let mut frame_counter: u64 = 0;
loop {
match read_client_payload(&mut crypto_reader, proto_tag, frame_limit, &user).await {
match read_client_payload(
&mut crypto_reader,
proto_tag,
frame_limit,
&user,
&mut frame_counter,
).await {
Ok(Some((payload, quickack))) => {
trace!(conn_id, bytes = payload.len(), "C->ME frame");
stats.add_user_octets_from(&user, payload.len() as u64);
@ -168,73 +204,111 @@ async fn read_client_payload<R>(
proto_tag: ProtoTag,
max_frame: usize,
user: &str,
frame_counter: &mut u64,
) -> Result<Option<(Vec<u8>, bool)>>
where
R: AsyncRead + Unpin + Send + 'static,
{
let (len, quickack) = match proto_tag {
ProtoTag::Abridged => {
let mut first = [0u8; 1];
match client_reader.read_exact(&mut first).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ProxyError::Io(e)),
loop {
let (len, quickack, raw_len_bytes) = match proto_tag {
ProtoTag::Abridged => {
let mut first = [0u8; 1];
match client_reader.read_exact(&mut first).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ProxyError::Io(e)),
}
let quickack = (first[0] & 0x80) != 0;
let len_words = if (first[0] & 0x7f) == 0x7f {
let mut ext = [0u8; 3];
client_reader
.read_exact(&mut ext)
.await
.map_err(ProxyError::Io)?;
u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize
} else {
(first[0] & 0x7f) as usize
};
let len = len_words
.checked_mul(4)
.ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?;
(len, quickack, None)
}
let quickack = (first[0] & 0x80) != 0;
let len_words = if (first[0] & 0x7f) == 0x7f {
let mut ext = [0u8; 3];
client_reader
.read_exact(&mut ext)
.await
.map_err(ProxyError::Io)?;
u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize
} else {
(first[0] & 0x7f) as usize
};
let len = len_words
.checked_mul(4)
.ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?;
(len, quickack)
}
ProtoTag::Intermediate | ProtoTag::Secure => {
let mut len_buf = [0u8; 4];
match client_reader.read_exact(&mut len_buf).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ProxyError::Io(e)),
ProtoTag::Intermediate | ProtoTag::Secure => {
let mut len_buf = [0u8; 4];
match client_reader.read_exact(&mut len_buf).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(ProxyError::Io(e)),
}
let quickack = (len_buf[3] & 0x80) != 0;
(
(u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize,
quickack,
Some(len_buf),
)
}
let quickack = (len_buf[3] & 0x80) != 0;
((u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize, quickack)
};
if len == 0 {
continue;
}
};
if len > max_frame {
warn!(
user = %user,
raw_len = len,
raw_len_hex = format_args!("0x{:08x}", len),
proto = ?proto_tag,
"Frame too large — possible crypto desync or TLS record error"
);
return Err(ProxyError::Proxy(format!("Frame too large: {len} (max {max_frame})")));
}
let mut payload = vec![0u8; len];
client_reader
.read_exact(&mut payload)
.await
.map_err(ProxyError::Io)?;
// Secure Intermediate: remove random padding (last len%4 bytes)
if proto_tag == ProtoTag::Secure {
let rem = len % 4;
if rem != 0 && payload.len() >= rem {
payload.truncate(len - rem);
if len < 4 && proto_tag != ProtoTag::Abridged {
warn!(
user = %user,
len,
proto = ?proto_tag,
"Frame too small — corrupt or probe"
);
return Err(ProxyError::Proxy(format!("Frame too small: {len}")));
}
if len > max_frame {
let len_buf = raw_len_bytes.unwrap_or((len as u32).to_le_bytes());
let looks_like_tls = raw_len_bytes
.map(|b| b[0] == 0x16 && b[1] == 0x03)
.unwrap_or(false);
let looks_like_http = raw_len_bytes
.map(|b| matches!(b[0], b'G' | b'P' | b'H' | b'C' | b'D'))
.unwrap_or(false);
warn!(
user = %user,
raw_len = len,
raw_len_hex = format_args!("0x{:08x}", len),
raw_bytes = format_args!(
"{:02x} {:02x} {:02x} {:02x}",
len_buf[0], len_buf[1], len_buf[2], len_buf[3]
),
proto = ?proto_tag,
tls_like = looks_like_tls,
http_like = looks_like_http,
frames_ok = *frame_counter,
"Frame too large — crypto desync forensics"
);
return Err(ProxyError::Proxy(format!(
"Frame too large: {len} (max {max_frame}), frames_ok={}",
*frame_counter
)));
}
let mut payload = vec![0u8; len];
client_reader
.read_exact(&mut payload)
.await
.map_err(ProxyError::Io)?;
// Secure Intermediate: remove random padding (last len%4 bytes)
if proto_tag == ProtoTag::Secure {
let rem = len % 4;
if rem != 0 && payload.len() >= rem {
payload.truncate(len - rem);
}
}
*frame_counter += 1;
return Ok(Some((payload, quickack)));
}
Ok(Some((payload, quickack)))
}
async fn write_client_payload<W>(
@ -264,8 +338,11 @@ where
if quickack {
first |= 0x80;
}
let mut frame_buf = Vec::with_capacity(1 + data.len());
frame_buf.push(first);
frame_buf.extend_from_slice(data);
client_writer
.write_all(&[first])
.write_all(&frame_buf)
.await
.map_err(ProxyError::Io)?;
} else if len_words < (1 << 24) {
@ -274,8 +351,11 @@ where
first |= 0x80;
}
let lw = (len_words as u32).to_le_bytes();
let mut frame_buf = Vec::with_capacity(4 + data.len());
frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]);
frame_buf.extend_from_slice(data);
client_writer
.write_all(&[first, lw[0], lw[1], lw[2]])
.write_all(&frame_buf)
.await
.map_err(ProxyError::Io)?;
} else {
@ -284,11 +364,6 @@ where
data.len()
)));
}
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
}
ProtoTag::Intermediate | ProtoTag::Secure => {
let padding_len = if proto_tag == ProtoTag::Secure {
@ -296,35 +371,24 @@ where
} else {
0
};
let mut len = (data.len() + padding_len) as u32;
let mut len_val = (data.len() + padding_len) as u32;
if quickack {
len |= 0x8000_0000;
len_val |= 0x8000_0000;
}
client_writer
.write_all(&len.to_le_bytes())
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
let total = 4 + data.len() + padding_len;
let mut frame_buf = Vec::with_capacity(total);
frame_buf.extend_from_slice(&len_val.to_le_bytes());
frame_buf.extend_from_slice(data);
if padding_len > 0 {
let pad = rng.bytes(padding_len);
client_writer
.write_all(&pad)
.await
.map_err(ProxyError::Io)?;
frame_buf.extend_from_slice(&rng.bytes(padding_len));
}
client_writer
.write_all(&frame_buf)
.await
.map_err(ProxyError::Io)?;
}
}
// Avoid unconditional per-frame flush (throughput killer on large downloads).
// Flush only when low-latency ack semantics are requested or when
// CryptoWriter has buffered pending ciphertext that must be drained.
if quickack || client_writer.has_pending() {
client_writer.flush().await.map_err(ProxyError::Io)?;
}
Ok(())
}

View File

@ -8,7 +8,7 @@ use std::io::{self, Error, ErrorKind};
use std::sync::Arc;
use tokio_util::codec::{Decoder, Encoder};
use crate::protocol::constants::ProtoTag;
use crate::protocol::constants::{ProtoTag, secure_padding_len};
use crate::crypto::SecureRandom;
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
@ -303,14 +303,8 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::R
return Ok(());
}
// Generate padding to make length not divisible by 4
let padding_len = if data.len() % 4 == 0 {
// Add 1-3 bytes to make it non-aligned
(rng.range(3) + 1) as usize
} else {
// Already non-aligned, can add 0-3
rng.range(4) as usize
};
// Generate padding that keeps total length non-divisible by 4.
let padding_len = secure_padding_len(data.len(), rng);
let total_len = data.len() + padding_len;
dst.reserve(4 + total_len);

View File

@ -190,11 +190,26 @@ impl RpcWriter {
self.writer.flush().await.map_err(ProxyError::Io)
}
pub(crate) async fn send_keepalive(&mut self, payload: [u8; 4]) -> Result<()> {
// Keepalive is a frame with fl == 4 and 4 bytes payload.
let mut frame = Vec::with_capacity(8);
frame.extend_from_slice(&4u32.to_le_bytes());
frame.extend_from_slice(&payload);
self.send(&frame).await
/// Sends a 4-byte keepalive marker directly into the CBC stream.
/// This is not an RPC frame and must not consume sequence numbers.
pub(crate) async fn send_keepalive(&mut self) -> Result<()> {
let mut buf = [0u8; 16];
for i in 0..4 {
let start = i * 4;
let end = start + 4;
buf[start..end].copy_from_slice(&PADDING_FILLER);
}
let cipher = AesCbc::new(self.key, self.iv);
let mut v = buf.to_vec();
cipher
.encrypt_in_place(&mut v)
.map_err(|e| ProxyError::Crypto(format!("{e}")))?;
if v.len() >= 16 {
self.iv.copy_from_slice(&v[v.len() - 16..]);
}
self.writer.write_all(&v).await.map_err(ProxyError::Io)?;
self.writer.flush().await.map_err(ProxyError::Io)
}
}

View File

@ -23,7 +23,6 @@ use super::reader::reader_loop;
use super::MeResponse;
const ME_ACTIVE_PING_SECS: u64 = 25;
const ME_ACTIVE_PING_JITTER_SECS: i64 = 5;
const ME_KEEPALIVE_PAYLOAD_LEN: usize = 4;
#[derive(Clone)]
pub struct MeWriter {
@ -361,7 +360,6 @@ impl MePool {
// Additional connections up to pool_size total (round-robin across DCs), staggered to de-phase lifecycles.
if self.me_warmup_stagger_enabled {
let mut delay_ms = 0u64;
for (dc, addrs) in dc_addrs.iter() {
for (ip, port) in addrs {
if self.connection_count() >= pool_size {
@ -369,7 +367,7 @@ impl MePool {
}
let addr = SocketAddr::new(*ip, *port);
let jitter = rand::rng().random_range(0..=self.me_warmup_step_jitter.as_millis() as u64);
delay_ms = delay_ms.saturating_add(self.me_warmup_step_delay.as_millis() as u64 + jitter);
let delay_ms = self.me_warmup_step_delay.as_millis() as u64 + jitter;
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
if let Err(e) = self.connect_one(addr, rng.as_ref()).await {
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)");
@ -419,7 +417,6 @@ impl MePool {
let draining = Arc::new(AtomicBool::new(false));
let (tx, mut rx) = mpsc::channel::<WriterCommand>(4096);
let tx_for_keepalive = tx.clone();
let keepalive_random = self.me_keepalive_payload_random;
let stats = self.stats.clone();
let mut rpc_writer = RpcWriter {
writer: hs.wr,
@ -440,11 +437,7 @@ impl MePool {
if rpc_writer.send_and_flush(&payload).await.is_err() { break; }
}
Some(WriterCommand::Keepalive) => {
let mut payload = [0u8; ME_KEEPALIVE_PAYLOAD_LEN];
if keepalive_random {
rand::rng().fill(&mut payload);
}
match rpc_writer.send_keepalive(payload).await {
match rpc_writer.send_keepalive().await {
Ok(()) => {
stats.increment_me_keepalive_sent();
}

View File

@ -33,7 +33,6 @@ pub(crate) async fn reader_loop(
) -> Result<()> {
let mut raw = enc_leftover;
let mut expected_seq: i32 = 0;
let mut crc_errors = 0u32;
let mut seq_mismatch = 0u32;
loop {
@ -80,13 +79,15 @@ pub(crate) async fn reader_loop(
let frame = dec.split_to(fl);
let pe = fl - 4;
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
if crc32(&frame[..pe]) != ec {
warn!("CRC mismatch in data frame");
crc_errors += 1;
if crc_errors > 3 {
return Err(ProxyError::Proxy("Too many CRC mismatches".into()));
}
continue;
let actual_crc = crc32(&frame[..pe]);
if actual_crc != ec {
warn!(
frame_len = fl,
expected_crc = format_args!("0x{ec:08x}"),
actual_crc = format_args!("0x{actual_crc:08x}"),
"CRC mismatch — CBC crypto desync, aborting ME connection"
);
return Err(ProxyError::Proxy("CRC mismatch (crypto desync)".into()));
}
let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap());

View File

@ -24,6 +24,8 @@ const NUM_DCS: usize = 5;
/// Timeout for individual DC ping attempt
const DC_PING_TIMEOUT_SECS: u64 = 5;
/// Timeout for direct TG DC TCP connect readiness.
const DIRECT_CONNECT_TIMEOUT_SECS: u64 = 10;
// ============= RTT Tracking =============
@ -375,7 +377,16 @@ impl UpstreamManager {
let std_stream: std::net::TcpStream = socket.into();
let stream = TcpStream::from_std(std_stream)?;
stream.writable().await?;
let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS);
match tokio::time::timeout(connect_timeout, stream.writable()).await {
Ok(Ok(())) => {}
Ok(Err(e)) => return Err(ProxyError::Io(e)),
Err(_) => {
return Err(ProxyError::ConnectionTimeout {
addr: target.to_string(),
});
}
}
if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e));
}