Middle-End protocol hardening

- Secure framing / hot-path fix: enforced a single length + padding contract across the framing layer. Replaced legacy runtime `len % 4` recovery with strict validation to eliminate undefined behavior paths.

- ME RPC aligned with C reference contract: handshake now includes `flags + sender_pid + peer_pid`. Added negotiated CRC mode (CRC32 / CRC32C) and applied the negotiated mode consistently in read/write paths.

- Sequence fail-fast semantics: immediate connection termination on first sequence mismatch with dedicated counter increment.

- Keepalive reworked to RPC ping/pong: removed raw CBC keepalive frames. Introduced stale ping tracker with proper timeout accounting.

- Route/backpressure observability improvements: increased per-connection route queue to 4096. Added `RouteResult` with explicit failure reasons (NoConn, ChannelClosed, QueueFull) and per-reason counters.

- Direct-DC secure mode-gate relaxation: removed TLS/secure conflict in Direct-DC handshake path.
This commit is contained in:
Alexey 2026-02-23 02:28:00 +03:00
parent 69be44b2b6
commit 6ff29e43d3
No known key found for this signature in database
16 changed files with 407 additions and 137 deletions

View File

@ -37,7 +37,7 @@ pub(crate) fn default_replay_window_secs() -> u64 {
} }
pub(crate) fn default_handshake_timeout() -> u64 { pub(crate) fn default_handshake_timeout() -> u64 {
30 15
} }
pub(crate) fn default_connect_timeout() -> u64 { pub(crate) fn default_connect_timeout() -> u64 {
@ -52,11 +52,11 @@ pub(crate) fn default_ack_timeout() -> u64 {
300 300
} }
pub(crate) fn default_me_one_retry() -> u8 { pub(crate) fn default_me_one_retry() -> u8 {
12 3
} }
pub(crate) fn default_me_one_timeout() -> u64 { pub(crate) fn default_me_one_timeout() -> u64 {
1200 1500
} }
pub(crate) fn default_listen_addr() -> String { pub(crate) fn default_listen_addr() -> String {
@ -83,7 +83,7 @@ pub(crate) fn default_unknown_dc_log_path() -> Option<String> {
} }
pub(crate) fn default_pool_size() -> usize { pub(crate) fn default_pool_size() -> usize {
16 2
} }
pub(crate) fn default_keepalive_interval() -> u64 { pub(crate) fn default_keepalive_interval() -> u64 {

View File

@ -118,7 +118,7 @@ impl Default for NetworkConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
ipv4: true, ipv4: true,
ipv6: Some(false), ipv6: None,
prefer: 4, prefer: 4,
multipath: false, multipath: false,
stun_servers: default_stun_servers(), stun_servers: default_stun_servers(),
@ -140,7 +140,7 @@ pub struct GeneralConfig {
#[serde(default = "default_true")] #[serde(default = "default_true")]
pub fast_mode: bool, pub fast_mode: bool,
#[serde(default = "default_true")] #[serde(default)]
pub use_middle_proxy: bool, pub use_middle_proxy: bool,
#[serde(default)] #[serde(default)]
@ -157,7 +157,7 @@ pub struct GeneralConfig {
pub middle_proxy_nat_ip: Option<IpAddr>, pub middle_proxy_nat_ip: Option<IpAddr>,
/// Enable STUN-based NAT probing to discover public IP:port for ME KDF. /// Enable STUN-based NAT probing to discover public IP:port for ME KDF.
#[serde(default = "default_true")] #[serde(default)]
pub middle_proxy_nat_probe: bool, pub middle_proxy_nat_probe: bool,
/// Optional STUN server address (host:port) for NAT probing. /// Optional STUN server address (host:port) for NAT probing.
@ -283,15 +283,15 @@ impl Default for GeneralConfig {
modes: ProxyModes::default(), modes: ProxyModes::default(),
prefer_ipv6: false, prefer_ipv6: false,
fast_mode: true, fast_mode: true,
use_middle_proxy: true, use_middle_proxy: false,
ad_tag: None, ad_tag: None,
proxy_secret_path: None, proxy_secret_path: None,
middle_proxy_nat_ip: None, middle_proxy_nat_ip: None,
middle_proxy_nat_probe: true, middle_proxy_nat_probe: false,
middle_proxy_nat_stun: None, middle_proxy_nat_stun: None,
middle_proxy_nat_stun_servers: Vec::new(), middle_proxy_nat_stun_servers: Vec::new(),
middle_proxy_pool_size: default_pool_size(), middle_proxy_pool_size: default_pool_size(),
middle_proxy_warm_standby: 8, middle_proxy_warm_standby: 0,
me_keepalive_enabled: true, me_keepalive_enabled: true,
me_keepalive_interval_secs: default_keepalive_interval(), me_keepalive_interval_secs: default_keepalive_interval(),
me_keepalive_jitter_secs: default_keepalive_jitter(), me_keepalive_jitter_secs: default_keepalive_jitter(),
@ -302,7 +302,7 @@ impl Default for GeneralConfig {
me_reconnect_max_concurrent_per_dc: 1, me_reconnect_max_concurrent_per_dc: 1,
me_reconnect_backoff_base_ms: default_reconnect_backoff_base_ms(), me_reconnect_backoff_base_ms: default_reconnect_backoff_base_ms(),
me_reconnect_backoff_cap_ms: default_reconnect_backoff_cap_ms(), me_reconnect_backoff_cap_ms: default_reconnect_backoff_cap_ms(),
me_reconnect_fast_retry_count: 11, me_reconnect_fast_retry_count: 1,
stun_iface_mismatch_ignore: false, stun_iface_mismatch_ignore: false,
unknown_dc_log_path: default_unknown_dc_log_path(), unknown_dc_log_path: default_unknown_dc_log_path(),
log_level: LogLevel::Normal, log_level: LogLevel::Normal,
@ -455,7 +455,7 @@ pub struct AntiCensorshipConfig {
pub fake_cert_len: usize, pub fake_cert_len: usize,
/// Enable TLS certificate emulation using cached real certificates. /// Enable TLS certificate emulation using cached real certificates.
#[serde(default = "default_true")] #[serde(default)]
pub tls_emulation: bool, pub tls_emulation: bool,
/// Directory to store TLS front cache (on disk). /// Directory to store TLS front cache (on disk).
@ -489,7 +489,7 @@ impl Default for AntiCensorshipConfig {
mask_port: default_mask_port(), mask_port: default_mask_port(),
mask_unix_sock: None, mask_unix_sock: None,
fake_cert_len: default_fake_cert_len(), fake_cert_len: default_fake_cert_len(),
tls_emulation: true, tls_emulation: false,
tls_front_dir: default_tls_front_dir(), tls_front_dir: default_tls_front_dir(),
server_hello_delay_min_ms: default_server_hello_delay_min_ms(), server_hello_delay_min_ms: default_server_hello_delay_min_ms(),
server_hello_delay_max_ms: default_server_hello_delay_max_ms(), server_hello_delay_max_ms: default_server_hello_delay_max_ms(),
@ -619,9 +619,9 @@ pub struct ListenerConfig {
/// - omitted — show no links (default) /// - omitted — show no links (default)
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum ShowLink { pub enum ShowLink {
/// Don't show any links. /// Don't show any links (default when omitted).
None, None,
/// Show links for all configured users (default). /// Show links for all configured users.
All, All,
/// Show links for specific users. /// Show links for specific users.
Specific(Vec<String>), Specific(Vec<String>),
@ -629,7 +629,7 @@ pub enum ShowLink {
impl Default for ShowLink { impl Default for ShowLink {
fn default() -> Self { fn default() -> Self {
ShowLink::All ShowLink::None
} }
} }

View File

@ -55,6 +55,11 @@ pub fn crc32(data: &[u8]) -> u32 {
crc32fast::hash(data) crc32fast::hash(data)
} }
/// CRC32C (Castagnoli)
pub fn crc32c(data: &[u8]) -> u32 {
crc32c::crc32c(data)
}
/// Build the exact prekey buffer used by Telegram Middle Proxy KDF. /// Build the exact prekey buffer used by Telegram Middle Proxy KDF.
/// ///
/// Returned buffer layout (IPv4): /// Returned buffer layout (IPv4):

View File

@ -5,5 +5,8 @@ pub mod hash;
pub mod random; pub mod random;
pub use aes::{AesCtr, AesCbc}; pub use aes::{AesCtr, AesCbc};
pub use hash::{sha256, sha256_hmac, sha1, md5, crc32, derive_middleproxy_keys, build_middleproxy_prekey}; pub use hash::{
build_middleproxy_prekey, crc32, crc32c, derive_middleproxy_keys, md5, sha1, sha256,
sha256_hmac,
};
pub use random::SecureRandom; pub use random::SecureRandom;

View File

@ -100,6 +100,14 @@ fn render_metrics(stats: &Stats) -> String {
let _ = writeln!(out, "# TYPE telemt_me_keepalive_failed_total counter"); let _ = writeln!(out, "# TYPE telemt_me_keepalive_failed_total counter");
let _ = writeln!(out, "telemt_me_keepalive_failed_total {}", stats.get_me_keepalive_failed()); let _ = writeln!(out, "telemt_me_keepalive_failed_total {}", stats.get_me_keepalive_failed());
let _ = writeln!(out, "# HELP telemt_me_keepalive_pong_total ME keepalive pong replies");
let _ = writeln!(out, "# TYPE telemt_me_keepalive_pong_total counter");
let _ = writeln!(out, "telemt_me_keepalive_pong_total {}", stats.get_me_keepalive_pong());
let _ = writeln!(out, "# HELP telemt_me_keepalive_timeout_total ME keepalive ping timeouts");
let _ = writeln!(out, "# TYPE telemt_me_keepalive_timeout_total counter");
let _ = writeln!(out, "telemt_me_keepalive_timeout_total {}", stats.get_me_keepalive_timeout());
let _ = writeln!(out, "# HELP telemt_me_reconnect_attempts_total ME reconnect attempts"); let _ = writeln!(out, "# HELP telemt_me_reconnect_attempts_total ME reconnect attempts");
let _ = writeln!(out, "# TYPE telemt_me_reconnect_attempts_total counter"); let _ = writeln!(out, "# TYPE telemt_me_reconnect_attempts_total counter");
let _ = writeln!(out, "telemt_me_reconnect_attempts_total {}", stats.get_me_reconnect_attempts()); let _ = writeln!(out, "telemt_me_reconnect_attempts_total {}", stats.get_me_reconnect_attempts());
@ -108,6 +116,30 @@ fn render_metrics(stats: &Stats) -> String {
let _ = writeln!(out, "# TYPE telemt_me_reconnect_success_total counter"); let _ = writeln!(out, "# TYPE telemt_me_reconnect_success_total counter");
let _ = writeln!(out, "telemt_me_reconnect_success_total {}", stats.get_me_reconnect_success()); let _ = writeln!(out, "telemt_me_reconnect_success_total {}", stats.get_me_reconnect_success());
let _ = writeln!(out, "# HELP telemt_me_crc_mismatch_total ME CRC mismatches");
let _ = writeln!(out, "# TYPE telemt_me_crc_mismatch_total counter");
let _ = writeln!(out, "telemt_me_crc_mismatch_total {}", stats.get_me_crc_mismatch());
let _ = writeln!(out, "# HELP telemt_me_seq_mismatch_total ME sequence mismatches");
let _ = writeln!(out, "# TYPE telemt_me_seq_mismatch_total counter");
let _ = writeln!(out, "telemt_me_seq_mismatch_total {}", stats.get_me_seq_mismatch());
let _ = writeln!(out, "# HELP telemt_me_route_drop_no_conn_total ME route drops: no conn");
let _ = writeln!(out, "# TYPE telemt_me_route_drop_no_conn_total counter");
let _ = writeln!(out, "telemt_me_route_drop_no_conn_total {}", stats.get_me_route_drop_no_conn());
let _ = writeln!(out, "# HELP telemt_me_route_drop_channel_closed_total ME route drops: channel closed");
let _ = writeln!(out, "# TYPE telemt_me_route_drop_channel_closed_total counter");
let _ = writeln!(out, "telemt_me_route_drop_channel_closed_total {}", stats.get_me_route_drop_channel_closed());
let _ = writeln!(out, "# HELP telemt_me_route_drop_queue_full_total ME route drops: queue full");
let _ = writeln!(out, "# TYPE telemt_me_route_drop_queue_full_total counter");
let _ = writeln!(out, "telemt_me_route_drop_queue_full_total {}", stats.get_me_route_drop_queue_full());
let _ = writeln!(out, "# HELP telemt_secure_padding_invalid_total Invalid secure frame lengths");
let _ = writeln!(out, "# TYPE telemt_secure_padding_invalid_total counter");
let _ = writeln!(out, "telemt_secure_padding_invalid_total {}", stats.get_secure_padding_invalid());
let _ = writeln!(out, "# HELP telemt_user_connections_total Per-user total connections"); let _ = writeln!(out, "# HELP telemt_user_connections_total Per-user total connections");
let _ = writeln!(out, "# TYPE telemt_user_connections_total counter"); let _ = writeln!(out, "# TYPE telemt_user_connections_total counter");
let _ = writeln!(out, "# HELP telemt_user_connections_current Per-user active connections"); let _ = writeln!(out, "# HELP telemt_user_connections_current Per-user active connections");

View File

@ -156,17 +156,37 @@ pub const MAX_TLS_RECORD_SIZE: usize = 16384;
/// RFC 8446 §5.2 allows up to 16384 + 256 bytes of ciphertext /// RFC 8446 §5.2 allows up to 16384 + 256 bytes of ciphertext
pub const MAX_TLS_CHUNK_SIZE: usize = 16384 + 256; pub const MAX_TLS_CHUNK_SIZE: usize = 16384 + 256;
/// Generate padding length for Secure Intermediate protocol. /// Secure Intermediate payload is expected to be 4-byte aligned.
/// Total (data + padding) must not be divisible by 4 per MTProto spec. pub fn is_valid_secure_payload_len(data_len: usize) -> bool {
pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize { data_len % 4 == 0
let rem = data_len % 4;
match rem {
0 => (rng.range(3) + 1) as usize, // {1, 2, 3}
1 => rng.range(3) as usize, // {0, 1, 2}
2 => [0usize, 1, 3][rng.range(3) as usize], // {0, 1, 3}
3 => [0usize, 2, 3][rng.range(3) as usize], // {0, 2, 3}
_ => unreachable!(),
} }
/// Compute Secure Intermediate payload length from wire length.
///
/// Returns `None` for invalid Secure lengths (e.g. divisible by 4).
pub fn secure_payload_len_from_wire_len(wire_len: usize) -> Option<usize> {
if wire_len < 4 {
return None;
}
let padding_len = wire_len % 4;
if padding_len == 0 || wire_len < padding_len {
return None;
}
let payload_len = wire_len - padding_len;
if !is_valid_secure_payload_len(payload_len) {
return None;
}
Some(payload_len)
}
/// Generate padding length for Secure Intermediate protocol.
/// Data must be 4-byte aligned; padding is 1..=3 so total is never divisible by 4.
pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize {
debug_assert!(
is_valid_secure_payload_len(data_len),
"Secure payload must be 4-byte aligned, got {data_len}"
);
(rng.range(3) + 1) as usize
} }
// ============= Timeouts ============= // ============= Timeouts =============
@ -301,6 +321,10 @@ pub mod rpc_flags {
pub const FLAG_QUICKACK: u32 = 0x80000000; pub const FLAG_QUICKACK: u32 = 0x80000000;
} }
pub mod rpc_crypto_flags {
pub const USE_CRC32C: u32 = 0x800;
}
pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5; pub const ME_CONNECT_TIMEOUT_SECS: u64 = 5;
pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10; pub const ME_HANDSHAKE_TIMEOUT_SECS: u64 = 10;
@ -339,7 +363,7 @@ mod tests {
#[test] #[test]
fn secure_padding_never_produces_aligned_total() { fn secure_padding_never_produces_aligned_total() {
let rng = SecureRandom::new(); let rng = SecureRandom::new();
for data_len in 0..1000 { for data_len in (0..1000).step_by(4) {
for _ in 0..100 { for _ in 0..100 {
let padding = secure_padding_len(data_len, &rng); let padding = secure_padding_len(data_len, &rng);
assert!( assert!(
@ -355,4 +379,22 @@ mod tests {
} }
} }
} }
#[test]
fn secure_wire_len_roundtrip_for_aligned_payload() {
for payload_len in (4..4096).step_by(4) {
for padding in 1..=3usize {
let wire_len = payload_len + padding;
let recovered = secure_payload_len_from_wire_len(wire_len);
assert_eq!(recovered, Some(payload_len));
}
}
}
#[test]
fn secure_wire_len_rejects_aligned_totals() {
for wire_len in (0..1024).step_by(4) {
assert_eq!(secure_payload_len_from_wire_len(wire_len), None);
}
}
} }

View File

@ -253,7 +253,11 @@ where
let mode_ok = match proto_tag { let mode_ok = match proto_tag {
ProtoTag::Secure => { ProtoTag::Secure => {
if is_tls { config.general.modes.tls } else { config.general.modes.secure } if is_tls {
config.general.modes.tls || config.general.modes.secure
} else {
config.general.modes.secure || config.general.modes.tls
}
} }
ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic, ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic,
}; };

View File

@ -165,6 +165,7 @@ where
frame_limit, frame_limit,
&user, &user,
&mut frame_counter, &mut frame_counter,
&stats,
).await { ).await {
Ok(Some((payload, quickack))) => { Ok(Some((payload, quickack))) => {
trace!(conn_id, bytes = payload.len(), "C->ME frame"); trace!(conn_id, bytes = payload.len(), "C->ME frame");
@ -238,6 +239,7 @@ async fn read_client_payload<R>(
max_frame: usize, max_frame: usize,
user: &str, user: &str,
frame_counter: &mut u64, frame_counter: &mut u64,
stats: &Stats,
) -> Result<Option<(Vec<u8>, bool)>> ) -> Result<Option<(Vec<u8>, bool)>>
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
@ -326,18 +328,29 @@ where
))); )));
} }
let secure_payload_len = if proto_tag == ProtoTag::Secure {
match secure_payload_len_from_wire_len(len) {
Some(payload_len) => payload_len,
None => {
stats.increment_secure_padding_invalid();
return Err(ProxyError::Proxy(format!(
"Invalid secure frame length: {len}"
)));
}
}
} else {
len
};
let mut payload = vec![0u8; len]; let mut payload = vec![0u8; len];
client_reader client_reader
.read_exact(&mut payload) .read_exact(&mut payload)
.await .await
.map_err(ProxyError::Io)?; .map_err(ProxyError::Io)?;
// Secure Intermediate: remove random padding (last len%4 bytes) // Secure Intermediate: strip validated trailing padding bytes.
if proto_tag == ProtoTag::Secure { if proto_tag == ProtoTag::Secure {
let rem = len % 4; payload.truncate(secure_payload_len);
if rem != 0 && payload.len() >= rem {
payload.truncate(len - rem);
}
} }
*frame_counter += 1; *frame_counter += 1;
return Ok(Some((payload, quickack))); return Ok(Some((payload, quickack)));
@ -400,6 +413,12 @@ where
} }
ProtoTag::Intermediate | ProtoTag::Secure => { ProtoTag::Intermediate | ProtoTag::Secure => {
let padding_len = if proto_tag == ProtoTag::Secure { let padding_len = if proto_tag == ProtoTag::Secure {
if !is_valid_secure_payload_len(data.len()) {
return Err(ProxyError::Proxy(format!(
"Secure payload must be 4-byte aligned, got {}",
data.len()
)));
}
secure_padding_len(data.len(), rng) secure_padding_len(data.len(), rng)
} else { } else {
0 0

View File

@ -21,8 +21,16 @@ pub struct Stats {
handshake_timeouts: AtomicU64, handshake_timeouts: AtomicU64,
me_keepalive_sent: AtomicU64, me_keepalive_sent: AtomicU64,
me_keepalive_failed: AtomicU64, me_keepalive_failed: AtomicU64,
me_keepalive_pong: AtomicU64,
me_keepalive_timeout: AtomicU64,
me_reconnect_attempts: AtomicU64, me_reconnect_attempts: AtomicU64,
me_reconnect_success: AtomicU64, me_reconnect_success: AtomicU64,
me_crc_mismatch: AtomicU64,
me_seq_mismatch: AtomicU64,
me_route_drop_no_conn: AtomicU64,
me_route_drop_channel_closed: AtomicU64,
me_route_drop_queue_full: AtomicU64,
secure_padding_invalid: AtomicU64,
user_stats: DashMap<String, UserStats>, user_stats: DashMap<String, UserStats>,
start_time: parking_lot::RwLock<Option<Instant>>, start_time: parking_lot::RwLock<Option<Instant>>,
} }
@ -49,14 +57,45 @@ impl Stats {
pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); } pub fn increment_handshake_timeouts(&self) { self.handshake_timeouts.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_keepalive_sent(&self) { self.me_keepalive_sent.fetch_add(1, Ordering::Relaxed); } pub fn increment_me_keepalive_sent(&self) { self.me_keepalive_sent.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_keepalive_failed(&self) { self.me_keepalive_failed.fetch_add(1, Ordering::Relaxed); } pub fn increment_me_keepalive_failed(&self) { self.me_keepalive_failed.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_keepalive_pong(&self) { self.me_keepalive_pong.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_keepalive_timeout(&self) { self.me_keepalive_timeout.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_keepalive_timeout_by(&self, value: u64) {
self.me_keepalive_timeout.fetch_add(value, Ordering::Relaxed);
}
pub fn increment_me_reconnect_attempt(&self) { self.me_reconnect_attempts.fetch_add(1, Ordering::Relaxed); } pub fn increment_me_reconnect_attempt(&self) { self.me_reconnect_attempts.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_reconnect_success(&self) { self.me_reconnect_success.fetch_add(1, Ordering::Relaxed); } pub fn increment_me_reconnect_success(&self) { self.me_reconnect_success.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_crc_mismatch(&self) { self.me_crc_mismatch.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_seq_mismatch(&self) { self.me_seq_mismatch.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_route_drop_no_conn(&self) { self.me_route_drop_no_conn.fetch_add(1, Ordering::Relaxed); }
pub fn increment_me_route_drop_channel_closed(&self) {
self.me_route_drop_channel_closed.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_me_route_drop_queue_full(&self) {
self.me_route_drop_queue_full.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_secure_padding_invalid(&self) {
self.secure_padding_invalid.fetch_add(1, Ordering::Relaxed);
}
pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) } pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) }
pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) } pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) }
pub fn get_me_keepalive_sent(&self) -> u64 { self.me_keepalive_sent.load(Ordering::Relaxed) } pub fn get_me_keepalive_sent(&self) -> u64 { self.me_keepalive_sent.load(Ordering::Relaxed) }
pub fn get_me_keepalive_failed(&self) -> u64 { self.me_keepalive_failed.load(Ordering::Relaxed) } pub fn get_me_keepalive_failed(&self) -> u64 { self.me_keepalive_failed.load(Ordering::Relaxed) }
pub fn get_me_keepalive_pong(&self) -> u64 { self.me_keepalive_pong.load(Ordering::Relaxed) }
pub fn get_me_keepalive_timeout(&self) -> u64 { self.me_keepalive_timeout.load(Ordering::Relaxed) }
pub fn get_me_reconnect_attempts(&self) -> u64 { self.me_reconnect_attempts.load(Ordering::Relaxed) } pub fn get_me_reconnect_attempts(&self) -> u64 { self.me_reconnect_attempts.load(Ordering::Relaxed) }
pub fn get_me_reconnect_success(&self) -> u64 { self.me_reconnect_success.load(Ordering::Relaxed) } pub fn get_me_reconnect_success(&self) -> u64 { self.me_reconnect_success.load(Ordering::Relaxed) }
pub fn get_me_crc_mismatch(&self) -> u64 { self.me_crc_mismatch.load(Ordering::Relaxed) }
pub fn get_me_seq_mismatch(&self) -> u64 { self.me_seq_mismatch.load(Ordering::Relaxed) }
pub fn get_me_route_drop_no_conn(&self) -> u64 { self.me_route_drop_no_conn.load(Ordering::Relaxed) }
pub fn get_me_route_drop_channel_closed(&self) -> u64 {
self.me_route_drop_channel_closed.load(Ordering::Relaxed)
}
pub fn get_me_route_drop_queue_full(&self) -> u64 {
self.me_route_drop_queue_full.load(Ordering::Relaxed)
}
pub fn get_secure_padding_invalid(&self) -> u64 {
self.secure_padding_invalid.load(Ordering::Relaxed)
}
pub fn increment_user_connects(&self, user: &str) { pub fn increment_user_connects(&self, user: &str) {
self.user_stats.entry(user.to_string()).or_default() self.user_stats.entry(user.to_string()).or_default()

View File

@ -8,7 +8,9 @@ use std::io::{self, Error, ErrorKind};
use std::sync::Arc; use std::sync::Arc;
use tokio_util::codec::{Decoder, Encoder}; use tokio_util::codec::{Decoder, Encoder};
use crate::protocol::constants::{ProtoTag, secure_padding_len}; use crate::protocol::constants::{
ProtoTag, is_valid_secure_payload_len, secure_padding_len, secure_payload_len_from_wire_len,
};
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait}; use super::frame::{Frame, FrameMeta, FrameCodec as FrameCodecTrait};
@ -274,13 +276,13 @@ fn decode_secure(src: &mut BytesMut, max_size: usize) -> io::Result<Option<Frame
return Ok(None); return Ok(None);
} }
// Calculate padding (indicated by length not divisible by 4) let data_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
let padding_len = len % 4; Error::new(
let data_len = if padding_len != 0 { ErrorKind::InvalidData,
len - padding_len format!("invalid secure frame length: {len}"),
} else { )
len })?;
}; let padding_len = len - data_len;
meta.padding_len = padding_len as u8; meta.padding_len = padding_len as u8;
@ -303,6 +305,13 @@ fn encode_secure(frame: &Frame, dst: &mut BytesMut, rng: &SecureRandom) -> io::R
return Ok(()); return Ok(());
} }
if !is_valid_secure_payload_len(data.len()) {
return Err(Error::new(
ErrorKind::InvalidData,
format!("secure payload must be 4-byte aligned, got {}", data.len()),
));
}
// Generate padding that keeps total length non-divisible by 4. // Generate padding that keeps total length non-divisible by 4.
let padding_len = secure_padding_len(data.len(), rng); let padding_len = secure_padding_len(data.len(), rng);

View File

@ -232,11 +232,13 @@ impl<R: AsyncRead + Unpin> SecureIntermediateFrameReader<R> {
let mut data = vec![0u8; len]; let mut data = vec![0u8; len];
self.upstream.read_exact(&mut data).await?; self.upstream.read_exact(&mut data).await?;
// Strip padding (not aligned to 4) let payload_len = secure_payload_len_from_wire_len(len).ok_or_else(|| {
if len % 4 != 0 { Error::new(
let actual_len = len - (len % 4); ErrorKind::InvalidData,
data.truncate(actual_len); format!("Invalid secure frame length: {len}"),
} )
})?;
data.truncate(payload_len);
Ok((Bytes::from(data), meta)) Ok((Bytes::from(data), meta))
} }
@ -267,6 +269,13 @@ impl<W: AsyncWrite + Unpin> SecureIntermediateFrameWriter<W> {
return Ok(()); return Ok(());
} }
if !is_valid_secure_payload_len(data.len()) {
return Err(Error::new(
ErrorKind::InvalidData,
format!("Secure payload must be 4-byte aligned, got {}", data.len()),
));
}
// Add padding so total length is never divisible by 4 (MTProto Secure) // Add padding so total length is never divisible by 4 (MTProto Secure)
let padding_len = secure_padding_len(data.len(), &self.rng); let padding_len = secure_padding_len(data.len(), &self.rng);
let padding = self.rng.bytes(padding_len); let padding = self.rng.bytes(padding_len);
@ -550,9 +559,7 @@ mod tests {
writer.flush().await.unwrap(); writer.flush().await.unwrap();
let (received, _meta) = reader.read_frame().await.unwrap(); let (received, _meta) = reader.read_frame().await.unwrap();
// Received should have padding stripped to align to 4 assert_eq!(received.len(), data.len());
let expected_len = (data.len() / 4) * 4;
assert_eq!(received.len(), expected_len);
} }
#[tokio::test] #[tokio::test]

View File

@ -1,6 +1,6 @@
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::crypto::{AesCbc, crc32}; use crate::crypto::{AesCbc, crc32, crc32c};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::protocol::constants::*; use crate::protocol::constants::*;
@ -8,17 +8,46 @@ use crate::protocol::constants::*;
pub(crate) enum WriterCommand { pub(crate) enum WriterCommand {
Data(Vec<u8>), Data(Vec<u8>),
DataAndFlush(Vec<u8>), DataAndFlush(Vec<u8>),
Keepalive,
Close, Close,
} }
pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8]) -> Vec<u8> { #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum RpcChecksumMode {
Crc32,
Crc32c,
}
impl RpcChecksumMode {
pub(crate) fn from_handshake_flags(flags: u32) -> Self {
if (flags & rpc_crypto_flags::USE_CRC32C) != 0 {
Self::Crc32c
} else {
Self::Crc32
}
}
pub(crate) fn advertised_flags(self) -> u32 {
match self {
Self::Crc32 => 0,
Self::Crc32c => rpc_crypto_flags::USE_CRC32C,
}
}
}
pub(crate) fn rpc_crc(mode: RpcChecksumMode, data: &[u8]) -> u32 {
match mode {
RpcChecksumMode::Crc32 => crc32(data),
RpcChecksumMode::Crc32c => crc32c(data),
}
}
pub(crate) fn build_rpc_frame(seq_no: i32, payload: &[u8], crc_mode: RpcChecksumMode) -> Vec<u8> {
let total_len = (4 + 4 + payload.len() + 4) as u32; let total_len = (4 + 4 + payload.len() + 4) as u32;
let mut frame = Vec::with_capacity(total_len as usize); let mut frame = Vec::with_capacity(total_len as usize);
frame.extend_from_slice(&total_len.to_le_bytes()); frame.extend_from_slice(&total_len.to_le_bytes());
frame.extend_from_slice(&seq_no.to_le_bytes()); frame.extend_from_slice(&seq_no.to_le_bytes());
frame.extend_from_slice(payload); frame.extend_from_slice(payload);
let c = crc32(&frame); let c = rpc_crc(crc_mode, &frame);
frame.extend_from_slice(&c.to_le_bytes()); frame.extend_from_slice(&c.to_le_bytes());
frame frame
} }
@ -45,7 +74,7 @@ pub(crate) async fn read_rpc_frame_plaintext(
let crc_offset = total_len - 4; let crc_offset = total_len - 4;
let expected_crc = u32::from_le_bytes(full[crc_offset..crc_offset + 4].try_into().unwrap()); let expected_crc = u32::from_le_bytes(full[crc_offset..crc_offset + 4].try_into().unwrap());
let actual_crc = crc32(&full[..crc_offset]); let actual_crc = rpc_crc(RpcChecksumMode::Crc32, &full[..crc_offset]);
if expected_crc != actual_crc { if expected_crc != actual_crc {
return Err(ProxyError::InvalidHandshake(format!( return Err(ProxyError::InvalidHandshake(format!(
"CRC mismatch: 0x{expected_crc:08x} vs 0x{actual_crc:08x}" "CRC mismatch: 0x{expected_crc:08x} vs 0x{actual_crc:08x}"
@ -95,24 +124,52 @@ pub(crate) fn build_handshake_payload(
our_port: u16, our_port: u16,
peer_ip: [u8; 4], peer_ip: [u8; 4],
peer_port: u16, peer_port: u16,
flags: u32,
) -> [u8; 32] { ) -> [u8; 32] {
let mut p = [0u8; 32]; let mut p = [0u8; 32];
p[0..4].copy_from_slice(&RPC_HANDSHAKE_U32.to_le_bytes()); p[0..4].copy_from_slice(&RPC_HANDSHAKE_U32.to_le_bytes());
p[4..8].copy_from_slice(&flags.to_le_bytes());
// Keep C memory layout compatibility for PID IPv4 bytes. // process_id sender_pid
p[8..12].copy_from_slice(&our_ip); p[8..12].copy_from_slice(&our_ip);
p[12..14].copy_from_slice(&our_port.to_le_bytes()); p[12..14].copy_from_slice(&our_port.to_le_bytes());
let pid = (std::process::id() & 0xffff) as u16; p[14..16].copy_from_slice(&process_pid16().to_le_bytes());
p[14..16].copy_from_slice(&pid.to_le_bytes()); p[16..20].copy_from_slice(&process_utime().to_le_bytes());
// process_id peer_pid
p[20..24].copy_from_slice(&peer_ip);
p[24..26].copy_from_slice(&peer_port.to_le_bytes());
p[26..28].copy_from_slice(&0u16.to_le_bytes());
p[28..32].copy_from_slice(&0u32.to_le_bytes());
p
}
pub(crate) fn parse_handshake_flags(payload: &[u8]) -> Result<u32> {
if payload.len() != 32 {
return Err(ProxyError::InvalidHandshake(format!(
"Bad handshake payload len: {}",
payload.len()
)));
}
let hs_type = u32::from_le_bytes(payload[0..4].try_into().unwrap());
if hs_type != RPC_HANDSHAKE_U32 {
return Err(ProxyError::InvalidHandshake(format!(
"Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}"
)));
}
Ok(u32::from_le_bytes(payload[4..8].try_into().unwrap()))
}
fn process_pid16() -> u16 {
(std::process::id() & 0xffff) as u16
}
fn process_utime() -> u32 {
let utime = std::time::SystemTime::now() let utime = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default() .unwrap_or_default()
.as_secs() as u32; .as_secs() as u32;
p[16..20].copy_from_slice(&utime.to_le_bytes()); utime
p[20..24].copy_from_slice(&peer_ip);
p[24..26].copy_from_slice(&peer_port.to_le_bytes());
p
} }
pub(crate) fn cbc_encrypt_padded( pub(crate) fn cbc_encrypt_padded(
@ -160,11 +217,12 @@ pub(crate) struct RpcWriter {
pub(crate) key: [u8; 32], pub(crate) key: [u8; 32],
pub(crate) iv: [u8; 16], pub(crate) iv: [u8; 16],
pub(crate) seq_no: i32, pub(crate) seq_no: i32,
pub(crate) crc_mode: RpcChecksumMode,
} }
impl RpcWriter { impl RpcWriter {
pub(crate) async fn send(&mut self, payload: &[u8]) -> Result<()> { pub(crate) async fn send(&mut self, payload: &[u8]) -> Result<()> {
let frame = build_rpc_frame(self.seq_no, payload); let frame = build_rpc_frame(self.seq_no, payload, self.crc_mode);
self.seq_no += 1; self.seq_no += 1;
let pad = (16 - (frame.len() % 16)) % 16; let pad = (16 - (frame.len() % 16)) % 16;
@ -189,27 +247,4 @@ impl RpcWriter {
self.send(payload).await?; self.send(payload).await?;
self.writer.flush().await.map_err(ProxyError::Io) self.writer.flush().await.map_err(ProxyError::Io)
} }
/// Sends a 4-byte keepalive marker directly into the CBC stream.
/// This is not an RPC frame and must not consume sequence numbers.
pub(crate) async fn send_keepalive(&mut self) -> Result<()> {
let mut buf = [0u8; 16];
for i in 0..4 {
let start = i * 4;
let end = start + 4;
buf[start..end].copy_from_slice(&PADDING_FILLER);
}
let cipher = AesCbc::new(self.key, self.iv);
let mut v = buf.to_vec();
cipher
.encrypt_in_place(&mut v)
.map_err(|e| ProxyError::Crypto(format!("{e}")))?;
if v.len() >= 16 {
self.iv.copy_from_slice(&v[v.len() - 16..]);
}
self.writer.write_all(&v).await.map_err(ProxyError::Io)?;
self.writer.flush().await.map_err(ProxyError::Io)
}
} }

View File

@ -18,13 +18,14 @@ use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_k
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::network::IpFamily; use crate::network::IpFamily;
use crate::protocol::constants::{ use crate::protocol::constants::{
ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32, RPC_HANDSHAKE_ERROR_U32, ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32,
RPC_HANDSHAKE_U32, RPC_PING_U32, RPC_PONG_U32, RPC_NONCE_U32, RPC_HANDSHAKE_ERROR_U32, rpc_crypto_flags,
}; };
use super::codec::{ use super::codec::{
build_handshake_payload, build_nonce_payload, build_rpc_frame, cbc_decrypt_inplace, RpcChecksumMode, build_handshake_payload, build_nonce_payload, build_rpc_frame,
cbc_encrypt_padded, parse_nonce_payload, read_rpc_frame_plaintext, cbc_decrypt_inplace, cbc_encrypt_padded, parse_handshake_flags, parse_nonce_payload,
read_rpc_frame_plaintext, rpc_crc,
}; };
use super::wire::{extract_ip_material, IpMaterial}; use super::wire::{extract_ip_material, IpMaterial};
use super::MePool; use super::MePool;
@ -37,6 +38,7 @@ pub(crate) struct HandshakeOutput {
pub read_iv: [u8; 16], pub read_iv: [u8; 16],
pub write_key: [u8; 32], pub write_key: [u8; 32],
pub write_iv: [u8; 16], pub write_iv: [u8; 16],
pub crc_mode: RpcChecksumMode,
pub handshake_ms: f64, pub handshake_ms: f64,
} }
@ -146,7 +148,7 @@ impl MePool {
let ks = self.key_selector().await; let ks = self.key_selector().await;
let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce); let nonce_payload = build_nonce_payload(ks, crypto_ts, &my_nonce);
let nonce_frame = build_rpc_frame(-2, &nonce_payload); let nonce_frame = build_rpc_frame(-2, &nonce_payload, RpcChecksumMode::Crc32);
let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]); let dump = hex_dump(&nonce_frame[..nonce_frame.len().min(44)]);
debug!( debug!(
key_selector = format_args!("0x{ks:08x}"), key_selector = format_args!("0x{ks:08x}"),
@ -284,8 +286,15 @@ impl MePool {
srv_v6_opt.as_ref(), srv_v6_opt.as_ref(),
); );
let hs_payload = build_handshake_payload(hs_our_ip, local_addr.port(), hs_peer_ip, peer_addr.port()); let requested_crc_mode = RpcChecksumMode::Crc32c;
let hs_frame = build_rpc_frame(-1, &hs_payload); let hs_payload = build_handshake_payload(
hs_our_ip,
local_addr.port(),
hs_peer_ip,
peer_addr.port(),
requested_crc_mode.advertised_flags(),
);
let hs_frame = build_rpc_frame(-1, &hs_payload, RpcChecksumMode::Crc32);
if diag_level >= 1 { if diag_level >= 1 {
info!( info!(
write_key = %hex_dump(&wk), write_key = %hex_dump(&wk),
@ -314,7 +323,7 @@ impl MePool {
); );
} }
let (encrypted_hs, mut write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?; let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?;
if diag_level >= 1 { if diag_level >= 1 {
info!( info!(
hs_cipher = %hex_dump(&encrypted_hs), hs_cipher = %hex_dump(&encrypted_hs),
@ -328,6 +337,7 @@ impl MePool {
let mut enc_buf = BytesMut::with_capacity(256); let mut enc_buf = BytesMut::with_capacity(256);
let mut dec_buf = BytesMut::with_capacity(256); let mut dec_buf = BytesMut::with_capacity(256);
let mut read_iv = ri; let mut read_iv = ri;
let mut negotiated_crc_mode = RpcChecksumMode::Crc32;
let mut handshake_ok = false; let mut handshake_ok = false;
while Instant::now() < deadline && !handshake_ok { while Instant::now() < deadline && !handshake_ok {
@ -375,17 +385,23 @@ impl MePool {
let frame = dec_buf.split_to(fl); let frame = dec_buf.split_to(fl);
let pe = fl - 4; let pe = fl - 4;
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
let ac = crate::crypto::crc32(&frame[..pe]); let ac = rpc_crc(RpcChecksumMode::Crc32, &frame[..pe]);
if ec != ac { if ec != ac {
return Err(ProxyError::InvalidHandshake(format!( return Err(ProxyError::InvalidHandshake(format!(
"HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}" "HS CRC mismatch: 0x{ec:08x} vs 0x{ac:08x}"
))); )));
} }
let hs_type = u32::from_le_bytes(frame[8..12].try_into().unwrap()); let hs_payload = &frame[8..pe];
if hs_payload.len() < 4 {
return Err(ProxyError::InvalidHandshake(
"Handshake payload too short".to_string(),
));
}
let hs_type = u32::from_le_bytes(hs_payload[0..4].try_into().unwrap());
if hs_type == RPC_HANDSHAKE_ERROR_U32 { if hs_type == RPC_HANDSHAKE_ERROR_U32 {
let err_code = if frame.len() >= 16 { let err_code = if hs_payload.len() >= 8 {
i32::from_le_bytes(frame[12..16].try_into().unwrap()) i32::from_le_bytes(hs_payload[4..8].try_into().unwrap())
} else { } else {
-1 -1
}; };
@ -393,11 +409,21 @@ impl MePool {
"ME rejected handshake (error={err_code})" "ME rejected handshake (error={err_code})"
))); )));
} }
if hs_type != RPC_HANDSHAKE_U32 { let hs_flags = parse_handshake_flags(hs_payload)?;
if hs_flags & 0xff != 0 {
return Err(ProxyError::InvalidHandshake(format!( return Err(ProxyError::InvalidHandshake(format!(
"Expected HANDSHAKE 0x{RPC_HANDSHAKE_U32:08x}, got 0x{hs_type:08x}" "Unsupported handshake flags: 0x{hs_flags:08x}"
))); )));
} }
negotiated_crc_mode = if (hs_flags & requested_crc_mode.advertised_flags()) != 0 {
RpcChecksumMode::from_handshake_flags(hs_flags)
} else if (hs_flags & rpc_crypto_flags::USE_CRC32C) != 0 {
return Err(ProxyError::InvalidHandshake(format!(
"Peer negotiated unsupported CRC flags: 0x{hs_flags:08x}"
)));
} else {
RpcChecksumMode::Crc32
};
handshake_ok = true; handshake_ok = true;
break; break;
@ -418,6 +444,7 @@ impl MePool {
read_iv, read_iv,
write_key: wk, write_key: wk,
write_iv, write_iv,
crc_mode: negotiated_crc_mode,
handshake_ms, handshake_ms,
}) })
} }

View File

@ -17,10 +17,9 @@ use crate::network::IpFamily;
use crate::protocol::constants::*; use crate::protocol::constants::*;
use super::ConnRegistry; use super::ConnRegistry;
use super::registry::{BoundConn, ConnMeta}; use super::registry::BoundConn;
use super::codec::{RpcWriter, WriterCommand}; use super::codec::{RpcWriter, WriterCommand};
use super::reader::reader_loop; use super::reader::reader_loop;
use super::MeResponse;
const ME_ACTIVE_PING_SECS: u64 = 25; const ME_ACTIVE_PING_SECS: u64 = 25;
const ME_ACTIVE_PING_JITTER_SECS: i64 = 5; const ME_ACTIVE_PING_JITTER_SECS: i64 = 5;
@ -417,12 +416,12 @@ impl MePool {
let draining = Arc::new(AtomicBool::new(false)); let draining = Arc::new(AtomicBool::new(false));
let (tx, mut rx) = mpsc::channel::<WriterCommand>(4096); let (tx, mut rx) = mpsc::channel::<WriterCommand>(4096);
let tx_for_keepalive = tx.clone(); let tx_for_keepalive = tx.clone();
let stats = self.stats.clone();
let mut rpc_writer = RpcWriter { let mut rpc_writer = RpcWriter {
writer: hs.wr, writer: hs.wr,
key: hs.write_key, key: hs.write_key,
iv: hs.write_iv, iv: hs.write_iv,
seq_no: 0, seq_no: 0,
crc_mode: hs.crc_mode,
}; };
let cancel_wr = cancel.clone(); let cancel_wr = cancel.clone();
tokio::spawn(async move { tokio::spawn(async move {
@ -436,17 +435,6 @@ impl MePool {
Some(WriterCommand::DataAndFlush(payload)) => { Some(WriterCommand::DataAndFlush(payload)) => {
if rpc_writer.send_and_flush(&payload).await.is_err() { break; } if rpc_writer.send_and_flush(&payload).await.is_err() { break; }
} }
Some(WriterCommand::Keepalive) => {
match rpc_writer.send_keepalive().await {
Ok(()) => {
stats.increment_me_keepalive_sent();
}
Err(_) => {
stats.increment_me_keepalive_failed();
break;
}
}
}
Some(WriterCommand::Close) | None => break, Some(WriterCommand::Close) | None => break,
} }
} }
@ -469,7 +457,11 @@ impl MePool {
let reg = self.registry.clone(); let reg = self.registry.clone();
let writers_arc = self.writers_arc(); let writers_arc = self.writers_arc();
let ping_tracker = self.ping_tracker.clone(); let ping_tracker = self.ping_tracker.clone();
let ping_tracker_reader = ping_tracker.clone();
let rtt_stats = self.rtt_stats.clone(); let rtt_stats = self.rtt_stats.clone();
let stats_reader = self.stats.clone();
let stats_ping = self.stats.clone();
let stats_keepalive = self.stats.clone();
let pool = Arc::downgrade(self); let pool = Arc::downgrade(self);
let cancel_ping = cancel.clone(); let cancel_ping = cancel.clone();
let tx_ping = tx.clone(); let tx_ping = tx.clone();
@ -489,12 +481,14 @@ impl MePool {
hs.rd, hs.rd,
hs.read_key, hs.read_key,
hs.read_iv, hs.read_iv,
hs.crc_mode,
reg.clone(), reg.clone(),
BytesMut::new(), BytesMut::new(),
BytesMut::new(), BytesMut::new(),
tx.clone(), tx.clone(),
ping_tracker.clone(), ping_tracker_reader,
rtt_stats.clone(), rtt_stats.clone(),
stats_reader,
writer_id, writer_id,
degraded.clone(), degraded.clone(),
cancel_reader_token.clone(), cancel_reader_token.clone(),
@ -535,7 +529,12 @@ impl MePool {
p.extend_from_slice(&sent_id.to_le_bytes()); p.extend_from_slice(&sent_id.to_le_bytes());
{ {
let mut tracker = ping_tracker_ping.lock().await; let mut tracker = ping_tracker_ping.lock().await;
let before = tracker.len();
tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120)); tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120));
let expired = before.saturating_sub(tracker.len());
if expired > 0 {
stats_ping.increment_me_keepalive_timeout_by(expired as u64);
}
tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); tracker.insert(sent_id, (std::time::Instant::now(), writer_id));
} }
ping_id = ping_id.wrapping_add(1); ping_id = ping_id.wrapping_add(1);
@ -558,18 +557,37 @@ impl MePool {
if keepalive_enabled { if keepalive_enabled {
let tx_keepalive = tx_for_keepalive; let tx_keepalive = tx_for_keepalive;
let cancel_keepalive = cancel_keepalive_token; let cancel_keepalive = cancel_keepalive_token;
let ping_tracker_keepalive = ping_tracker.clone();
tokio::spawn(async move { tokio::spawn(async move {
// Per-writer jittered start to avoid phase sync. // Per-writer jittered start to avoid phase sync.
let jitter_cap_ms = keepalive_interval.as_millis() / 2; let jitter_cap_ms = keepalive_interval.as_millis() / 2;
let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1);
let initial_jitter_ms = rand::rng().random_range(0..=effective_jitter_ms as u64); let initial_jitter_ms = rand::rng().random_range(0..=effective_jitter_ms as u64);
tokio::time::sleep(Duration::from_millis(initial_jitter_ms)).await; tokio::time::sleep(Duration::from_millis(initial_jitter_ms)).await;
let mut ping_id: i64 = rand::random::<i64>();
loop { loop {
tokio::select! { tokio::select! {
_ = cancel_keepalive.cancelled() => break, _ = cancel_keepalive.cancelled() => break,
_ = tokio::time::sleep(keepalive_interval + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64))) => {} _ = tokio::time::sleep(keepalive_interval + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64))) => {}
} }
if tx_keepalive.send(WriterCommand::Keepalive).await.is_err() { let sent_id = ping_id;
ping_id = ping_id.wrapping_add(1);
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_PING_U32.to_le_bytes());
p.extend_from_slice(&sent_id.to_le_bytes());
{
let mut tracker = ping_tracker_keepalive.lock().await;
let before = tracker.len();
tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120));
let expired = before.saturating_sub(tracker.len());
if expired > 0 {
stats_keepalive.increment_me_keepalive_timeout_by(expired as u64);
}
tracker.insert(sent_id, (std::time::Instant::now(), writer_id));
}
stats_keepalive.increment_me_keepalive_sent();
if tx_keepalive.send(WriterCommand::DataAndFlush(p)).await.is_err() {
stats_keepalive.increment_me_keepalive_failed();
break; break;
} }
} }

View File

@ -10,30 +10,33 @@ use tokio::sync::{Mutex, mpsc};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use tracing::{debug, trace, warn}; use tracing::{debug, trace, warn};
use crate::crypto::{AesCbc, crc32}; use crate::crypto::AesCbc;
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::stats::Stats;
use super::codec::WriterCommand; use super::codec::{RpcChecksumMode, WriterCommand, rpc_crc};
use super::registry::RouteResult;
use super::{ConnRegistry, MeResponse}; use super::{ConnRegistry, MeResponse};
pub(crate) async fn reader_loop( pub(crate) async fn reader_loop(
mut rd: tokio::io::ReadHalf<TcpStream>, mut rd: tokio::io::ReadHalf<TcpStream>,
dk: [u8; 32], dk: [u8; 32],
mut div: [u8; 16], mut div: [u8; 16],
crc_mode: RpcChecksumMode,
reg: Arc<ConnRegistry>, reg: Arc<ConnRegistry>,
enc_leftover: BytesMut, enc_leftover: BytesMut,
mut dec: BytesMut, mut dec: BytesMut,
tx: mpsc::Sender<WriterCommand>, tx: mpsc::Sender<WriterCommand>,
ping_tracker: Arc<Mutex<HashMap<i64, (Instant, u64)>>>, ping_tracker: Arc<Mutex<HashMap<i64, (Instant, u64)>>>,
rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>, rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
stats: Arc<Stats>,
_writer_id: u64, _writer_id: u64,
degraded: Arc<AtomicBool>, degraded: Arc<AtomicBool>,
cancel: CancellationToken, cancel: CancellationToken,
) -> Result<()> { ) -> Result<()> {
let mut raw = enc_leftover; let mut raw = enc_leftover;
let mut expected_seq: i32 = 0; let mut expected_seq: i32 = 0;
let mut seq_mismatch = 0u32;
loop { loop {
let mut tmp = [0u8; 16_384]; let mut tmp = [0u8; 16_384];
@ -79,8 +82,9 @@ pub(crate) async fn reader_loop(
let frame = dec.split_to(fl); let frame = dec.split_to(fl);
let pe = fl - 4; let pe = fl - 4;
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap()); let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
let actual_crc = crc32(&frame[..pe]); let actual_crc = rpc_crc(crc_mode, &frame[..pe]);
if actual_crc != ec { if actual_crc != ec {
stats.increment_me_crc_mismatch();
warn!( warn!(
frame_len = fl, frame_len = fl,
expected_crc = format_args!("0x{ec:08x}"), expected_crc = format_args!("0x{ec:08x}"),
@ -92,15 +96,14 @@ pub(crate) async fn reader_loop(
let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap()); let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap());
if seq_no != expected_seq { if seq_no != expected_seq {
stats.increment_me_seq_mismatch();
warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch"); warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch");
seq_mismatch += 1; return Err(ProxyError::SeqNoMismatch {
if seq_mismatch > 10 { expected: expected_seq,
return Err(ProxyError::Proxy("Too many seq mismatches".into())); got: seq_no,
});
} }
expected_seq = seq_no.wrapping_add(1);
} else {
expected_seq = expected_seq.wrapping_add(1); expected_seq = expected_seq.wrapping_add(1);
}
let payload = &frame[8..pe]; let payload = &frame[8..pe];
if payload.len() < 4 { if payload.len() < 4 {
@ -117,7 +120,13 @@ pub(crate) async fn reader_loop(
trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS"); trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS");
let routed = reg.route(cid, MeResponse::Data { flags, data }).await; let routed = reg.route(cid, MeResponse::Data { flags, data }).await;
if !routed { if !matches!(routed, RouteResult::Routed) {
match routed {
RouteResult::NoConn => stats.increment_me_route_drop_no_conn(),
RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(),
RouteResult::QueueFull => stats.increment_me_route_drop_queue_full(),
RouteResult::Routed => {}
}
reg.unregister(cid).await; reg.unregister(cid).await;
send_close_conn(&tx, cid).await; send_close_conn(&tx, cid).await;
} }
@ -127,7 +136,13 @@ pub(crate) async fn reader_loop(
trace!(cid, cfm, "RPC_SIMPLE_ACK"); trace!(cid, cfm, "RPC_SIMPLE_ACK");
let routed = reg.route(cid, MeResponse::Ack(cfm)).await; let routed = reg.route(cid, MeResponse::Ack(cfm)).await;
if !routed { if !matches!(routed, RouteResult::Routed) {
match routed {
RouteResult::NoConn => stats.increment_me_route_drop_no_conn(),
RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(),
RouteResult::QueueFull => stats.increment_me_route_drop_queue_full(),
RouteResult::Routed => {}
}
reg.unregister(cid).await; reg.unregister(cid).await;
send_close_conn(&tx, cid).await; send_close_conn(&tx, cid).await;
} }
@ -153,6 +168,7 @@ pub(crate) async fn reader_loop(
} }
} else if pt == RPC_PONG_U32 && body.len() >= 8 { } else if pt == RPC_PONG_U32 && body.len() >= 8 {
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap()); let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
stats.increment_me_keepalive_pong();
if let Some((sent, wid)) = { if let Some((sent, wid)) = {
let mut guard = ping_tracker.lock().await; let mut guard = ping_tracker.lock().await;
guard.remove(&ping_id) guard.remove(&ping_id)

View File

@ -1,13 +1,23 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex, RwLock}; use tokio::sync::{mpsc, RwLock};
use tokio::sync::mpsc::error::TrySendError;
use super::codec::WriterCommand; use super::codec::WriterCommand;
use super::MeResponse; use super::MeResponse;
const ROUTE_CHANNEL_CAPACITY: usize = 4096;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RouteResult {
Routed,
NoConn,
ChannelClosed,
QueueFull,
}
#[derive(Clone)] #[derive(Clone)]
pub struct ConnMeta { pub struct ConnMeta {
pub target_dc: i16, pub target_dc: i16,
@ -64,7 +74,7 @@ impl ConnRegistry {
pub async fn register(&self) -> (u64, mpsc::Receiver<MeResponse>) { pub async fn register(&self) -> (u64, mpsc::Receiver<MeResponse>) {
let id = self.next_id.fetch_add(1, Ordering::Relaxed); let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = mpsc::channel(1024); let (tx, rx) = mpsc::channel(ROUTE_CHANNEL_CAPACITY);
self.inner.write().await.map.insert(id, tx); self.inner.write().await.map.insert(id, tx);
(id, rx) (id, rx)
} }
@ -83,12 +93,16 @@ impl ConnRegistry {
None None
} }
pub async fn route(&self, id: u64, resp: MeResponse) -> bool { pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult {
let inner = self.inner.read().await; let inner = self.inner.read().await;
if let Some(tx) = inner.map.get(&id) { if let Some(tx) = inner.map.get(&id) {
tx.try_send(resp).is_ok() match tx.try_send(resp) {
Ok(()) => RouteResult::Routed,
Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed,
Err(TrySendError::Full(_)) => RouteResult::QueueFull,
}
} else { } else {
false RouteResult::NoConn
} }
} }