Merge pull request #208 from telemt/flow

Middle-End protocol hardening
This commit is contained in:
Alexey 2026-02-23 02:51:01 +03:00 committed by GitHub
commit aae3e2665e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 402 additions and 139 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "telemt"
version = "3.0.8"
version = "3.0.9"
edition = "2024"
[dependencies]
@ -20,6 +20,7 @@ sha1 = "0.10"
md-5 = "0.10"
hmac = "0.12"
crc32fast = "1.4"
crc32c = "0.6"
zeroize = { version = "1.8", features = ["derive"] }
# Network

View File

@ -37,7 +37,7 @@ pub(crate) fn default_replay_window_secs() -> u64 {
}
pub(crate) fn default_handshake_timeout() -> u64 {
30
15
}
pub(crate) fn default_connect_timeout() -> u64 {
@ -52,11 +52,11 @@ pub(crate) fn default_ack_timeout() -> u64 {
300
}
pub(crate) fn default_me_one_retry() -> u8 {
12
3
}
pub(crate) fn default_me_one_timeout() -> u64 {
1200
1500
}
pub(crate) fn default_listen_addr() -> String {
@ -83,7 +83,7 @@ pub(crate) fn default_unknown_dc_log_path() -> Option<String> {
}
pub(crate) fn default_pool_size() -> usize {
16
2
}
pub(crate) fn default_keepalive_interval() -> u64 {
@ -207,4 +207,4 @@ where
}
}
Ok(out)
}
}

View File

