Flush-response experiments

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey 2026-02-22 23:53:10 +03:00
parent d552ae84d0
commit 197f9867e0
No known key found for this signature in database
7 changed files with 227 additions and 126 deletions

View File

@ -159,10 +159,13 @@ pub const MAX_TLS_CHUNK_SIZE: usize = 16384 + 256;
/// Generate padding length for Secure Intermediate protocol. /// Generate padding length for Secure Intermediate protocol.
/// Total (data + padding) must not be divisible by 4 per MTProto spec. /// Total (data + padding) must not be divisible by 4 per MTProto spec.
pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize { pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize {
if data_len % 4 == 0 { let rem = data_len % 4;
(rng.range(3) + 1) as usize // 1-3 match rem {
} else { 0 => (rng.range(3) + 1) as usize, // {1, 2, 3}
rng.range(4) as usize // 0-3 1 => rng.range(3) as usize, // {0, 1, 2}
2 => [0usize, 1, 3][rng.range(3) as usize], // {0, 1, 3}
3 => [0usize, 2, 3][rng.range(3) as usize], // {0, 2, 3}
_ => unreachable!(),
} }
} }
@ -332,4 +335,24 @@ mod tests {
assert_eq!(TG_DATACENTERS_V4.len(), 5); assert_eq!(TG_DATACENTERS_V4.len(), 5);
assert_eq!(TG_DATACENTERS_V6.len(), 5); assert_eq!(TG_DATACENTERS_V6.len(), 5);
} }
#[test]
fn secure_padding_never_produces_aligned_total() {
let rng = SecureRandom::new();
for data_len in 0..1000 {
for _ in 0..100 {
let padding = secure_padding_len(data_len, &rng);
assert!(
padding <= 3,
"padding out of range: data_len={data_len}, padding={padding}"
);
assert_ne!(
(data_len + padding) % 4,
0,
"invariant violated: data_len={data_len}, padding={padding}, total={}",
data_len + padding
);
}
}
}
} }

View File

