diff --git a/src/protocol/constants.rs b/src/protocol/constants.rs index 826f2b2..9cbe633 100644 --- a/src/protocol/constants.rs +++ b/src/protocol/constants.rs @@ -159,10 +159,13 @@ pub const MAX_TLS_CHUNK_SIZE: usize = 16384 + 256; /// Generate padding length for Secure Intermediate protocol. /// Total (data + padding) must not be divisible by 4 per MTProto spec. pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize { - if data_len % 4 == 0 { - (rng.range(3) + 1) as usize // 1-3 - } else { - rng.range(4) as usize // 0-3 + let rem = data_len % 4; + match rem { + 0 => (rng.range(3) + 1) as usize, // {1, 2, 3} + 1 => rng.range(3) as usize, // {0, 1, 2} + 2 => [0usize, 1, 3][rng.range(3) as usize], // {0, 1, 3} + 3 => [0usize, 2, 3][rng.range(3) as usize], // {0, 2, 3} + _ => unreachable!(), } } @@ -332,4 +335,24 @@ mod tests { assert_eq!(TG_DATACENTERS_V4.len(), 5); assert_eq!(TG_DATACENTERS_V6.len(), 5); } + + #[test] + fn secure_padding_never_produces_aligned_total() { + let rng = SecureRandom::new(); + for data_len in 0..1000 { + for _ in 0..100 { + let padding = secure_padding_len(data_len, &rng); + assert!( + padding <= 3, + "padding out of range: data_len={data_len}, padding={padding}" + ); + assert_ne!( + (data_len + padding) % 4, + 0, + "invariant violated: data_len={data_len}, padding={padding}, total={}", + data_len + padding + ); + } + } + } } diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 0735d01..fe4e219 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -74,6 +74,34 @@ where trace!(conn_id, bytes = data.len(), flags, "ME->C data"); stats_clone.add_user_octets_to(&user_clone, data.len() as u64); write_client_payload(&mut writer, proto_tag, flags, &data, rng_clone.as_ref()).await?; + + // Drain all immediately queued ME responses and flush once. + while let Ok(next) = me_rx_task.try_recv() { + match next { + MeResponse::Data { flags, data } => { + trace!(conn_id, bytes = data.len(), flags, "ME->C data (batched)"); + stats_clone.add_user_octets_to(&user_clone, data.len() as u64); + write_client_payload( + &mut writer, + proto_tag, + flags, + &data, + rng_clone.as_ref(), + ).await?; + } + MeResponse::Ack(confirm) => { + trace!(conn_id, confirm, "ME->C quickack (batched)"); + write_client_ack(&mut writer, proto_tag, confirm).await?; + } + MeResponse::Close => { + debug!(conn_id, "ME sent close (batched)"); + let _ = writer.flush().await; + return Ok(()); + } + } + } + + writer.flush().await.map_err(ProxyError::Io)?; } Some(MeResponse::Ack(confirm)) => { trace!(conn_id, confirm, "ME->C quickack"); @@ -81,6 +109,7 @@ where } Some(MeResponse::Close) => { debug!(conn_id, "ME sent close"); + let _ = writer.flush().await; return Ok(()); } None => { @@ -99,8 +128,15 @@ where let mut main_result: Result<()> = Ok(()); let mut client_closed = false; + let mut frame_counter: u64 = 0; loop { - match read_client_payload(&mut crypto_reader, proto_tag, frame_limit, &user).await { + match read_client_payload( + &mut crypto_reader, + proto_tag, + frame_limit, + &user, + &mut frame_counter, + ).await { Ok(Some((payload, quickack))) => { trace!(conn_id, bytes = payload.len(), "C->ME frame"); stats.add_user_octets_from(&user, payload.len() as u64); @@ -168,73 +204,111 @@ async fn read_client_payload( proto_tag: ProtoTag, max_frame: usize, user: &str, + frame_counter: &mut u64, ) -> Result, bool)>> where R: AsyncRead + Unpin + Send + 'static, { - let (len, quickack) = match proto_tag { - ProtoTag::Abridged => { - let mut first = [0u8; 1]; - match client_reader.read_exact(&mut first).await { - Ok(_) => {} - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ProxyError::Io(e)), + loop { + let (len, quickack, raw_len_bytes) = match proto_tag { + ProtoTag::Abridged => { + let mut first = [0u8; 1]; + match client_reader.read_exact(&mut first).await { + Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(ProxyError::Io(e)), + } + + let quickack = (first[0] & 0x80) != 0; + let len_words = if (first[0] & 0x7f) == 0x7f { + let mut ext = [0u8; 3]; + client_reader + .read_exact(&mut ext) + .await + .map_err(ProxyError::Io)?; + u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize + } else { + (first[0] & 0x7f) as usize + }; + + let len = len_words + .checked_mul(4) + .ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?; + (len, quickack, None) } - - let quickack = (first[0] & 0x80) != 0; - let len_words = if (first[0] & 0x7f) == 0x7f { - let mut ext = [0u8; 3]; - client_reader - .read_exact(&mut ext) - .await - .map_err(ProxyError::Io)?; - u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize - } else { - (first[0] & 0x7f) as usize - }; - - let len = len_words - .checked_mul(4) - .ok_or_else(|| ProxyError::Proxy("Abridged frame length overflow".into()))?; - (len, quickack) - } - ProtoTag::Intermediate | ProtoTag::Secure => { - let mut len_buf = [0u8; 4]; - match client_reader.read_exact(&mut len_buf).await { - Ok(_) => {} - Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(ProxyError::Io(e)), + ProtoTag::Intermediate | ProtoTag::Secure => { + let mut len_buf = [0u8; 4]; + match client_reader.read_exact(&mut len_buf).await { + Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None), + Err(e) => return Err(ProxyError::Io(e)), + } + let quickack = (len_buf[3] & 0x80) != 0; + ( + (u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize, + quickack, + Some(len_buf), + ) } - let quickack = (len_buf[3] & 0x80) != 0; - ((u32::from_le_bytes(len_buf) & 0x7fff_ffff) as usize, quickack) + }; + + if len == 0 { + continue; } - }; - - if len > max_frame { - warn!( - user = %user, - raw_len = len, - raw_len_hex = format_args!("0x{:08x}", len), - proto = ?proto_tag, - "Frame too large — possible crypto desync or TLS record error" - ); - return Err(ProxyError::Proxy(format!("Frame too large: {len} (max {max_frame})"))); - } - - let mut payload = vec![0u8; len]; - client_reader - .read_exact(&mut payload) - .await - .map_err(ProxyError::Io)?; - - // Secure Intermediate: remove random padding (last len%4 bytes) - if proto_tag == ProtoTag::Secure { - let rem = len % 4; - if rem != 0 && payload.len() >= rem { - payload.truncate(len - rem); + if len < 4 && proto_tag != ProtoTag::Abridged { + warn!( + user = %user, + len, + proto = ?proto_tag, + "Frame too small — corrupt or probe" + ); + return Err(ProxyError::Proxy(format!("Frame too small: {len}"))); } + + if len > max_frame { + let len_buf = raw_len_bytes.unwrap_or((len as u32).to_le_bytes()); + let looks_like_tls = raw_len_bytes + .map(|b| b[0] == 0x16 && b[1] == 0x03) + .unwrap_or(false); + let looks_like_http = raw_len_bytes + .map(|b| matches!(b[0], b'G' | b'P' | b'H' | b'C' | b'D')) + .unwrap_or(false); + warn!( + user = %user, + raw_len = len, + raw_len_hex = format_args!("0x{:08x}", len), + raw_bytes = format_args!( + "{:02x} {:02x} {:02x} {:02x}", + len_buf[0], len_buf[1], len_buf[2], len_buf[3] + ), + proto = ?proto_tag, + tls_like = looks_like_tls, + http_like = looks_like_http, + frames_ok = *frame_counter, + "Frame too large — crypto desync forensics" + ); + return Err(ProxyError::Proxy(format!( + "Frame too large: {len} (max {max_frame}), frames_ok={}", + *frame_counter + ))); + } + + let mut payload = vec![0u8; len]; + client_reader + .read_exact(&mut payload) + .await + .map_err(ProxyError::Io)?; + + // Secure Intermediate: remove random padding (last len%4 bytes) + if proto_tag == ProtoTag::Secure { + let rem = len % 4; + if rem != 0 && payload.len() >= rem { + payload.truncate(len - rem); + } + } + *frame_counter += 1; + return Ok(Some((payload, quickack))); } - Ok(Some((payload, quickack))) } async fn write_client_payload( @@ -264,8 +338,11 @@ where if quickack { first |= 0x80; } + let mut frame_buf = Vec::with_capacity(1 + data.len()); + frame_buf.push(first); + frame_buf.extend_from_slice(data); client_writer - .write_all(&[first]) + .write_all(&frame_buf) .await .map_err(ProxyError::Io)?; } else if len_words < (1 << 24) { @@ -274,8 +351,11 @@ where first |= 0x80; } let lw = (len_words as u32).to_le_bytes(); + let mut frame_buf = Vec::with_capacity(4 + data.len()); + frame_buf.extend_from_slice(&[first, lw[0], lw[1], lw[2]]); + frame_buf.extend_from_slice(data); client_writer - .write_all(&[first, lw[0], lw[1], lw[2]]) + .write_all(&frame_buf) .await .map_err(ProxyError::Io)?; } else { @@ -284,11 +364,6 @@ where data.len() ))); } - - client_writer - .write_all(data) - .await - .map_err(ProxyError::Io)?; } ProtoTag::Intermediate | ProtoTag::Secure => { let padding_len = if proto_tag == ProtoTag::Secure { @@ -296,35 +371,24 @@ where } else { 0 }; - let mut len = (data.len() + padding_len) as u32; + let mut len_val = (data.len() + padding_len) as u32; if quickack { - len |= 0x8000_0000; + len_val |= 0x8000_0000; } - client_writer - .write_all(&len.to_le_bytes()) - .await - .map_err(ProxyError::Io)?; - client_writer - .write_all(data) - .await - .map_err(ProxyError::Io)?; + let total = 4 + data.len() + padding_len; + let mut frame_buf = Vec::with_capacity(total); + frame_buf.extend_from_slice(&len_val.to_le_bytes()); + frame_buf.extend_from_slice(data); if padding_len > 0 { - let pad = rng.bytes(padding_len); - client_writer - .write_all(&pad) - .await - .map_err(ProxyError::Io)?; + frame_buf.extend_from_slice(&rng.bytes(padding_len)); } + client_writer + .write_all(&frame_buf) + .await + .map_err(ProxyError::Io)?; } } - // Avoid unconditional per-frame flush (throughput killer on large downloads). - // Flush only when low-latency ack semantics are requested or when - // CryptoWriter has buffered pending ciphertext that must be drained. - if quickack || client_writer.has_pending() { - client_writer.flush().await.map_err(ProxyError::Io)?; - } - Ok(()) } diff --git a/src/stream/frame_codec.rs b/src/stream/frame_codec.rs index 30bcc95..7547ae6 100644 --- a/src/stream/frame_codec.rs +++ b/src/stream/frame_codec.rs @@ -8,7 +8,7 @@ use std::io::{self, Error, ErrorKind}; use std::sync::Arc; use tokio_util::codec::{Decoder, Encoder}; -use crate::protocol::constants::ProtoTag; +use crate::protocol::constants::{ProtoTag, secure_padding_len}; use crate::crypto::SecureRandom; use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait}; @@ -303,14 +303,8 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::R return Ok(()); } - // Generate padding to make length not divisible by 4 - let padding_len = if data.len() % 4 == 0 { - // Add 1-3 bytes to make it non-aligned - (rng.range(3) + 1) as usize - } else { - // Already non-aligned, can add 0-3 - rng.range(4) as usize - }; + // Generate padding that keeps total length non-divisible by 4. + let padding_len = secure_padding_len(data.len(), rng); let total_len = data.len() + padding_len; dst.reserve(4 + total_len); @@ -625,4 +619,4 @@ mod tests { let result = codec.decode(&mut buf); assert!(result.is_err()); } -} \ No newline at end of file +} diff --git a/src/transport/middle_proxy/codec.rs b/src/transport/middle_proxy/codec.rs index 82f0960..1dccede 100644 --- a/src/transport/middle_proxy/codec.rs +++ b/src/transport/middle_proxy/codec.rs @@ -190,11 +190,26 @@ impl RpcWriter { self.writer.flush().await.map_err(ProxyError::Io) } - pub(crate) async fn send_keepalive(&mut self, payload: [u8; 4]) -> Result<()> { - // Keepalive is a frame with fl == 4 and 4 bytes payload. - let mut frame = Vec::with_capacity(8); - frame.extend_from_slice(&4u32.to_le_bytes()); - frame.extend_from_slice(&payload); - self.send(&frame).await + /// Sends a 4-byte keepalive marker directly into the CBC stream. + /// This is not an RPC frame and must not consume sequence numbers. + pub(crate) async fn send_keepalive(&mut self) -> Result<()> { + let mut buf = [0u8; 16]; + for i in 0..4 { + let start = i * 4; + let end = start + 4; + buf[start..end].copy_from_slice(&PADDING_FILLER); + } + + let cipher = AesCbc::new(self.key, self.iv); + let mut v = buf.to_vec(); + cipher + .encrypt_in_place(&mut v) + .map_err(|e| ProxyError::Crypto(format!("{e}")))?; + + if v.len() >= 16 { + self.iv.copy_from_slice(&v[v.len() - 16..]); + } + self.writer.write_all(&v).await.map_err(ProxyError::Io)?; + self.writer.flush().await.map_err(ProxyError::Io) } } diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 3572671..f65edd6 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -23,7 +23,6 @@ use super::reader::reader_loop; use super::MeResponse; const ME_ACTIVE_PING_SECS: u64 = 25; const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; -const ME_KEEPALIVE_PAYLOAD_LEN: usize = 4; #[derive(Clone)] pub struct MeWriter { @@ -361,7 +360,6 @@ impl MePool { // Additional connections up to pool_size total (round-robin across DCs), staggered to de-phase lifecycles. if self.me_warmup_stagger_enabled { - let mut delay_ms = 0u64; for (dc, addrs) in dc_addrs.iter() { for (ip, port) in addrs { if self.connection_count() >= pool_size { @@ -369,7 +367,7 @@ impl MePool { } let addr = SocketAddr::new(*ip, *port); let jitter = rand::rng().random_range(0..=self.me_warmup_step_jitter.as_millis() as u64); - delay_ms = delay_ms.saturating_add(self.me_warmup_step_delay.as_millis() as u64 + jitter); + let delay_ms = self.me_warmup_step_delay.as_millis() as u64 + jitter; tokio::time::sleep(Duration::from_millis(delay_ms)).await; if let Err(e) = self.connect_one(addr, rng.as_ref()).await { debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)"); @@ -419,7 +417,6 @@ impl MePool { let draining = Arc::new(AtomicBool::new(false)); let (tx, mut rx) = mpsc::channel::(4096); let tx_for_keepalive = tx.clone(); - let keepalive_random = self.me_keepalive_payload_random; let stats = self.stats.clone(); let mut rpc_writer = RpcWriter { writer: hs.wr, @@ -440,11 +437,7 @@ impl MePool { if rpc_writer.send_and_flush(&payload).await.is_err() { break; } } Some(WriterCommand::Keepalive) => { - let mut payload = [0u8; ME_KEEPALIVE_PAYLOAD_LEN]; - if keepalive_random { - rand::rng().fill(&mut payload); - } - match rpc_writer.send_keepalive(payload).await { + match rpc_writer.send_keepalive().await { Ok(()) => { stats.increment_me_keepalive_sent(); } diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index c22ed68..83e4472 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -33,7 +33,6 @@ pub(crate) async fn reader_loop( ) -> Result<()> { let mut raw = enc_leftover; let mut expected_seq: i32 = 0; - let mut crc_errors = 0u32; let mut seq_mismatch = 0u32; loop { @@ -80,13 +79,15 @@ pub(crate) async fn reader_loop( let frame = dec.split_to(fl); let pe = fl - 4; let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); - if crc32(&frame[..pe]) != ec { - warn!("CRC mismatch in data frame"); - crc_errors += 1; - if crc_errors > 3 { - return Err(ProxyError::Proxy("Too many CRC mismatches".into())); - } - continue; + let actual_crc = crc32(&frame[..pe]); + if actual_crc != ec { + warn!( + frame_len = fl, + expected_crc = format_args!("0x{ec:08x}"), + actual_crc = format_args!("0x{actual_crc:08x}"), + "CRC mismatch — CBC crypto desync, aborting ME connection" + ); + return Err(ProxyError::Proxy("CRC mismatch (crypto desync)".into())); } let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap()); diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index 7d8927d..0f458f2 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -24,6 +24,8 @@ const NUM_DCS: usize = 5; /// Timeout for individual DC ping attempt const DC_PING_TIMEOUT_SECS: u64 = 5; +/// Timeout for direct TG DC TCP connect readiness. +const DIRECT_CONNECT_TIMEOUT_SECS: u64 = 10; // ============= RTT Tracking ============= @@ -375,7 +377,16 @@ impl UpstreamManager { let std_stream: std::net::TcpStream = socket.into(); let stream = TcpStream::from_std(std_stream)?; - stream.writable().await?; + let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS); + match tokio::time::timeout(connect_timeout, stream.writable()).await { + Ok(Ok(())) => {} + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) => { + return Err(ProxyError::ConnectionTimeout { + addr: target.to_string(), + }); + } + } if let Some(e) = stream.take_error()? { return Err(ProxyError::Io(e)); }