@ -74,8 +74,8 @@ pub struct ProxyModes {
impl Default for ProxyModes {
fn default() -> Self {
Self {
classic: true,
secure: true,
classic: false,
secure: false,
tls: true,
}
}
@ -140,7 +140,7 @@ pub struct GeneralConfig {
#[serde(default = "default_true")]
pub fast_mode: bool,
#[serde(default = "default_true")]
#[serde(default)]
pub use_middle_proxy: bool,
#[serde(default)]
@ -157,7 +157,7 @@ pub struct GeneralConfig {
pub middle_proxy_nat_ip: Option<IpAddr>,
/// Enable STUN-based NAT probing to discover public IP:port for ME KDF.
#[serde(default = "default_true")]
#[serde(default)]
pub middle_proxy_nat_probe: bool,
/// Optional STUN server address (host:port) for NAT probing.
@ -283,11 +283,11 @@ impl Default for GeneralConfig {
modes: ProxyModes::default(),
prefer_ipv6: false,
fast_mode: true,
use_middle_proxy: true,
use_middle_proxy: false,
ad_tag: None,
proxy_secret_path: None,
middle_proxy_nat_ip: None,
middle_proxy_nat_probe: true,
middle_proxy_nat_probe: false,
middle_proxy_nat_stun: None,
middle_proxy_nat_stun_servers: Vec::new(),
middle_proxy_pool_size: default_pool_size(),
@ -299,10 +299,10 @@ impl Default for GeneralConfig {
me_warmup_stagger_enabled: true,
me_warmup_step_delay_ms: default_warmup_step_delay_ms(),
me_warmup_step_jitter_ms: default_warmup_step_jitter_ms(),
me_reconnect_max_concurrent_per_dc: 1,
me_reconnect_max_concurrent_per_dc: 4,
me_reconnect_backoff_base_ms: default_reconnect_backoff_base_ms(),
me_reconnect_backoff_cap_ms: default_reconnect_backoff_cap_ms(),
me_reconnect_fast_retry_count: 11,
me_reconnect_fast_retry_count: 8,
stun_iface_mismatch_ignore: false,
unknown_dc_log_path: default_unknown_dc_log_path(),
log_level: LogLevel::Normal,
@ -455,7 +455,7 @@ pub struct AntiCensorshipConfig {
pub fake_cert_len: usize,
/// Enable TLS certificate emulation using cached real certificates.
#[serde(default = "default_true")]
#[serde(default)]
pub tls_emulation: bool,
/// Directory to store TLS front cache (on disk).
@ -489,7 +489,7 @@ impl Default for AntiCensorshipConfig {
mask_port: default_mask_port(),
mask_unix_sock: None,
fake_cert_len: default_fake_cert_len(),
tls_emulation: true,
tls_emulation: false,
tls_front_dir: default_tls_front_dir(),
server_hello_delay_min_ms: default_server_hello_delay_min_ms(),
server_hello_delay_max_ms: default_server_hello_delay_max_ms(),
@ -619,9 +619,9 @@ pub struct ListenerConfig {
/// - omitted — show no links (default)
#[derive(Debug, Clone)]
pub enum ShowLink {
/// Don't show any links.
/// Don't show any links (default when omitted).
None,
/// Show links for all configured users (default).
/// Show links for all configured users.
All,
/// Show links for specific users.
Specific(Vec<String>),
@ -629,7 +629,7 @@ pub enum ShowLink {
impl Default for ShowLink {
fn default() -> Self {
ShowLink::All
ShowLink::None
}
}

View File

@ -55,6 +55,11 @@ pub fn crc32(data: &[u8]) -> u32 {
crc32fast::hash(data)
}
/// CRC32C (Castagnoli)
pub fn crc32c(data: &[u8]) -> u32 {
crc32c::crc32c(data)
}
/// Build the exact prekey buffer used by Telegram Middle Proxy KDF.
///
/// Returned buffer layout (IPv4):

View File

@ -5,5 +5,8 @@ pub mod hash;
pub mod random;
pub use aes::{AesCtr, AesCbc};
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32, derive_middleproxy_keys, build_middleproxy_prekey};
pub use hash::{
build_middleproxy_prekey, crc32, crc32c, derive_middleproxy_keys, md5, sha1, sha256,
sha256_hmac,
};
pub use random::SecureRandom;

View File

@ -100,6 +100,14 @@ fn render_metrics(stats: &Stats) -> String {
let _ = writeln!(out, "# TYPE telemt_me_keepalive_failed_total counter");
let _ = writeln!(out, "telemt_me_keepalive_failed_total {}", stats.get_me_keepalive_failed());
let _ = writeln!(out, "# HELP telemt_me_keepalive_pong_total ME keepalive pong replies");
let _ = writeln!(out, "# TYPE telemt_me_keepalive_pong_total counter");
let _ = writeln!(out, "telemt_me_keepalive_pong_total {}", stats.get_me_keepalive_pong());
let _ = writeln!(out, "# HELP telemt_me_keepalive_timeout_total ME keepalive ping timeouts");
let _ = writeln!(out, "# TYPE telemt_me_keepalive_timeout_total counter");
let _ = writeln!(out, "telemt_me_keepalive_timeout_total {}", stats.get_me_keepalive_timeout());
let _ = writeln!(out, "# HELP telemt_me_reconnect_attempts_total ME reconnect attempts");
let _ = writeln!(out, "# TYPE telemt_me_reconnect_attempts_total counter");
let _ = writeln!(out, "telemt_me_reconnect_attempts_total {}", stats.get_me_reconnect_attempts());
@ -108,6 +116,30 @@ fn render_metrics(stats: &Stats) -> String {
let _ = writeln!(out, "# TYPE telemt_me_reconnect_success_total counter");
let _ = writeln!(out, "telemt_me_reconnect_success_total {}", stats.get_me_reconnect_success());
let _ = writeln!(out, "# HELP telemt_me_crc_mismatch_total ME CRC mismatches");
let _ = writeln!(out, "# TYPE telemt_me_crc_mismatch_total counter");
let _ = writeln!(out, "telemt_me_crc_mismatch_total {}", stats.get_me_crc_mismatch());
let _ = writeln!(out, "# HELP telemt_me_seq_mismatch_total ME sequence mismatches");
let _ = writeln!(out, "# TYPE telemt_me_seq_mismatch_total counter");
let _ = writeln!(out, "telemt_me_seq_mismatch_total {}", stats.get_me_seq_mismatch());
let _ = writeln!(out, "# HELP telemt_me_route_drop_no_conn_total ME route drops: no conn");
let _ = writeln!(out, "# TYPE telemt_me_route_drop_no_conn_total counter");
let _ = writeln!(out, "telemt_me_route_drop_no_conn_total {}", stats.get_me_route_drop_no_conn());
let _ = writeln!(out, "# HELP telemt_me_route_drop_channel_closed_total ME route drops: channel closed");
let _ = writeln!(out, "# TYPE telemt_me_route_drop_channel_closed_total counter");
let _ = writeln!(out, "telemt_me_route_drop_channel_closed_total {}", stats.get_me_route_drop_channel_closed());
let _ = writeln!(out, "# HELP telemt_me_route_drop_queue_full_total ME route drops: queue full");
let _ = writeln!(out, "# TYPE telemt_me_route_drop_queue_full_total counter");
let _ = writeln!(out, "telemt_me_route_drop_queue_full_total {}", stats.get_me_route_drop_queue_full());
let _ = writeln!(out, "# HELP telemt_secure_padding_invalid_total Invalid secure frame lengths");
let _ = writeln!(out, "# TYPE telemt_secure_padding_invalid_total counter");
let _ = writeln!(out, "telemt_secure_padding_invalid_total {}", stats.get_secure_padding_invalid());
let _ = writeln!(out, "# HELP telemt_user_connections_total Per-user total connections");
let _ = writeln!(out, "# TYPE telemt_user_connections_total counter");
let _ = writeln!(out, "# HELP telemt_user_connections_current Per-user active connections");

View File

@ -156,17 +156,28 @@ pub const MAX_TLS_RECORD_SIZE: usize = 16384;
/// RFC 8446 §5.2 allows up to 16384 + 256 bytes of ciphertext
pub const MAX_TLS_CHUNK_SIZE: usize = 16384 + 256;
/// Generate padding length for Secure Intermediate protocol.
/// Total (data + padding) must not be divisible by 4 per MTProto spec.
pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize {
let rem = data_len % 4;
match rem {
0 => (rng.range(3) + 1) as usize, // {1, 2, 3}
1 => rng.range(3) as usize, // {0, 1, 2}
2 => [0usize, 1, 3][rng.range(3) as usize], // {0, 1, 3}
3 => [0usize, 2, 3][rng.range(3) as usize], // {0, 2, 3}
_ => unreachable!(),
/// Secure Intermediate payload is expected to be 4-byte aligned.
pub fn is_valid_secure_payload_len(data_len: usize) -> bool {
data_len % 4 == 0
}
/// Compute Secure Intermediate payload length from wire length.
/// Secure mode strips up to 3 random tail bytes by truncating to 4-byte boundary.
pub fn secure_payload_len_from_wire_len(wire_len: usize) -> Option<usize> {
if wire_len < 4 {
return None;
}
Some(wire_len - (wire_len % 4))
}
/// Generate padding length for Secure Intermediate protocol.
/// Data must be 4-byte aligned; padding is 1..=3 so total is never divisible by 4.
pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize {
debug_assert!(
is_valid_secure_payload_len(data_len),
"Secure payload must be 4-byte aligned, got {data_len}"
);
(rng.range(3) + 1) as usize
}
// ============= Timeouts =============
@ -300,6 +311,10 @@ pub mod rpc_flags {
pub const FLAG_ABRIDGED: u32 = 0x40000000;
pub const FLAG_QUICKACK: u32 = 0x80000000;
}
pub mod rpc_crypto_flags {
pub const USE_CRC32C: u32 = 0x800;
}
pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5;
pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10;
@ -339,7 +354,7 @@ mod tests {
#[test]
fn secure_padding_never_produces_aligned_total() {
let rng = SecureRandom::new();
for data_len in 0..1000 {
for data_len in (0..1000).step_by(4) {
for _ in 0..100 {
let padding = secure_padding_len(data_len, &rng);
assert!(
@ -355,4 +370,23 @@ mod tests {
}
}
}
#[test]
fn secure_wire_len_roundtrip_for_aligned_payload() {
for payload_len in (4..4096).step_by(4) {
for padding in 0..=3usize {
let wire_len = payload_len + padding;
let recovered = secure_payload_len_from_wire_len(wire_len);
assert_eq!(recovered, Some(payload_len));
}
}
}
#[test]
fn secure_wire_len_rejects_too_short_frames() {
assert_eq!(secure_payload_len_from_wire_len(0), None);
assert_eq!(secure_payload_len_from_wire_len(1), None);
assert_eq!(secure_payload_len_from_wire_len(2), None);
assert_eq!(secure_payload_len_from_wire_len(3), None);
}
}

View File

@ -253,7 +253,11 @@ where
let mode_ok = match proto_tag {
ProtoTag::Secure => {
if is_tls { config.general.modes.tls } else { config.general.modes.secure }
if is_tls {
config.general.modes.tls || config.general.modes.secure
} else {
config.general.modes.secure || config.general.modes.tls
}
}
ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic,
};

View File

@ -165,6 +165,7 @@ where
frame_limit,
&user,
&mut frame_counter,
&stats,
).await {
Ok(Some((payload, quickack))) => {
trace!(conn_id, bytes = payload.len(), "C->ME frame");
@ -238,6 +239,7 @@ async fn read_client_payload<R>(
max_frame: usize,
user: &str,
frame_counter: &mut u64,
stats: &Stats,
) -> Result<Option<(Vec<u8>, bool)>>
where
R: AsyncRead + Unpin + Send + 'static,
@ -326,18 +328,29 @@ where
)));
}
let secure_payload_len = if proto_tag == ProtoTag::Secure {
match secure_payload_len_from_wire_len(len) {
Some(payload_len) => payload_len,
None => {
stats.increment_secure_padding_invalid();
return Err(ProxyError::Proxy(format!(
"Invalid secure frame length: {len}"
)));
}
}
} else {
len
};
let mut payload = vec![0u8; len];
client_reader
.read_exact(&mut payload)
.await
.map_err(ProxyError::Io)?;
// Secure Intermediate: remove random padding (last len%4 bytes)
// Secure Intermediate: strip validated trailing padding bytes.
if proto_tag == ProtoTag::Secure {
let rem = len % 4;
if rem != 0 && payload.len() >= rem {
payload.truncate(len - rem);
}
payload.truncate(secure_payload_len);
}
*frame_counter += 1;
return Ok(Some((payload, quickack)));
@ -400,6 +413,12 @@ where
}
ProtoTag::Intermediate | ProtoTag::Secure => {
let padding_len = if proto_tag == ProtoTag::Secure {
if !is_valid_secure_payload_len(data.len()) {
return Err(ProxyError::Proxy(format!(
"Secure payload must be 4-byte aligned, got {}",
data.len()
)));
}
secure_padding_len(data.len(), rng)
} else {
0

View File

@ -21,8 +21,16 @@ pub struct Stats {
handshake_timeouts: AtomicU64,
me_keepalive_sent: AtomicU64,
me_keepalive_failed: AtomicU64,
me_keepalive_pong: AtomicU64,
me_keepalive_timeout: AtomicU64,
me_reconnect_attempts: AtomicU64,
me_reconnect_success: AtomicU64,
me_crc_mismatch: AtomicU64,
me_seq_mismatch: AtomicU64,
me_route_drop_no_conn: AtomicU64,
me_route_drop_channel_closed: AtomicU64,
me_route_drop_queue_full: AtomicU64,
secure_padding_invalid: AtomicU64,
user_stats: DashMap<String, UserStats>,
start_time: parking_lot::RwLock<Option<Instant>>,
}
@ -49,14 +57,45 @@ impl Stats {
pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_keepalive_sent(&self) { self.me_keepalive_sent.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_keepalive_failed(&self) { self.me_keepalive_failed.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_keepalive_pong(&self) { self.me_keepalive_pong.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_keepalive_timeout(&self) { self.me_keepalive_timeout.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_keepalive_timeout_by(&self, value: u64) {
self.me_keepalive_timeout.fetch_add(value, Ordering::Relaxed);
}
pub fn increment_me_reconnect_attempt(&self) { self.me_reconnect_attempts.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_reconnect_success(&self) { self.me_reconnect_success.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_crc_mismatch(&self) { self.me_crc_mismatch.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_seq_mismatch(&self) { self.me_seq_mismatch.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_route_drop_no_conn(&self) { self.me_route_drop_no_conn.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_route_drop_channel_closed(&self) {
self.me_route_drop_channel_closed.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_me_route_drop_queue_full(&self) {
self.me_route_drop_queue_full.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_secure_padding_invalid(&self) {
self.secure_padding_invalid.fetch_add(1, Ordering::Relaxed);
}
pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) }
pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) }
pub fn get_me_keepalive_sent(&self) -> u64 { self.me_keepalive_sent.load(Ordering::Relaxed) }
pub fn get_me_keepalive_failed(&self) -> u64 { self.me_keepalive_failed.load(Ordering::Relaxed) }
pub fn get_me_keepalive_pong(&self) -> u64 { self.me_keepalive_pong.load(Ordering::Relaxed) }
pub fn get_me_keepalive_timeout(&self) -> u64 { self.me_keepalive_timeout.load(Ordering::Relaxed) }
pub fn get_me_reconnect_attempts(&self) -> u64 { self.me_reconnect_attempts.load(Ordering::Relaxed) }
pub fn get_me_reconnect_success(&self) -> u64 { self.me_reconnect_success.load(Ordering::Relaxed) }
pub fn get_me_crc_mismatch(&self) -> u64 { self.me_crc_mismatch.load(Ordering::Relaxed) }
pub fn get_me_seq_mismatch(&self) -> u64 { self.me_seq_mismatch.load(Ordering::Relaxed) }
pub fn get_me_route_drop_no_conn(&self) -> u64 { self.me_route_drop_no_conn.load(Ordering::Relaxed) }
pub fn get_me_route_drop_channel_closed(&self) -> u64 {
self.me_route_drop_channel_closed.load(Ordering::Relaxed)
}
pub fn get_me_route_drop_queue_full(&self) -> u64 {
self.me_route_drop_queue_full.load(Ordering::Relaxed)
}
pub fn get_secure_padding_invalid(&self) -> u64 {
self.secure_padding_invalid.load(Ordering::Relaxed)
}
pub fn increment_user_connects(&self, user: &str) {
self.user_stats.entry(user.to_string()).or_default()

View File

@ -8,7 +8,9 @@ use std::io::{self, Error, ErrorKind};
use std::sync::Arc;
use tokio_util::codec::{Decoder, Encoder};
use crate::protocol::constants::{ProtoTag, secure_padding_len};
use crate::protocol::constants::{
ProtoTag, is_valid_secure_payload_len, secure_padding_len, secure_payload_len_from_wire_len,
};
use crate::crypto::SecureRandom;
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
@ -274,13 +276,13 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
return Ok(None);
}
// Calculate padding (indicated by length not divisible by 4)
let padding_len = len % 4;
let data_len = if padding_len != 0 {
len - padding_len
} else {
len
};
let data_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
Error::new(
ErrorKind::InvalidData,
format!("invalid secure frame length: {len}"),
)
})?;
let padding_len = len - data_len;
meta.padding_len = padding_len as u8;
@ -303,6 +305,13 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::R
return Ok(());
}
if !is_valid_secure_payload_len(data.len()) {
return Err(Error::new(
ErrorKind::InvalidData,
format!("secure payload must be 4-byte aligned, got {}", data.len()),
));
}
// Generate padding that keeps total length non-divisible by 4.
let padding_len = secure_padding_len(data.len(), rng);

View File

@ -232,11 +232,13 @@ impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?;
// Strip padding (not aligned to 4)
if len % 4 != 0 {
let actual_len = len - (len % 4);
data.truncate(actual_len);
}
let payload_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
Error::new(
ErrorKind::InvalidData,
format!("Invalid secure frame length: {len}"),
)
})?;
data.truncate(payload_len);
Ok((Bytes::from(data), meta))
}
@ -267,6 +269,13 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
return Ok(());
}
if !is_valid_secure_payload_len(data.len()) {
return Err(Error::new(
ErrorKind::InvalidData,
format!("Secure payload must be 4-byte aligned, got {}", data.len()),
));
}
// Add padding so total length is never divisible by 4 (MTProto Secure)
let padding_len = secure_padding_len(data.len(), &self.rng);
let padding = self.rng.bytes(padding_len);
@ -550,9 +559,7 @@ mod tests {
writer.flush().await.unwrap();
let (received, _meta) = reader.read_frame().await.unwrap();
// Received should have padding stripped to align to 4
let expected_len = (data.len() / 4) * 4;
assert_eq!(received.len(), expected_len);
assert_eq!(received.len(), data.len());
}
#[tokio::test]

View File

@ -1,6 +1,6 @@
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::crypto::{AesCbc, crc32};
use crate::crypto::{AesCbc, crc32, crc32c};
use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
@ -8,17 +8,46 @@ use crate::protocol::constants::*;
pub(crate) enum WriterCommand {
Data(Vec<u8>),
DataAndFlush(Vec<u8>),
Keepalive,
Close,
}
pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec<u8> {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum RpcChecksumMode {
Crc32,
Crc32c,
}
impl RpcChecksumMode {
pub(crate) fn from_handshake_flags(flags: u32) -> Self {
if (flags & rpc_crypto_flags::USE_CRC32C) != 0 {
Self::Crc32c
} else {
Self::Crc32
}
}
pub(crate) fn advertised_flags(self) -> u32 {
match self {
Self::Crc32 => 0,
Self::Crc32c => rpc_crypto_flags::USE_CRC32C,
}
}
}
pub(crate) fn rpc_crc(mode: RpcChecksumMode, data: &[u8]) -> u32 {
match mode {
RpcChecksumMode::Crc32 => crc32(data),
RpcChecksumMode::Crc32c => crc32c(data),
}
}
pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8], crc_mode: RpcChecksumMode) -> Vec<u8> {
let total_len = (4 + 4 + payload.len() + 4) as u32;
let mut frame = Vec::with_capacity(total_len as usize);
frame.extend_from_slice(&total_len.to_le_bytes());
frame.extend_from_slice(&seq_no.to_le_bytes());
frame.extend_from_slice(payload);
let c = crc32(&frame);
let c = rpc_crc(crc_mode, &frame);
frame.extend_from_slice(&c.to_le_bytes());
frame
}
@ -45,7 +74,7 @@ pub(crate) async fn read_rpc_frame_plaintext(
let crc_offset = total_len - 4;
let expected_crc = u32::from_le_bytes(full[crc_offset..crc_offset + 4].try_into().unwrap());
let actual_crc = crc32(&full[..crc_offset]);
let actual_crc = rpc_crc(RpcChecksumMode::Crc32, &full[..crc_offset]);
if expected_crc != actual_crc {
return Err(ProxyError::InvalidHandshake(format!(
"CRC mismatch: 0x{expected_crc:08x} vs 0x{actual_crc:08x}"
@ -95,24 +124,52 @@ pub(crate) fn build_handshake_payload(
our_port: u16,
peer_ip: [u8; 4],
peer_port: u16,
flags: u32,
) -> [u8; 32] {
let mut p = [0u8; 32];
p[0..4].copy_from_slice(&RPC_HANDSHAKE_U32.to_le_bytes());
p[4..8].copy_from_slice(&flags.to_le_bytes());
// Keep C memory layout compatibility for PID IPv4 bytes.
// process_id sender_pid
p[8..12].copy_from_slice(&our_ip);
p[12..14].copy_from_slice(&our_port.to_le_bytes());
let pid = (std::process::id() & 0xffff) as u16;
p[14..16].copy_from_slice(&pid.to_le_bytes());
p[14..16].copy_from_slice(&process_pid16().to_le_bytes());
p[16..20].copy_from_slice(&process_utime().to_le_bytes());
// process_id peer_pid
p[20..24].copy_from_slice(&peer_ip);
p[24..26].copy_from_slice(&peer_port.to_le_bytes());
p[26..28].copy_from_slice(&0u16.to_le_bytes());
p[28..32].copy_from_slice(&0u32.to_le_bytes());
p
}
pub(crate) fn parse_handshake_flags(payload: &[u8]) -> Result<u32> {
if payload.len() != 32 {
return Err(ProxyError::InvalidHandshake(format!(
"Bad handshake payload len: {}",
payload.len()
)));
}
let hs_type = u32::from_le_bytes(payload[0..4].try_into().unwrap());
if hs_type != RPC_HANDSHAKE_U32 {
return Err(ProxyError::InvalidHandshake(format!(
"Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}"
)));
}
Ok(u32::from_le_bytes(payload[4..8].try_into().unwrap()))
}
fn process_pid16() -> u16 {
(std::process::id() & 0xffff) as u16
}
fn process_utime() -> u32 {
let utime = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as u32;
p[16..20].copy_from_slice(&utime.to_le_bytes());
p[20..24].copy_from_slice(&peer_ip);
p[24..26].copy_from_slice(&peer_port.to_le_bytes());
p
utime
}
pub(crate) fn cbc_encrypt_padded(
@ -160,11 +217,12 @@ pub(crate) struct RpcWriter {
pub(crate) key: [u8; 32],
pub(crate) iv: [u8; 16],
pub(crate) seq_no: i32,
pub(crate) crc_mode: RpcChecksumMode,
}
impl RpcWriter {
pub(crate) async fn send(&mut self, payload: &[u8]) -> Result<()> {
let frame = build_rpc_frame(self.seq_no, payload);
let frame = build_rpc_frame(self.seq_no, payload, self.crc_mode);
self.seq_no += 1;
let pad = (16 - (frame.len() % 16)) % 16;
@ -189,27 +247,4 @@ impl RpcWriter {
self.send(payload).await?;
self.writer.flush().await.map_err(ProxyError::Io)
}
/// Sends a 4-byte keepalive marker directly into the CBC stream.
/// This is not an RPC frame and must not consume sequence numbers.
pub(crate) async fn send_keepalive(&mut self) -> Result<()> {
let mut buf = [0u8; 16];
for i in 0..4 {
let start = i * 4;
let end = start + 4;
buf[start..end].copy_from_slice(&PADDING_FILLER);
}
let cipher = AesCbc::new(self.key, self.iv);
let mut v = buf.to_vec();
cipher
.encrypt_in_place(&mut v)
.map_err(|e| ProxyError::Crypto(format!("{e}")))?;
if v.len() >= 16 {
self.iv.copy_from_slice(&v[v.len() - 16..]);
}
self.writer.write_all(&v).await.map_err(ProxyError::Io)?;
self.writer.flush().await.map_err(ProxyError::Io)
}
}

View File

@ -18,13 +18,14 @@ use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_k
use crate::error::{ProxyError, Result};
use crate::network::IpFamily;
use crate::protocol::constants::{
ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32, RPC_HANDSHAKE_ERROR_U32,
RPC_HANDSHAKE_U32, RPC_PING_U32, RPC_PONG_U32, RPC_NONCE_U32,
ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32,
RPC_HANDSHAKE_ERROR_U32, rpc_crypto_flags,
};
use super::codec::{
build_handshake_payload, build_nonce_payload, build_rpc_frame, cbc_decrypt_inplace,
cbc_encrypt_padded, parse_nonce_payload, read_rpc_frame_plaintext,
RpcChecksumMode, build_handshake_payload, build_nonce_payload, build_rpc_frame,
cbc_decrypt_inplace, cbc_encrypt_padded, parse_handshake_flags, parse_nonce_payload,
read_rpc_frame_plaintext, rpc_crc,
};
use super::wire::{extract_ip_material, IpMaterial};
use super::MePool;
@ -37,6 +38,7 @@ pub(crate) struct HandshakeOutput {
pub read_iv: [u8; 16],
pub write_key: [u8; 32],
pub write_iv: [u8; 16],
pub crc_mode: RpcChecksumMode,
pub handshake_ms: f64,
}
@ -146,7 +148,7 @@ impl MePool {
let ks = self.key_selector().await;
let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce);
let nonce_frame = build_rpc_frame(-2, &nonce_payload);
let nonce_frame = build_rpc_frame(-2, &nonce_payload, RpcChecksumMode::Crc32);
let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]);
debug!(
key_selector = format_args!("0x{ks:08x}"),
@ -284,8 +286,15 @@ impl MePool {
srv_v6_opt.as_ref(),
);
let hs_payload = build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port());
let hs_frame = build_rpc_frame(-1, &hs_payload);
let requested_crc_mode = RpcChecksumMode::Crc32c;
let hs_payload = build_handshake_payload(
hs_our_ip,
local_addr.port(),
hs_peer_ip,
peer_addr.port(),
requested_crc_mode.advertised_flags(),
);
let hs_frame = build_rpc_frame(-1, &hs_payload, RpcChecksumMode::Crc32);
if diag_level >= 1 {
info!(
write_key = %hex_dump(&wk),
@ -314,7 +323,7 @@ impl MePool {
);
}
let (encrypted_hs, mut write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?;
let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?;
if diag_level >= 1 {
info!(
hs_cipher = %hex_dump(&encrypted_hs),
@ -328,6 +337,7 @@ impl MePool {
let mut enc_buf = BytesMut::with_capacity(256);
let mut dec_buf = BytesMut::with_capacity(256);
let mut read_iv = ri;
let mut negotiated_crc_mode = RpcChecksumMode::Crc32;
let mut handshake_ok = false;
while Instant::now() < deadline && !handshake_ok {
@ -375,17 +385,23 @@ impl MePool {
let frame = dec_buf.split_to(fl);
let pe = fl - 4;
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
let ac = crate::crypto::crc32(&frame[..pe]);
let ac = rpc_crc(RpcChecksumMode::Crc32, &frame[..pe]);
if ec != ac {
return Err(ProxyError::InvalidHandshake(format!(
"HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}"
)));
}
let hs_type = u32::from_le_bytes(frame[8..12].try_into().unwrap());
let hs_payload = &frame[8..pe];
if hs_payload.len() < 4 {
return Err(ProxyError::InvalidHandshake(
"Handshake payload too short".to_string(),
));
}
let hs_type = u32::from_le_bytes(hs_payload[0..4].try_into().unwrap());
if hs_type == RPC_HANDSHAKE_ERROR_U32 {
let err_code = if frame.len() >= 16 {
i32::from_le_bytes(frame[12..16].try_into().unwrap())
let err_code = if hs_payload.len() >= 8 {
i32::from_le_bytes(hs_payload[4..8].try_into().unwrap())
} else {
-1
};
@ -393,11 +409,21 @@ impl MePool {
"ME rejected handshake (error={err_code})"
)));
}
if hs_type != RPC_HANDSHAKE_U32 {
let hs_flags = parse_handshake_flags(hs_payload)?;
if hs_flags & 0xff != 0 {
return Err(ProxyError::InvalidHandshake(format!(
"Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}"
"Unsupported handshake flags: 0x{hs_flags:08x}"
)));
}
negotiated_crc_mode = if (hs_flags & requested_crc_mode.advertised_flags()) != 0 {
RpcChecksumMode::from_handshake_flags(hs_flags)
} else if (hs_flags & rpc_crypto_flags::USE_CRC32C) != 0 {
return Err(ProxyError::InvalidHandshake(format!(
"Peer negotiated unsupported CRC flags: 0x{hs_flags:08x}"
)));
} else {
RpcChecksumMode::Crc32
};
handshake_ok = true;
break;
@ -418,6 +444,7 @@ impl MePool {
read_iv,
write_key: wk,
write_iv,
crc_mode: negotiated_crc_mode,
handshake_ms,
})
}

View File

@ -17,10 +17,9 @@ use crate::network::IpFamily;
use crate::protocol::constants::*;
use super::ConnRegistry;
use super::registry::{BoundConn, ConnMeta};
use super::registry::BoundConn;
use super::codec::{RpcWriter, WriterCommand};
use super::reader::reader_loop;
use super::MeResponse;
const ME_ACTIVE_PING_SECS: u64 = 25;
const ME_ACTIVE_PING_JITTER_SECS: i64 = 5;
@ -417,12 +416,12 @@ impl MePool {
let draining = Arc::new(AtomicBool::new(false));
let (tx, mut rx) = mpsc::channel::<WriterCommand>(4096);
let tx_for_keepalive = tx.clone();
let stats = self.stats.clone();
let mut rpc_writer = RpcWriter {
writer: hs.wr,
key: hs.write_key,
iv: hs.write_iv,
seq_no: 0,
crc_mode: hs.crc_mode,
};
let cancel_wr = cancel.clone();
tokio::spawn(async move {
@ -436,17 +435,6 @@ impl MePool {
Some(WriterCommand::DataAndFlush(payload)) => {
if rpc_writer.send_and_flush(&payload).await.is_err() { break; }
}
Some(WriterCommand::Keepalive) => {
match rpc_writer.send_keepalive().await {
Ok(()) => {
stats.increment_me_keepalive_sent();
}
Err(_) => {
stats.increment_me_keepalive_failed();
break;
}
}
}
Some(WriterCommand::Close) | None => break,
}
}
@ -469,7 +457,11 @@ impl MePool {
let reg = self.registry.clone();
let writers_arc = self.writers_arc();
let ping_tracker = self.ping_tracker.clone();
let ping_tracker_reader = ping_tracker.clone();
let rtt_stats = self.rtt_stats.clone();
let stats_reader = self.stats.clone();
let stats_ping = self.stats.clone();
let stats_keepalive = self.stats.clone();
let pool = Arc::downgrade(self);
let cancel_ping = cancel.clone();
let tx_ping = tx.clone();
@ -489,12 +481,14 @@ impl MePool {
hs.rd,
hs.read_key,
hs.read_iv,
hs.crc_mode,
reg.clone(),
BytesMut::new(),
BytesMut::new(),
tx.clone(),
ping_tracker.clone(),
ping_tracker_reader,
rtt_stats.clone(),
stats_reader,
writer_id,
degraded.clone(),
cancel_reader_token.clone(),
@ -535,7 +529,12 @@ impl MePool {
p.extend_from_slice(&sent_id.to_le_bytes());
{
let mut tracker = ping_tracker_ping.lock().await;
let before = tracker.len();
tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120));
let expired = before.saturating_sub(tracker.len());
if expired > 0 {
stats_ping.increment_me_keepalive_timeout_by(expired as u64);
}
tracker.insert(sent_id, (std::time::Instant::now(), writer_id));
}
ping_id = ping_id.wrapping_add(1);
@ -558,18 +557,37 @@ impl MePool {
if keepalive_enabled {
let tx_keepalive = tx_for_keepalive;
let cancel_keepalive = cancel_keepalive_token;
let ping_tracker_keepalive = ping_tracker.clone();
tokio::spawn(async move {
// Per-writer jittered start to avoid phase sync.
let jitter_cap_ms = keepalive_interval.as_millis() / 2;
let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1);
let initial_jitter_ms = rand::rng().random_range(0..=effective_jitter_ms as u64);
tokio::time::sleep(Duration::from_millis(initial_jitter_ms)).await;
let mut ping_id: i64 = rand::random::<i64>();
loop {
tokio::select! {
_ = cancel_keepalive.cancelled() => break,
_ = tokio::time::sleep(keepalive_interval + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64))) => {}
}
if tx_keepalive.send(WriterCommand::Keepalive).await.is_err() {
let sent_id = ping_id;
ping_id = ping_id.wrapping_add(1);
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_PING_U32.to_le_bytes());
p.extend_from_slice(&sent_id.to_le_bytes());
{
let mut tracker = ping_tracker_keepalive.lock().await;
let before = tracker.len();
tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120));
let expired = before.saturating_sub(tracker.len());
if expired > 0 {
stats_keepalive.increment_me_keepalive_timeout_by(expired as u64);
}
tracker.insert(sent_id, (std::time::Instant::now(), writer_id));
}
stats_keepalive.increment_me_keepalive_sent();
if tx_keepalive.send(WriterCommand::DataAndFlush(p)).await.is_err() {
stats_keepalive.increment_me_keepalive_failed();
break;
}
}

View File

@ -10,30 +10,33 @@ use tokio::sync::{Mutex, mpsc};
use tokio_util::sync::CancellationToken;
use tracing::{debug, trace, warn};
use crate::crypto::{AesCbc, crc32};
use crate::crypto::AesCbc;
use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
use crate::stats::Stats;
use super::codec::WriterCommand;
use super::codec::{RpcChecksumMode, WriterCommand, rpc_crc};
use super::registry::RouteResult;
use super::{ConnRegistry, MeResponse};
pub(crate) async fn reader_loop(
mut rd: tokio::io::ReadHalf<TcpStream>,
dk: [u8; 32],
mut div: [u8; 16],
crc_mode: RpcChecksumMode,
reg: Arc<ConnRegistry>,
enc_leftover: BytesMut,
mut dec: BytesMut,
tx: mpsc::Sender<WriterCommand>,
ping_tracker: Arc<Mutex<HashMap<i64, (Instant, u64)>>>,
rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
stats: Arc<Stats>,
_writer_id: u64,
degraded: Arc<AtomicBool>,
cancel: CancellationToken,
) -> Result<()> {
let mut raw = enc_leftover;
let mut expected_seq: i32 = 0;
let mut seq_mismatch = 0u32;
loop {
let mut tmp = [0u8; 16_384];
@ -79,8 +82,9 @@ pub(crate) async fn reader_loop(
let frame = dec.split_to(fl);
let pe = fl - 4;
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
let actual_crc = crc32(&frame[..pe]);
let actual_crc = rpc_crc(crc_mode, &frame[..pe]);
if actual_crc != ec {
stats.increment_me_crc_mismatch();
warn!(
frame_len = fl,
expected_crc = format_args!("0x{ec:08x}"),
@ -92,15 +96,14 @@ pub(crate) async fn reader_loop(
let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap());
if seq_no != expected_seq {
stats.increment_me_seq_mismatch();
warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch");
seq_mismatch += 1;
if seq_mismatch > 10 {
return Err(ProxyError::Proxy("Too many seq mismatches".into()));
}
expected_seq = seq_no.wrapping_add(1);
} else {
expected_seq = expected_seq.wrapping_add(1);
return Err(ProxyError::SeqNoMismatch {
expected: expected_seq,
got: seq_no,
});
}
expected_seq = expected_seq.wrapping_add(1);
let payload = &frame[8..pe];
if payload.len() < 4 {
@ -117,7 +120,13 @@ pub(crate) async fn reader_loop(
trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS");
let routed = reg.route(cid, MeResponse::Data { flags, data }).await;
if !routed {
if !matches!(routed, RouteResult::Routed) {
match routed {
RouteResult::NoConn => stats.increment_me_route_drop_no_conn(),
RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(),
RouteResult::QueueFull => stats.increment_me_route_drop_queue_full(),
RouteResult::Routed => {}
}
reg.unregister(cid).await;
send_close_conn(&tx, cid).await;
}
@ -127,7 +136,13 @@ pub(crate) async fn reader_loop(
trace!(cid, cfm, "RPC_SIMPLE_ACK");
let routed = reg.route(cid, MeResponse::Ack(cfm)).await;
if !routed {
if !matches!(routed, RouteResult::Routed) {
match routed {
RouteResult::NoConn => stats.increment_me_route_drop_no_conn(),
RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(),
RouteResult::QueueFull => stats.increment_me_route_drop_queue_full(),
RouteResult::Routed => {}
}
reg.unregister(cid).await;
send_close_conn(&tx, cid).await;
}
@ -153,6 +168,7 @@ pub(crate) async fn reader_loop(
}
} else if pt == RPC_PONG_U32 && body.len() >= 8 {
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
stats.increment_me_keepalive_pong();
if let Some((sent, wid)) = {
let mut guard = ping_tracker.lock().await;
guard.remove(&ping_id)

View File

@ -1,13 +1,23 @@
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex, RwLock};
use tokio::sync::{mpsc, RwLock};
use tokio::sync::mpsc::error::TrySendError;
use super::codec::WriterCommand;
use super::MeResponse;
const ROUTE_CHANNEL_CAPACITY: usize = 4096;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RouteResult {
Routed,
NoConn,
ChannelClosed,
QueueFull,
}
#[derive(Clone)]
pub struct ConnMeta {
pub target_dc: i16,
@ -64,7 +74,7 @@ impl ConnRegistry {
pub async fn register(&self) -> (u64, mpsc::Receiver<MeResponse>) {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = mpsc::channel(1024);
let (tx, rx) = mpsc::channel(ROUTE_CHANNEL_CAPACITY);
self.inner.write().await.map.insert(id, tx);
(id, rx)
}
@ -83,12 +93,16 @@ impl ConnRegistry {
None
}
pub async fn route(&self, id: u64, resp: MeResponse) -> bool {
pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult {
let inner = self.inner.read().await;
if let Some(tx) = inner.map.get(&id) {
tx.try_send(resp).is_ok()
match tx.try_send(resp) {
Ok(()) => RouteResult::Routed,
Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed,
Err(TrySendError::Full(_)) => RouteResult::QueueFull,
}
} else {
false
RouteResult::NoConn
}
}