@ -74,6 +74,34 @@ where
trace!(conn_id, bytes = data.len(), flags, "ME->C data"); trace!(conn_id, bytes = data.len(), flags, "ME->C data");
stats_clone.add_user_octets_to(&user_clone, data.len() as u64); stats_clone.add_user_octets_to(&user_clone, data.len() as u64);
write_client_payload(&mut writer, proto_tag, flags, &data, rng_clone.as_ref()).await?; write_client_payload(&mut writer, proto_tag, flags, &data, rng_clone.as_ref()).await?;
// Drain all immediately queued ME responses and flush once.
while let Ok(next) = me_rx_task.try_recv() {
match next {
MeResponse::Data { flags, data } => {
trace!(conn_id, bytes = data.len(), flags, "ME->C data (batched)");
stats_clone.add_user_octets_to(&user_clone, data.len() as u64);
write_client_payload(
&mut writer,
proto_tag,
flags,
&data,
rng_clone.as_ref(),
).await?;
}
MeResponse::Ack(confirm) => {
trace!(conn_id, confirm, "ME->C quickack (batched)");
write_client_ack(&mut writer, proto_tag, confirm).await?;
}
MeResponse::Close => {
debug!(conn_id, "ME sent close (batched)");
let _ = writer.flush().await;
return Ok(());
}
}
}
writer.flush().await.map_err(ProxyError::Io)?;
} }
Some(MeResponse::Ack(confirm)) => { Some(MeResponse::Ack(confirm)) => {
trace!(conn_id, confirm, "ME->C quickack"); trace!(conn_id, confirm, "ME->C quickack");
@ -81,6 +109,7 @@ where
} }
Some(MeResponse::Close) => { Some(MeResponse::Close) => {
debug!(conn_id, "ME sent close"); debug!(conn_id, "ME sent close");
let _ = writer.flush().await;
return Ok(()); return Ok(());
} }
None => { None => {
@ -99,8 +128,15 @@ where
let mut main_result: Result<()> = Ok(()); let mut main_result: Result<()> = Ok(());
let mut client_closed = false; let mut client_closed = false;
let mut frame_counter: u64 = 0;
loop { loop {
match read_client_payload(&mut crypto_reader, proto_tag, frame_limit, &user).await { match read_client_payload(
&mut crypto_reader,
proto_tag,
frame_limit,
&user,
&mut frame_counter,
).await {
Ok(Some((payload, quickack))) => { Ok(Some((payload, quickack))) => {
trace!(conn_id, bytes = payload.len(), "C->ME frame"); trace!(conn_id, bytes = payload.len(), "C->ME frame");
stats.add_user_octets_from(&user, payload.len() as u64); stats.add_user_octets_from(&user, payload.len() as u64);
@ -168,11 +204,13 @@ async fn read_client_payload<R>(
proto_tag: ProtoTag, proto_tag: ProtoTag,
max_frame: usize, max_frame: usize,
user: &str, user: &str,
frame_counter: &mut u64,
) -> Result<Option<(Vec<u8>, bool)>> ) -> Result<Option<(Vec<u8>, bool)>>
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
{ {
let (len, quickack) = match proto_tag { loop {
let (len, quickack, raw_len_bytes) = match proto_tag {
ProtoTag::Abridged => { ProtoTag::Abridged => {
let mut first = [0u8; 1]; let mut first = [0u8; 1];
match client_reader.read_exact(&mut first).await { match client_reader.read_exact(&mut first).await {
@ -196,7 +234,7 @@ where
let len = len_words let len = len_words
.checked_mul(4) .checked_mul(4)
.ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?; .ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?;
(len, quickack) (len, quickack, None)
} }
ProtoTag::Intermediate | ProtoTag::Secure => { ProtoTag::Intermediate | ProtoTag::Secure => {
let mut len_buf = [0u8; 4]; let mut len_buf = [0u8; 4];
@ -206,19 +244,53 @@ where
Err(e) => return Err(ProxyError::Io(e)), Err(e) => return Err(ProxyError::Io(e)),
} }
let quickack = (len_buf[3] & 0x80) != 0; let quickack = (len_buf[3] & 0x80) != 0;
((u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize, quickack) (
(u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize,
quickack,
Some(len_buf),
)
} }
}; };
if len == 0 {
continue;
}
if len < 4 && proto_tag != ProtoTag::Abridged {
warn!(
user = %user,
len,
proto = ?proto_tag,
"Frame too small — corrupt or probe"
);
return Err(ProxyError::Proxy(format!("Frame too small: {len}")));
}
if len > max_frame { if len > max_frame {
let len_buf = raw_len_bytes.unwrap_or((len as u32).to_le_bytes());
let looks_like_tls = raw_len_bytes
.map(|b| b[0] == 0x16 && b[1] == 0x03)
.unwrap_or(false);
let looks_like_http = raw_len_bytes
.map(|b| matches!(b[0], b'G' | b'P' | b'H' | b'C' | b'D'))
.unwrap_or(false);
warn!( warn!(
user = %user, user = %user,
raw_len = len, raw_len = len,
raw_len_hex = format_args!("0x{:08x}", len), raw_len_hex = format_args!("0x{:08x}", len),
raw_bytes = format_args!(
"{:02x} {:02x} {:02x} {:02x}",
len_buf[0], len_buf[1], len_buf[2], len_buf[3]
),
proto = ?proto_tag, proto = ?proto_tag,
"Frame too large — possible crypto desync or TLS record error" tls_like = looks_like_tls,
http_like = looks_like_http,
frames_ok = *frame_counter,
"Frame too large — crypto desync forensics"
); );
return Err(ProxyError::Proxy(format!("Frame too large: {len} (max {max_frame})"))); return Err(ProxyError::Proxy(format!(
"Frame too large: {len} (max {max_frame}), frames_ok={}",
*frame_counter
)));
} }
let mut payload = vec![0u8; len]; let mut payload = vec![0u8; len];
@ -234,7 +306,9 @@ where
payload.truncate(len - rem); payload.truncate(len - rem);
} }
} }
Ok(Some((payload, quickack))) *frame_counter += 1;
return Ok(Some((payload, quickack)));
}
} }
async fn write_client_payload<W>( async fn write_client_payload<W>(
@ -264,8 +338,11 @@ where
if quickack { if quickack {
first |= 0x80; first |= 0x80;
} }
let mut frame_buf = Vec::with_capacity(1 + data.len());
frame_buf.push(first);
frame_buf.extend_from_slice(data);
client_writer client_writer
.write_all(&[first]) .write_all(&frame_buf)
.await .await
.map_err(ProxyError::Io)?; .map_err(ProxyError::Io)?;
} else if len_words < (1 << 24) { } else if len_words < (1 << 24) {
@ -274,8 +351,11 @@ where
first |= 0x80; first |= 0x80;
} }
let lw = (len_words as u32).to_le_bytes(); let lw = (len_words as u32).to_le_bytes();
let mut frame_buf = Vec::with_capacity(4 + data.len());
frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]);
frame_buf.extend_from_slice(data);
client_writer client_writer
.write_all(&[first, lw[0], lw[1], lw[2]]) .write_all(&frame_buf)
.await .await
.map_err(ProxyError::Io)?; .map_err(ProxyError::Io)?;
} else { } else {
@ -284,11 +364,6 @@ where
data.len() data.len()
))); )));
} }
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
} }
ProtoTag::Intermediate | ProtoTag::Secure => { ProtoTag::Intermediate | ProtoTag::Secure => {
let padding_len = if proto_tag == ProtoTag::Secure { let padding_len = if proto_tag == ProtoTag::Secure {
@ -296,34 +371,23 @@ where
} else { } else {
0 0
}; };
let mut len = (data.len() + padding_len) as u32; let mut len_val = (data.len() + padding_len) as u32;
if quickack { if quickack {
len |= 0x8000_0000; len_val |= 0x8000_0000;
} }
client_writer let total = 4 + data.len() + padding_len;
.write_all(&len.to_le_bytes()) let mut frame_buf = Vec::with_capacity(total);
.await frame_buf.extend_from_slice(&len_val.to_le_bytes());
.map_err(ProxyError::Io)?; frame_buf.extend_from_slice(data);
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
if padding_len > 0 { if padding_len > 0 {
let pad = rng.bytes(padding_len); frame_buf.extend_from_slice(&rng.bytes(padding_len));
}
client_writer client_writer
.write_all(&pad) .write_all(&frame_buf)
.await .await
.map_err(ProxyError::Io)?; .map_err(ProxyError::Io)?;
} }
} }
}
// Avoid unconditional per-frame flush (throughput killer on large downloads).
// Flush only when low-latency ack semantics are requested or when
// CryptoWriter has buffered pending ciphertext that must be drained.
if quickack || client_writer.has_pending() {
client_writer.flush().await.map_err(ProxyError::Io)?;
}
Ok(()) Ok(())
} }

