diff --git a/Cargo.toml b/Cargo.toml index f8ee25f..737312c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "telemt" -version = "3.0.8" +version = "3.0.9" edition = "2024" [dependencies] @@ -20,6 +20,7 @@ sha1 = "0.10" md-5 = "0.10" hmac = "0.12" crc32fast = "1.4" +crc32c = "0.6" zeroize = { version = "1.8", features = ["derive"] } # Network diff --git a/src/config/defaults.rs b/src/config/defaults.rs index f4180c9..2dee3e0 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -37,7 +37,7 @@ pub(crate) fn default_replay_window_secs() -> u64 { } pub(crate) fn default_handshake_timeout() -> u64 { - 30 + 15 } pub(crate) fn default_connect_timeout() -> u64 { @@ -52,11 +52,11 @@ pub(crate) fn default_ack_timeout() -> u64 { 300 } pub(crate) fn default_me_one_retry() -> u8 { - 12 + 3 } pub(crate) fn default_me_one_timeout() -> u64 { - 1200 + 1500 } pub(crate) fn default_listen_addr() -> String { @@ -83,7 +83,7 @@ pub(crate) fn default_unknown_dc_log_path() -> Option { } pub(crate) fn default_pool_size() -> usize { - 16 + 2 } pub(crate) fn default_keepalive_interval() -> u64 { @@ -207,4 +207,4 @@ where } } Ok(out) -} \ No newline at end of file +} diff --git a/src/config/types.rs b/src/config/types.rs index 193e234..503bb38 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -74,8 +74,8 @@ pub struct ProxyModes { impl Default for ProxyModes { fn default() -> Self { Self { - classic: true, - secure: true, + classic: false, + secure: false, tls: true, } } @@ -140,7 +140,7 @@ pub struct GeneralConfig { #[serde(default = "default_true")] pub fast_mode: bool, - #[serde(default = "default_true")] + #[serde(default)] pub use_middle_proxy: bool, #[serde(default)] @@ -157,7 +157,7 @@ pub struct GeneralConfig { pub middle_proxy_nat_ip: Option, /// Enable STUN-based NAT probing to discover public IP:port for ME KDF. - #[serde(default = "default_true")] + #[serde(default)] pub middle_proxy_nat_probe: bool, /// Optional STUN server address (host:port) for NAT probing. @@ -283,11 +283,11 @@ impl Default for GeneralConfig { modes: ProxyModes::default(), prefer_ipv6: false, fast_mode: true, - use_middle_proxy: true, + use_middle_proxy: false, ad_tag: None, proxy_secret_path: None, middle_proxy_nat_ip: None, - middle_proxy_nat_probe: true, + middle_proxy_nat_probe: false, middle_proxy_nat_stun: None, middle_proxy_nat_stun_servers: Vec::new(), middle_proxy_pool_size: default_pool_size(), @@ -299,10 +299,10 @@ impl Default for GeneralConfig { me_warmup_stagger_enabled: true, me_warmup_step_delay_ms: default_warmup_step_delay_ms(), me_warmup_step_jitter_ms: default_warmup_step_jitter_ms(), - me_reconnect_max_concurrent_per_dc: 1, + me_reconnect_max_concurrent_per_dc: 4, me_reconnect_backoff_base_ms: default_reconnect_backoff_base_ms(), me_reconnect_backoff_cap_ms: default_reconnect_backoff_cap_ms(), - me_reconnect_fast_retry_count: 11, + me_reconnect_fast_retry_count: 8, stun_iface_mismatch_ignore: false, unknown_dc_log_path: default_unknown_dc_log_path(), log_level: LogLevel::Normal, @@ -455,7 +455,7 @@ pub struct AntiCensorshipConfig { pub fake_cert_len: usize, /// Enable TLS certificate emulation using cached real certificates. - #[serde(default = "default_true")] + #[serde(default)] pub tls_emulation: bool, /// Directory to store TLS front cache (on disk). @@ -489,7 +489,7 @@ impl Default for AntiCensorshipConfig { mask_port: default_mask_port(), mask_unix_sock: None, fake_cert_len: default_fake_cert_len(), - tls_emulation: true, + tls_emulation: false, tls_front_dir: default_tls_front_dir(), server_hello_delay_min_ms: default_server_hello_delay_min_ms(), server_hello_delay_max_ms: default_server_hello_delay_max_ms(), @@ -619,9 +619,9 @@ pub struct ListenerConfig { /// - omitted — show no links (default) #[derive(Debug, Clone)] pub enum ShowLink { - /// Don't show any links. + /// Don't show any links (default when omitted). None, - /// Show links for all configured users (default). + /// Show links for all configured users. All, /// Show links for specific users. Specific(Vec), @@ -629,7 +629,7 @@ pub enum ShowLink { impl Default for ShowLink { fn default() -> Self { - ShowLink::All + ShowLink::None } } diff --git a/src/crypto/hash.rs b/src/crypto/hash.rs index 1586e50..d3f6f55 100644 --- a/src/crypto/hash.rs +++ b/src/crypto/hash.rs @@ -55,6 +55,11 @@ pub fn crc32(data: &[u8]) -> u32 { crc32fast::hash(data) } +/// CRC32C (Castagnoli) +pub fn crc32c(data: &[u8]) -> u32 { + crc32c::crc32c(data) +} + /// Build the exact prekey buffer used by Telegram Middle Proxy KDF. /// /// Returned buffer layout (IPv4): diff --git a/src/crypto/mod.rs b/src/crypto/mod.rs index 40951c6..266a3cb 100644 --- a/src/crypto/mod.rs +++ b/src/crypto/mod.rs @@ -5,5 +5,8 @@ pub mod hash; pub mod random; pub use aes::{AesCtr, AesCbc}; -pub use hash::{sha256, sha256_hmac, sha1, md5, crc32, derive_middleproxy_keys, build_middleproxy_prekey}; +pub use hash::{ + build_middleproxy_prekey, crc32, crc32c, derive_middleproxy_keys, md5, sha1, sha256, + sha256_hmac, +}; pub use random::SecureRandom; diff --git a/src/metrics.rs b/src/metrics.rs index 940a0d8..e00091f 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -100,6 +100,14 @@ fn render_metrics(stats: &Stats) -> String { let _ = writeln!(out, "# TYPE telemt_me_keepalive_failed_total counter"); let _ = writeln!(out, "telemt_me_keepalive_failed_total {}", stats.get_me_keepalive_failed()); + let _ = writeln!(out, "# HELP telemt_me_keepalive_pong_total ME keepalive pong replies"); + let _ = writeln!(out, "# TYPE telemt_me_keepalive_pong_total counter"); + let _ = writeln!(out, "telemt_me_keepalive_pong_total {}", stats.get_me_keepalive_pong()); + + let _ = writeln!(out, "# HELP telemt_me_keepalive_timeout_total ME keepalive ping timeouts"); + let _ = writeln!(out, "# TYPE telemt_me_keepalive_timeout_total counter"); + let _ = writeln!(out, "telemt_me_keepalive_timeout_total {}", stats.get_me_keepalive_timeout()); + let _ = writeln!(out, "# HELP telemt_me_reconnect_attempts_total ME reconnect attempts"); let _ = writeln!(out, "# TYPE telemt_me_reconnect_attempts_total counter"); let _ = writeln!(out, "telemt_me_reconnect_attempts_total {}", stats.get_me_reconnect_attempts()); @@ -108,6 +116,30 @@ fn render_metrics(stats: &Stats) -> String { let _ = writeln!(out, "# TYPE telemt_me_reconnect_success_total counter"); let _ = writeln!(out, "telemt_me_reconnect_success_total {}", stats.get_me_reconnect_success()); + let _ = writeln!(out, "# HELP telemt_me_crc_mismatch_total ME CRC mismatches"); + let _ = writeln!(out, "# TYPE telemt_me_crc_mismatch_total counter"); + let _ = writeln!(out, "telemt_me_crc_mismatch_total {}", stats.get_me_crc_mismatch()); + + let _ = writeln!(out, "# HELP telemt_me_seq_mismatch_total ME sequence mismatches"); + let _ = writeln!(out, "# TYPE telemt_me_seq_mismatch_total counter"); + let _ = writeln!(out, "telemt_me_seq_mismatch_total {}", stats.get_me_seq_mismatch()); + + let _ = writeln!(out, "# HELP telemt_me_route_drop_no_conn_total ME route drops: no conn"); + let _ = writeln!(out, "# TYPE telemt_me_route_drop_no_conn_total counter"); + let _ = writeln!(out, "telemt_me_route_drop_no_conn_total {}", stats.get_me_route_drop_no_conn()); + + let _ = writeln!(out, "# HELP telemt_me_route_drop_channel_closed_total ME route drops: channel closed"); + let _ = writeln!(out, "# TYPE telemt_me_route_drop_channel_closed_total counter"); + let _ = writeln!(out, "telemt_me_route_drop_channel_closed_total {}", stats.get_me_route_drop_channel_closed()); + + let _ = writeln!(out, "# HELP telemt_me_route_drop_queue_full_total ME route drops: queue full"); + let _ = writeln!(out, "# TYPE telemt_me_route_drop_queue_full_total counter"); + let _ = writeln!(out, "telemt_me_route_drop_queue_full_total {}", stats.get_me_route_drop_queue_full()); + + let _ = writeln!(out, "# HELP telemt_secure_padding_invalid_total Invalid secure frame lengths"); + let _ = writeln!(out, "# TYPE telemt_secure_padding_invalid_total counter"); + let _ = writeln!(out, "telemt_secure_padding_invalid_total {}", stats.get_secure_padding_invalid()); + let _ = writeln!(out, "# HELP telemt_user_connections_total Per-user total connections"); let _ = writeln!(out, "# TYPE telemt_user_connections_total counter"); let _ = writeln!(out, "# HELP telemt_user_connections_current Per-user active connections"); diff --git a/src/protocol/constants.rs b/src/protocol/constants.rs index 9cbe633..c930a1b 100644 --- a/src/protocol/constants.rs +++ b/src/protocol/constants.rs @@ -156,17 +156,28 @@ pub const MAX_TLS_RECORD_SIZE: usize = 16384; /// RFC 8446 §5.2 allows up to 16384 + 256 bytes of ciphertext pub const MAX_TLS_CHUNK_SIZE: usize = 16384 + 256; -/// Generate padding length for Secure Intermediate protocol. -/// Total (data + padding) must not be divisible by 4 per MTProto spec. -pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize { - let rem = data_len % 4; - match rem { - 0 => (rng.range(3) + 1) as usize, // {1, 2, 3} - 1 => rng.range(3) as usize, // {0, 1, 2} - 2 => [0usize, 1, 3][rng.range(3) as usize], // {0, 1, 3} - 3 => [0usize, 2, 3][rng.range(3) as usize], // {0, 2, 3} - _ => unreachable!(), +/// Secure Intermediate payload is expected to be 4-byte aligned. +pub fn is_valid_secure_payload_len(data_len: usize) -> bool { + data_len % 4 == 0 +} + +/// Compute Secure Intermediate payload length from wire length. +/// Secure mode strips up to 3 random tail bytes by truncating to 4-byte boundary. +pub fn secure_payload_len_from_wire_len(wire_len: usize) -> Option { + if wire_len < 4 { + return None; } + Some(wire_len - (wire_len % 4)) +} + +/// Generate padding length for Secure Intermediate protocol. +/// Data must be 4-byte aligned; padding is 1..=3 so total is never divisible by 4. +pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize { + debug_assert!( + is_valid_secure_payload_len(data_len), + "Secure payload must be 4-byte aligned, got {data_len}" + ); + (rng.range(3) + 1) as usize } // ============= Timeouts ============= @@ -300,6 +311,10 @@ pub mod rpc_flags { pub const FLAG_ABRIDGED: u32 = 0x40000000; pub const FLAG_QUICKACK: u32 = 0x80000000; } + + pub mod rpc_crypto_flags { + pub const USE_CRC32C: u32 = 0x800; + } pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5; pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10; @@ -339,7 +354,7 @@ mod tests { #[test] fn secure_padding_never_produces_aligned_total() { let rng = SecureRandom::new(); - for data_len in 0..1000 { + for data_len in (0..1000).step_by(4) { for _ in 0..100 { let padding = secure_padding_len(data_len, &rng); assert!( @@ -355,4 +370,23 @@ mod tests { } } } + + #[test] + fn secure_wire_len_roundtrip_for_aligned_payload() { + for payload_len in (4..4096).step_by(4) { + for padding in 0..=3usize { + let wire_len = payload_len + padding; + let recovered = secure_payload_len_from_wire_len(wire_len); + assert_eq!(recovered, Some(payload_len)); + } + } + } + + #[test] + fn secure_wire_len_rejects_too_short_frames() { + assert_eq!(secure_payload_len_from_wire_len(0), None); + assert_eq!(secure_payload_len_from_wire_len(1), None); + assert_eq!(secure_payload_len_from_wire_len(2), None); + assert_eq!(secure_payload_len_from_wire_len(3), None); + } } diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 8d48c8b..d060dc7 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -253,7 +253,11 @@ where let mode_ok = match proto_tag { ProtoTag::Secure => { - if is_tls { config.general.modes.tls } else { config.general.modes.secure } + if is_tls { + config.general.modes.tls || config.general.modes.secure + } else { + config.general.modes.secure || config.general.modes.tls + } } ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic, }; diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index b7f22ae..7b97049 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -165,6 +165,7 @@ where frame_limit, &user, &mut frame_counter, + &stats, ).await { Ok(Some((payload, quickack))) => { trace!(conn_id, bytes = payload.len(), "C->ME frame"); @@ -238,6 +239,7 @@ async fn read_client_payload( max_frame: usize, user: &str, frame_counter: &mut u64, + stats: &Stats, ) -> Result, bool)>> where R: AsyncRead + Unpin + Send + 'static, @@ -326,18 +328,29 @@ where ))); } + let secure_payload_len = if proto_tag == ProtoTag::Secure { + match secure_payload_len_from_wire_len(len) { + Some(payload_len) => payload_len, + None => { + stats.increment_secure_padding_invalid(); + return Err(ProxyError::Proxy(format!( + "Invalid secure frame length: {len}" + ))); + } + } + } else { + len + }; + let mut payload = vec![0u8; len]; client_reader .read_exact(&mut payload) .await .map_err(ProxyError::Io)?; - // Secure Intermediate: remove random padding (last len%4 bytes) + // Secure Intermediate: strip validated trailing padding bytes. if proto_tag == ProtoTag::Secure { - let rem = len % 4; - if rem != 0 && payload.len() >= rem { - payload.truncate(len - rem); - } + payload.truncate(secure_payload_len); } *frame_counter += 1; return Ok(Some((payload, quickack))); @@ -400,6 +413,12 @@ where } ProtoTag::Intermediate | ProtoTag::Secure => { let padding_len = if proto_tag == ProtoTag::Secure { + if !is_valid_secure_payload_len(data.len()) { + return Err(ProxyError::Proxy(format!( + "Secure payload must be 4-byte aligned, got {}", + data.len() + ))); + } secure_padding_len(data.len(), rng) } else { 0 diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 38318cc..e480ec6 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -21,8 +21,16 @@ pub struct Stats { handshake_timeouts: AtomicU64, me_keepalive_sent: AtomicU64, me_keepalive_failed: AtomicU64, + me_keepalive_pong: AtomicU64, + me_keepalive_timeout: AtomicU64, me_reconnect_attempts: AtomicU64, me_reconnect_success: AtomicU64, + me_crc_mismatch: AtomicU64, + me_seq_mismatch: AtomicU64, + me_route_drop_no_conn: AtomicU64, + me_route_drop_channel_closed: AtomicU64, + me_route_drop_queue_full: AtomicU64, + secure_padding_invalid: AtomicU64, user_stats: DashMap, start_time: parking_lot::RwLock>, } @@ -49,14 +57,45 @@ impl Stats { pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); } pub fn increment_me_keepalive_sent(&self) { self.me_keepalive_sent.fetch_add(1, Ordering::Relaxed); } pub fn increment_me_keepalive_failed(&self) { self.me_keepalive_failed.fetch_add(1, Ordering::Relaxed); } + pub fn increment_me_keepalive_pong(&self) { self.me_keepalive_pong.fetch_add(1, Ordering::Relaxed); } + pub fn increment_me_keepalive_timeout(&self) { self.me_keepalive_timeout.fetch_add(1, Ordering::Relaxed); } + pub fn increment_me_keepalive_timeout_by(&self, value: u64) { + self.me_keepalive_timeout.fetch_add(value, Ordering::Relaxed); + } pub fn increment_me_reconnect_attempt(&self) { self.me_reconnect_attempts.fetch_add(1, Ordering::Relaxed); } pub fn increment_me_reconnect_success(&self) { self.me_reconnect_success.fetch_add(1, Ordering::Relaxed); } + pub fn increment_me_crc_mismatch(&self) { self.me_crc_mismatch.fetch_add(1, Ordering::Relaxed); } + pub fn increment_me_seq_mismatch(&self) { self.me_seq_mismatch.fetch_add(1, Ordering::Relaxed); } + pub fn increment_me_route_drop_no_conn(&self) { self.me_route_drop_no_conn.fetch_add(1, Ordering::Relaxed); } + pub fn increment_me_route_drop_channel_closed(&self) { + self.me_route_drop_channel_closed.fetch_add(1, Ordering::Relaxed); + } + pub fn increment_me_route_drop_queue_full(&self) { + self.me_route_drop_queue_full.fetch_add(1, Ordering::Relaxed); + } + pub fn increment_secure_padding_invalid(&self) { + self.secure_padding_invalid.fetch_add(1, Ordering::Relaxed); + } pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) } pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) } pub fn get_me_keepalive_sent(&self) -> u64 { self.me_keepalive_sent.load(Ordering::Relaxed) } pub fn get_me_keepalive_failed(&self) -> u64 { self.me_keepalive_failed.load(Ordering::Relaxed) } + pub fn get_me_keepalive_pong(&self) -> u64 { self.me_keepalive_pong.load(Ordering::Relaxed) } + pub fn get_me_keepalive_timeout(&self) -> u64 { self.me_keepalive_timeout.load(Ordering::Relaxed) } pub fn get_me_reconnect_attempts(&self) -> u64 { self.me_reconnect_attempts.load(Ordering::Relaxed) } pub fn get_me_reconnect_success(&self) -> u64 { self.me_reconnect_success.load(Ordering::Relaxed) } + pub fn get_me_crc_mismatch(&self) -> u64 { self.me_crc_mismatch.load(Ordering::Relaxed) } + pub fn get_me_seq_mismatch(&self) -> u64 { self.me_seq_mismatch.load(Ordering::Relaxed) } + pub fn get_me_route_drop_no_conn(&self) -> u64 { self.me_route_drop_no_conn.load(Ordering::Relaxed) } + pub fn get_me_route_drop_channel_closed(&self) -> u64 { + self.me_route_drop_channel_closed.load(Ordering::Relaxed) + } + pub fn get_me_route_drop_queue_full(&self) -> u64 { + self.me_route_drop_queue_full.load(Ordering::Relaxed) + } + pub fn get_secure_padding_invalid(&self) -> u64 { + self.secure_padding_invalid.load(Ordering::Relaxed) + } pub fn increment_user_connects(&self, user: &str) { self.user_stats.entry(user.to_string()).or_default() diff --git a/src/stream/frame_codec.rs b/src/stream/frame_codec.rs index 7547ae6..6b90892 100644 --- a/src/stream/frame_codec.rs +++ b/src/stream/frame_codec.rs @@ -8,7 +8,9 @@ use std::io::{self, Error, ErrorKind}; use std::sync::Arc; use tokio_util::codec::{Decoder, Encoder}; -use crate::protocol::constants::{ProtoTag, secure_padding_len}; +use crate::protocol::constants::{ + ProtoTag, is_valid_secure_payload_len, secure_padding_len, secure_payload_len_from_wire_len, +}; use crate::crypto::SecureRandom; use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait}; @@ -274,13 +276,13 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result io::R return Ok(()); } + if !is_valid_secure_payload_len(data.len()) { + return Err(Error::new( + ErrorKind::InvalidData, + format!("secure payload must be 4-byte aligned, got {}", data.len()), + )); + } + // Generate padding that keeps total length non-divisible by 4. let padding_len = secure_padding_len(data.len(), rng); diff --git a/src/stream/frame_stream.rs b/src/stream/frame_stream.rs index 1ea6d1b..1726a06 100644 --- a/src/stream/frame_stream.rs +++ b/src/stream/frame_stream.rs @@ -232,11 +232,13 @@ impl SecureIntermediateFrameReader { let mut data = vec![0u8; len]; self.upstream.read_exact(&mut data).await?; - // Strip padding (not aligned to 4) - if len % 4 != 0 { - let actual_len = len - (len % 4); - data.truncate(actual_len); - } + let payload_len = secure_payload_len_from_wire_len(len).ok_or_else(|| { + Error::new( + ErrorKind::InvalidData, + format!("Invalid secure frame length: {len}"), + ) + })?; + data.truncate(payload_len); Ok((Bytes::from(data), meta)) } @@ -267,6 +269,13 @@ impl SecureIntermediateFrameWriter { return Ok(()); } + if !is_valid_secure_payload_len(data.len()) { + return Err(Error::new( + ErrorKind::InvalidData, + format!("Secure payload must be 4-byte aligned, got {}", data.len()), + )); + } + // Add padding so total length is never divisible by 4 (MTProto Secure) let padding_len = secure_padding_len(data.len(), &self.rng); let padding = self.rng.bytes(padding_len); @@ -550,9 +559,7 @@ mod tests { writer.flush().await.unwrap(); let (received, _meta) = reader.read_frame().await.unwrap(); - // Received should have padding stripped to align to 4 - let expected_len = (data.len() / 4) * 4; - assert_eq!(received.len(), expected_len); + assert_eq!(received.len(), data.len()); } #[tokio::test] diff --git a/src/transport/middle_proxy/codec.rs b/src/transport/middle_proxy/codec.rs index 1dccede..dd9589e 100644 --- a/src/transport/middle_proxy/codec.rs +++ b/src/transport/middle_proxy/codec.rs @@ -1,6 +1,6 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use crate::crypto::{AesCbc, crc32}; +use crate::crypto::{AesCbc, crc32, crc32c}; use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; @@ -8,17 +8,46 @@ use crate::protocol::constants::*; pub(crate) enum WriterCommand { Data(Vec), DataAndFlush(Vec), - Keepalive, Close, } -pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec { +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum RpcChecksumMode { + Crc32, + Crc32c, +} + +impl RpcChecksumMode { + pub(crate) fn from_handshake_flags(flags: u32) -> Self { + if (flags & rpc_crypto_flags::USE_CRC32C) != 0 { + Self::Crc32c + } else { + Self::Crc32 + } + } + + pub(crate) fn advertised_flags(self) -> u32 { + match self { + Self::Crc32 => 0, + Self::Crc32c => rpc_crypto_flags::USE_CRC32C, + } + } +} + +pub(crate) fn rpc_crc(mode: RpcChecksumMode, data: &[u8]) -> u32 { + match mode { + RpcChecksumMode::Crc32 => crc32(data), + RpcChecksumMode::Crc32c => crc32c(data), + } +} + +pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8], crc_mode: RpcChecksumMode) -> Vec { let total_len = (4 + 4 + payload.len() + 4) as u32; let mut frame = Vec::with_capacity(total_len as usize); frame.extend_from_slice(&total_len.to_le_bytes()); frame.extend_from_slice(&seq_no.to_le_bytes()); frame.extend_from_slice(payload); - let c = crc32(&frame); + let c = rpc_crc(crc_mode, &frame); frame.extend_from_slice(&c.to_le_bytes()); frame } @@ -45,7 +74,7 @@ pub(crate) async fn read_rpc_frame_plaintext( let crc_offset = total_len - 4; let expected_crc = u32::from_le_bytes(full[crc_offset..crc_offset + 4].try_into().unwrap()); - let actual_crc = crc32(&full[..crc_offset]); + let actual_crc = rpc_crc(RpcChecksumMode::Crc32, &full[..crc_offset]); if expected_crc != actual_crc { return Err(ProxyError::InvalidHandshake(format!( "CRC mismatch: 0x{expected_crc:08x} vs 0x{actual_crc:08x}" @@ -95,24 +124,52 @@ pub(crate) fn build_handshake_payload( our_port: u16, peer_ip: [u8; 4], peer_port: u16, + flags: u32, ) -> [u8; 32] { let mut p = [0u8; 32]; p[0..4].copy_from_slice(&RPC_HANDSHAKE_U32.to_le_bytes()); + p[4..8].copy_from_slice(&flags.to_le_bytes()); - // Keep C memory layout compatibility for PID IPv4 bytes. + // process_id sender_pid p[8..12].copy_from_slice(&our_ip); p[12..14].copy_from_slice(&our_port.to_le_bytes()); - let pid = (std::process::id() & 0xffff) as u16; - p[14..16].copy_from_slice(&pid.to_le_bytes()); + p[14..16].copy_from_slice(&process_pid16().to_le_bytes()); + p[16..20].copy_from_slice(&process_utime().to_le_bytes()); + + // process_id peer_pid + p[20..24].copy_from_slice(&peer_ip); + p[24..26].copy_from_slice(&peer_port.to_le_bytes()); + p[26..28].copy_from_slice(&0u16.to_le_bytes()); + p[28..32].copy_from_slice(&0u32.to_le_bytes()); + p +} + +pub(crate) fn parse_handshake_flags(payload: &[u8]) -> Result { + if payload.len() != 32 { + return Err(ProxyError::InvalidHandshake(format!( + "Bad handshake payload len: {}", + payload.len() + ))); + } + let hs_type = u32::from_le_bytes(payload[0..4].try_into().unwrap()); + if hs_type != RPC_HANDSHAKE_U32 { + return Err(ProxyError::InvalidHandshake(format!( + "Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}" + ))); + } + Ok(u32::from_le_bytes(payload[4..8].try_into().unwrap())) +} + +fn process_pid16() -> u16 { + (std::process::id() & 0xffff) as u16 +} + +fn process_utime() -> u32 { let utime = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs() as u32; - p[16..20].copy_from_slice(&utime.to_le_bytes()); - - p[20..24].copy_from_slice(&peer_ip); - p[24..26].copy_from_slice(&peer_port.to_le_bytes()); - p + utime } pub(crate) fn cbc_encrypt_padded( @@ -160,11 +217,12 @@ pub(crate) struct RpcWriter { pub(crate) key: [u8; 32], pub(crate) iv: [u8; 16], pub(crate) seq_no: i32, + pub(crate) crc_mode: RpcChecksumMode, } impl RpcWriter { pub(crate) async fn send(&mut self, payload: &[u8]) -> Result<()> { - let frame = build_rpc_frame(self.seq_no, payload); + let frame = build_rpc_frame(self.seq_no, payload, self.crc_mode); self.seq_no += 1; let pad = (16 - (frame.len() % 16)) % 16; @@ -189,27 +247,4 @@ impl RpcWriter { self.send(payload).await?; self.writer.flush().await.map_err(ProxyError::Io) } - - /// Sends a 4-byte keepalive marker directly into the CBC stream. - /// This is not an RPC frame and must not consume sequence numbers. - pub(crate) async fn send_keepalive(&mut self) -> Result<()> { - let mut buf = [0u8; 16]; - for i in 0..4 { - let start = i * 4; - let end = start + 4; - buf[start..end].copy_from_slice(&PADDING_FILLER); - } - - let cipher = AesCbc::new(self.key, self.iv); - let mut v = buf.to_vec(); - cipher - .encrypt_in_place(&mut v) - .map_err(|e| ProxyError::Crypto(format!("{e}")))?; - - if v.len() >= 16 { - self.iv.copy_from_slice(&v[v.len() - 16..]); - } - self.writer.write_all(&v).await.map_err(ProxyError::Io)?; - self.writer.flush().await.map_err(ProxyError::Io) - } } diff --git a/src/transport/middle_proxy/handshake.rs b/src/transport/middle_proxy/handshake.rs index 6814371..95a9d6e 100644 --- a/src/transport/middle_proxy/handshake.rs +++ b/src/transport/middle_proxy/handshake.rs @@ -18,13 +18,14 @@ use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_k use crate::error::{ProxyError, Result}; use crate::network::IpFamily; use crate::protocol::constants::{ - ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32, RPC_HANDSHAKE_ERROR_U32, - RPC_HANDSHAKE_U32, RPC_PING_U32, RPC_PONG_U32, RPC_NONCE_U32, + ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32, + RPC_HANDSHAKE_ERROR_U32, rpc_crypto_flags, }; use super::codec::{ - build_handshake_payload, build_nonce_payload, build_rpc_frame, cbc_decrypt_inplace, - cbc_encrypt_padded, parse_nonce_payload, read_rpc_frame_plaintext, + RpcChecksumMode, build_handshake_payload, build_nonce_payload, build_rpc_frame, + cbc_decrypt_inplace, cbc_encrypt_padded, parse_handshake_flags, parse_nonce_payload, + read_rpc_frame_plaintext, rpc_crc, }; use super::wire::{extract_ip_material, IpMaterial}; use super::MePool; @@ -37,6 +38,7 @@ pub(crate) struct HandshakeOutput { pub read_iv: [u8; 16], pub write_key: [u8; 32], pub write_iv: [u8; 16], + pub crc_mode: RpcChecksumMode, pub handshake_ms: f64, } @@ -146,7 +148,7 @@ impl MePool { let ks = self.key_selector().await; let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce); - let nonce_frame = build_rpc_frame(-2, &nonce_payload); + let nonce_frame = build_rpc_frame(-2, &nonce_payload, RpcChecksumMode::Crc32); let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]); debug!( key_selector = format_args!("0x{ks:08x}"), @@ -284,8 +286,15 @@ impl MePool { srv_v6_opt.as_ref(), ); - let hs_payload = build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port()); - let hs_frame = build_rpc_frame(-1, &hs_payload); + let requested_crc_mode = RpcChecksumMode::Crc32c; + let hs_payload = build_handshake_payload( + hs_our_ip, + local_addr.port(), + hs_peer_ip, + peer_addr.port(), + requested_crc_mode.advertised_flags(), + ); + let hs_frame = build_rpc_frame(-1, &hs_payload, RpcChecksumMode::Crc32); if diag_level >= 1 { info!( write_key = %hex_dump(&wk), @@ -314,7 +323,7 @@ impl MePool { ); } - let (encrypted_hs, mut write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; + let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; if diag_level >= 1 { info!( hs_cipher = %hex_dump(&encrypted_hs), @@ -328,6 +337,7 @@ impl MePool { let mut enc_buf = BytesMut::with_capacity(256); let mut dec_buf = BytesMut::with_capacity(256); let mut read_iv = ri; + let mut negotiated_crc_mode = RpcChecksumMode::Crc32; let mut handshake_ok = false; while Instant::now() < deadline && !handshake_ok { @@ -375,17 +385,23 @@ impl MePool { let frame = dec_buf.split_to(fl); let pe = fl - 4; let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); - let ac = crate::crypto::crc32(&frame[..pe]); + let ac = rpc_crc(RpcChecksumMode::Crc32, &frame[..pe]); if ec != ac { return Err(ProxyError::InvalidHandshake(format!( "HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}" ))); } - let hs_type = u32::from_le_bytes(frame[8..12].try_into().unwrap()); + let hs_payload = &frame[8..pe]; + if hs_payload.len() < 4 { + return Err(ProxyError::InvalidHandshake( + "Handshake payload too short".to_string(), + )); + } + let hs_type = u32::from_le_bytes(hs_payload[0..4].try_into().unwrap()); if hs_type == RPC_HANDSHAKE_ERROR_U32 { - let err_code = if frame.len() >= 16 { - i32::from_le_bytes(frame[12..16].try_into().unwrap()) + let err_code = if hs_payload.len() >= 8 { + i32::from_le_bytes(hs_payload[4..8].try_into().unwrap()) } else { -1 }; @@ -393,11 +409,21 @@ impl MePool { "ME rejected handshake (error={err_code})" ))); } - if hs_type != RPC_HANDSHAKE_U32 { + let hs_flags = parse_handshake_flags(hs_payload)?; + if hs_flags & 0xff != 0 { return Err(ProxyError::InvalidHandshake(format!( - "Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}" + "Unsupported handshake flags: 0x{hs_flags:08x}" ))); } + negotiated_crc_mode = if (hs_flags & requested_crc_mode.advertised_flags()) != 0 { + RpcChecksumMode::from_handshake_flags(hs_flags) + } else if (hs_flags & rpc_crypto_flags::USE_CRC32C) != 0 { + return Err(ProxyError::InvalidHandshake(format!( + "Peer negotiated unsupported CRC flags: 0x{hs_flags:08x}" + ))); + } else { + RpcChecksumMode::Crc32 + }; handshake_ok = true; break; @@ -418,6 +444,7 @@ impl MePool { read_iv, write_key: wk, write_iv, + crc_mode: negotiated_crc_mode, handshake_ms, }) } diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index f65edd6..062db67 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -17,10 +17,9 @@ use crate::network::IpFamily; use crate::protocol::constants::*; use super::ConnRegistry; -use super::registry::{BoundConn, ConnMeta}; +use super::registry::BoundConn; use super::codec::{RpcWriter, WriterCommand}; use super::reader::reader_loop; -use super::MeResponse; const ME_ACTIVE_PING_SECS: u64 = 25; const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; @@ -417,12 +416,12 @@ impl MePool { let draining = Arc::new(AtomicBool::new(false)); let (tx, mut rx) = mpsc::channel::(4096); let tx_for_keepalive = tx.clone(); - let stats = self.stats.clone(); let mut rpc_writer = RpcWriter { writer: hs.wr, key: hs.write_key, iv: hs.write_iv, seq_no: 0, + crc_mode: hs.crc_mode, }; let cancel_wr = cancel.clone(); tokio::spawn(async move { @@ -436,17 +435,6 @@ impl MePool { Some(WriterCommand::DataAndFlush(payload)) => { if rpc_writer.send_and_flush(&payload).await.is_err() { break; } } - Some(WriterCommand::Keepalive) => { - match rpc_writer.send_keepalive().await { - Ok(()) => { - stats.increment_me_keepalive_sent(); - } - Err(_) => { - stats.increment_me_keepalive_failed(); - break; - } - } - } Some(WriterCommand::Close) | None => break, } } @@ -469,7 +457,11 @@ impl MePool { let reg = self.registry.clone(); let writers_arc = self.writers_arc(); let ping_tracker = self.ping_tracker.clone(); + let ping_tracker_reader = ping_tracker.clone(); let rtt_stats = self.rtt_stats.clone(); + let stats_reader = self.stats.clone(); + let stats_ping = self.stats.clone(); + let stats_keepalive = self.stats.clone(); let pool = Arc::downgrade(self); let cancel_ping = cancel.clone(); let tx_ping = tx.clone(); @@ -489,12 +481,14 @@ impl MePool { hs.rd, hs.read_key, hs.read_iv, + hs.crc_mode, reg.clone(), BytesMut::new(), BytesMut::new(), tx.clone(), - ping_tracker.clone(), + ping_tracker_reader, rtt_stats.clone(), + stats_reader, writer_id, degraded.clone(), cancel_reader_token.clone(), @@ -535,7 +529,12 @@ impl MePool { p.extend_from_slice(&sent_id.to_le_bytes()); { let mut tracker = ping_tracker_ping.lock().await; + let before = tracker.len(); tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120)); + let expired = before.saturating_sub(tracker.len()); + if expired > 0 { + stats_ping.increment_me_keepalive_timeout_by(expired as u64); + } tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); } ping_id = ping_id.wrapping_add(1); @@ -558,18 +557,37 @@ impl MePool { if keepalive_enabled { let tx_keepalive = tx_for_keepalive; let cancel_keepalive = cancel_keepalive_token; + let ping_tracker_keepalive = ping_tracker.clone(); tokio::spawn(async move { // Per-writer jittered start to avoid phase sync. let jitter_cap_ms = keepalive_interval.as_millis() / 2; let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); let initial_jitter_ms = rand::rng().random_range(0..=effective_jitter_ms as u64); tokio::time::sleep(Duration::from_millis(initial_jitter_ms)).await; + let mut ping_id: i64 = rand::random::(); loop { tokio::select! { _ = cancel_keepalive.cancelled() => break, _ = tokio::time::sleep(keepalive_interval + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64))) => {} } - if tx_keepalive.send(WriterCommand::Keepalive).await.is_err() { + let sent_id = ping_id; + ping_id = ping_id.wrapping_add(1); + let mut p = Vec::with_capacity(12); + p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); + p.extend_from_slice(&sent_id.to_le_bytes()); + { + let mut tracker = ping_tracker_keepalive.lock().await; + let before = tracker.len(); + tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120)); + let expired = before.saturating_sub(tracker.len()); + if expired > 0 { + stats_keepalive.increment_me_keepalive_timeout_by(expired as u64); + } + tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); + } + stats_keepalive.increment_me_keepalive_sent(); + if tx_keepalive.send(WriterCommand::DataAndFlush(p)).await.is_err() { + stats_keepalive.increment_me_keepalive_failed(); break; } } diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index 83e4472..95bd0d8 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -10,30 +10,33 @@ use tokio::sync::{Mutex, mpsc}; use tokio_util::sync::CancellationToken; use tracing::{debug, trace, warn}; -use crate::crypto::{AesCbc, crc32}; +use crate::crypto::AesCbc; use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; +use crate::stats::Stats; -use super::codec::WriterCommand; +use super::codec::{RpcChecksumMode, WriterCommand, rpc_crc}; +use super::registry::RouteResult; use super::{ConnRegistry, MeResponse}; pub(crate) async fn reader_loop( mut rd: tokio::io::ReadHalf, dk: [u8; 32], mut div: [u8; 16], + crc_mode: RpcChecksumMode, reg: Arc, enc_leftover: BytesMut, mut dec: BytesMut, tx: mpsc::Sender, ping_tracker: Arc>>, rtt_stats: Arc>>, + stats: Arc, _writer_id: u64, degraded: Arc, cancel: CancellationToken, ) -> Result<()> { let mut raw = enc_leftover; let mut expected_seq: i32 = 0; - let mut seq_mismatch = 0u32; loop { let mut tmp = [0u8; 16_384]; @@ -79,8 +82,9 @@ pub(crate) async fn reader_loop( let frame = dec.split_to(fl); let pe = fl - 4; let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); - let actual_crc = crc32(&frame[..pe]); + let actual_crc = rpc_crc(crc_mode, &frame[..pe]); if actual_crc != ec { + stats.increment_me_crc_mismatch(); warn!( frame_len = fl, expected_crc = format_args!("0x{ec:08x}"), @@ -92,15 +96,14 @@ pub(crate) async fn reader_loop( let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap()); if seq_no != expected_seq { + stats.increment_me_seq_mismatch(); warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch"); - seq_mismatch += 1; - if seq_mismatch > 10 { - return Err(ProxyError::Proxy("Too many seq mismatches".into())); - } - expected_seq = seq_no.wrapping_add(1); - } else { - expected_seq = expected_seq.wrapping_add(1); + return Err(ProxyError::SeqNoMismatch { + expected: expected_seq, + got: seq_no, + }); } + expected_seq = expected_seq.wrapping_add(1); let payload = &frame[8..pe]; if payload.len() < 4 { @@ -117,7 +120,13 @@ pub(crate) async fn reader_loop( trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS"); let routed = reg.route(cid, MeResponse::Data { flags, data }).await; - if !routed { + if !matches!(routed, RouteResult::Routed) { + match routed { + RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), + RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(), + RouteResult::QueueFull => stats.increment_me_route_drop_queue_full(), + RouteResult::Routed => {} + } reg.unregister(cid).await; send_close_conn(&tx, cid).await; } @@ -127,7 +136,13 @@ pub(crate) async fn reader_loop( trace!(cid, cfm, "RPC_SIMPLE_ACK"); let routed = reg.route(cid, MeResponse::Ack(cfm)).await; - if !routed { + if !matches!(routed, RouteResult::Routed) { + match routed { + RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), + RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(), + RouteResult::QueueFull => stats.increment_me_route_drop_queue_full(), + RouteResult::Routed => {} + } reg.unregister(cid).await; send_close_conn(&tx, cid).await; } @@ -153,6 +168,7 @@ pub(crate) async fn reader_loop( } } else if pt == RPC_PONG_U32 && body.len() >= 8 { let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap()); + stats.increment_me_keepalive_pong(); if let Some((sent, wid)) = { let mut guard = ping_tracker.lock().await; guard.remove(&ping_id) diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index ab4f280..4b25e00 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -1,13 +1,23 @@ use std::collections::{HashMap, HashSet}; use std::net::SocketAddr; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Arc; -use tokio::sync::{mpsc, Mutex, RwLock}; +use tokio::sync::{mpsc, RwLock}; +use tokio::sync::mpsc::error::TrySendError; use super::codec::WriterCommand; use super::MeResponse; +const ROUTE_CHANNEL_CAPACITY: usize = 4096; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RouteResult { + Routed, + NoConn, + ChannelClosed, + QueueFull, +} + #[derive(Clone)] pub struct ConnMeta { pub target_dc: i16, @@ -64,7 +74,7 @@ impl ConnRegistry { pub async fn register(&self) -> (u64, mpsc::Receiver) { let id = self.next_id.fetch_add(1, Ordering::Relaxed); - let (tx, rx) = mpsc::channel(1024); + let (tx, rx) = mpsc::channel(ROUTE_CHANNEL_CAPACITY); self.inner.write().await.map.insert(id, tx); (id, rx) } @@ -83,12 +93,16 @@ impl ConnRegistry { None } - pub async fn route(&self, id: u64, resp: MeResponse) -> bool { + pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult { let inner = self.inner.read().await; if let Some(tx) = inner.map.get(&id) { - tx.try_send(resp).is_ok() + match tx.try_send(resp) { + Ok(()) => RouteResult::Routed, + Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed, + Err(TrySendError::Full(_)) => RouteResult::QueueFull, + } } else { - false + RouteResult::NoConn } }