View File

@ -8,7 +8,7 @@ use std::io::{self, Error, ErrorKind};
use std::sync::Arc; use std::sync::Arc;
use tokio_util::codec::{Decoder, Encoder}; use tokio_util::codec::{Decoder, Encoder};
use crate::protocol::constants::ProtoTag; use crate::protocol::constants::{ProtoTag, secure_padding_len};
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait}; use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
@ -303,14 +303,8 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::R
return Ok(()); return Ok(());
} }
// Generate padding to make length not divisible by 4 // Generate padding that keeps total length non-divisible by 4.
let padding_len = if data.len() % 4 == 0 { let padding_len = secure_padding_len(data.len(), rng);
// Add 1-3 bytes to make it non-aligned
(rng.range(3) + 1) as usize
} else {
// Already non-aligned, can add 0-3
rng.range(4) as usize
};
let total_len = data.len() + padding_len; let total_len = data.len() + padding_len;
dst.reserve(4 + total_len); dst.reserve(4 + total_len);

View File

@ -190,11 +190,26 @@ impl RpcWriter {
self.writer.flush().await.map_err(ProxyError::Io) self.writer.flush().await.map_err(ProxyError::Io)
} }
pub(crate) async fn send_keepalive(&mut self, payload: [u8; 4]) -> Result<()> { /// Sends a 4-byte keepalive marker directly into the CBC stream.
// Keepalive is a frame with fl == 4 and 4 bytes payload. /// This is not an RPC frame and must not consume sequence numbers.
let mut frame = Vec::with_capacity(8); pub(crate) async fn send_keepalive(&mut self) -> Result<()> {
frame.extend_from_slice(&4u32.to_le_bytes()); let mut buf = [0u8; 16];
frame.extend_from_slice(&payload); for i in 0..4 {
self.send(&frame).await let start = i * 4;
let end = start + 4;
buf[start..end].copy_from_slice(&PADDING_FILLER);
}
let cipher = AesCbc::new(self.key, self.iv);
let mut v = buf.to_vec();
cipher
.encrypt_in_place(&mut v)
.map_err(|e| ProxyError::Crypto(format!("{e}")))?;
if v.len() >= 16 {
self.iv.copy_from_slice(&v[v.len() - 16..]);
}
self.writer.write_all(&v).await.map_err(ProxyError::Io)?;
self.writer.flush().await.map_err(ProxyError::Io)
} }
} }

View File

@ -23,7 +23,6 @@ use super::reader::reader_loop;
use super::MeResponse; use super::MeResponse;
const ME_ACTIVE_PING_SECS: u64 = 25; const ME_ACTIVE_PING_SECS: u64 = 25;
const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; const ME_ACTIVE_PING_JITTER_SECS: i64 = 5;
const ME_KEEPALIVE_PAYLOAD_LEN: usize = 4;
#[derive(Clone)] #[derive(Clone)]
pub struct MeWriter { pub struct MeWriter {
@ -361,7 +360,6 @@ impl MePool {
// Additional connections up to pool_size total (round-robin across DCs), staggered to de-phase lifecycles. // Additional connections up to pool_size total (round-robin across DCs), staggered to de-phase lifecycles.
if self.me_warmup_stagger_enabled { if self.me_warmup_stagger_enabled {
let mut delay_ms = 0u64;
for (dc, addrs) in dc_addrs.iter() { for (dc, addrs) in dc_addrs.iter() {
for (ip, port) in addrs { for (ip, port) in addrs {
if self.connection_count() >= pool_size { if self.connection_count() >= pool_size {
@ -369,7 +367,7 @@ impl MePool {
} }
let addr = SocketAddr::new(*ip, *port); let addr = SocketAddr::new(*ip, *port);
let jitter = rand::rng().random_range(0..=self.me_warmup_step_jitter.as_millis() as u64); let jitter = rand::rng().random_range(0..=self.me_warmup_step_jitter.as_millis() as u64);
delay_ms = delay_ms.saturating_add(self.me_warmup_step_delay.as_millis() as u64 + jitter); let delay_ms = self.me_warmup_step_delay.as_millis() as u64 + jitter;
tokio::time::sleep(Duration::from_millis(delay_ms)).await; tokio::time::sleep(Duration::from_millis(delay_ms)).await;
if let Err(e) = self.connect_one(addr, rng.as_ref()).await { if let Err(e) = self.connect_one(addr, rng.as_ref()).await {
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)"); debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)");
@ -419,7 +417,6 @@ impl MePool {
let draining = Arc::new(AtomicBool::new(false)); let draining = Arc::new(AtomicBool::new(false));
let (tx, mut rx) = mpsc::channel::<WriterCommand>(4096); let (tx, mut rx) = mpsc::channel::<WriterCommand>(4096);
let tx_for_keepalive = tx.clone(); let tx_for_keepalive = tx.clone();
let keepalive_random = self.me_keepalive_payload_random;
let stats = self.stats.clone(); let stats = self.stats.clone();
let mut rpc_writer = RpcWriter { let mut rpc_writer = RpcWriter {
writer: hs.wr, writer: hs.wr,
@ -440,11 +437,7 @@ impl MePool {
if rpc_writer.send_and_flush(&payload).await.is_err() { break; } if rpc_writer.send_and_flush(&payload).await.is_err() { break; }
} }
Some(WriterCommand::Keepalive) => { Some(WriterCommand::Keepalive) => {
let mut payload = [0u8; ME_KEEPALIVE_PAYLOAD_LEN]; match rpc_writer.send_keepalive().await {
if keepalive_random {
rand::rng().fill(&mut payload);
}
match rpc_writer.send_keepalive(payload).await {
Ok(()) => { Ok(()) => {
stats.increment_me_keepalive_sent(); stats.increment_me_keepalive_sent();
} }

View File

@ -33,7 +33,6 @@ pub(crate) async fn reader_loop(
) -> Result<()> { ) -> Result<()> {
let mut raw = enc_leftover; let mut raw = enc_leftover;
let mut expected_seq: i32 = 0; let mut expected_seq: i32 = 0;
let mut crc_errors = 0u32;
let mut seq_mismatch = 0u32; let mut seq_mismatch = 0u32;
loop { loop {
@ -80,13 +79,15 @@ pub(crate) async fn reader_loop(
let frame = dec.split_to(fl); let frame = dec.split_to(fl);
let pe = fl - 4; let pe = fl - 4;
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
if crc32(&frame[..pe]) != ec { let actual_crc = crc32(&frame[..pe]);
warn!("CRC mismatch in data frame"); if actual_crc != ec {
crc_errors += 1; warn!(
if crc_errors > 3 { frame_len = fl,
return Err(ProxyError::Proxy("Too many CRC mismatches".into())); expected_crc = format_args!("0x{ec:08x}"),
} actual_crc = format_args!("0x{actual_crc:08x}"),
continue; "CRC mismatch — CBC crypto desync, aborting ME connection"
);
return Err(ProxyError::Proxy("CRC mismatch (crypto desync)".into()));
} }
let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap()); let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap());

View File

@ -24,6 +24,8 @@ const NUM_DCS: usize = 5;
/// Timeout for individual DC ping attempt /// Timeout for individual DC ping attempt
const DC_PING_TIMEOUT_SECS: u64 = 5; const DC_PING_TIMEOUT_SECS: u64 = 5;
/// Timeout for direct TG DC TCP connect readiness.
const DIRECT_CONNECT_TIMEOUT_SECS: u64 = 10;
// ============= RTT Tracking ============= // ============= RTT Tracking =============
@ -375,7 +377,16 @@ impl UpstreamManager {
let std_stream: std::net::TcpStream = socket.into(); let std_stream: std::net::TcpStream = socket.into();
let stream = TcpStream::from_std(std_stream)?; let stream = TcpStream::from_std(std_stream)?;
stream.writable().await?; let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS);
match tokio::time::timeout(connect_timeout, stream.writable()).await {
Ok(Ok(())) => {}
Ok(Err(e)) => return Err(ProxyError::Io(e)),
Err(_) => {
return Err(ProxyError::ConnectionTimeout {
addr: target.to_string(),
});
}
}
if let Some(e) = stream.take_error()? { if let Some(e) = stream.take_error()? {
return Err(ProxyError::Io(e)); return Err(ProxyError::Io(e));
} }