telemt/src/transport/middle_proxy/reader.rs

424 lines
16 KiB
Rust

#![allow(clippy::too_many_arguments)]
use std::collections::HashMap;
use std::io::ErrorKind;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::time::Instant;
use bytes::{Bytes, BytesMut};
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::{Mutex, mpsc};
use tokio_util::sync::CancellationToken;
use tracing::{debug, trace, warn};
use crate::crypto::AesCbc;
use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
use crate::stats::Stats;
use super::codec::{RpcChecksumMode, WriterCommand, rpc_crc};
use super::registry::RouteResult;
use super::{ConnRegistry, MeResponse};
const DATA_ROUTE_MAX_ATTEMPTS: usize = 3;
const DATA_ROUTE_QUEUE_FULL_STARVATION_THRESHOLD: u8 = 3;
fn should_close_on_route_result_for_data(result: RouteResult) -> bool {
matches!(result, RouteResult::NoConn | RouteResult::ChannelClosed)
}
fn should_close_on_route_result_for_ack(result: RouteResult) -> bool {
matches!(result, RouteResult::NoConn | RouteResult::ChannelClosed)
}
fn is_data_route_queue_full(result: RouteResult) -> bool {
matches!(
result,
RouteResult::QueueFullBase | RouteResult::QueueFullHigh
)
}
fn should_close_on_queue_full_streak(streak: u8) -> bool {
streak >= DATA_ROUTE_QUEUE_FULL_STARVATION_THRESHOLD
}
async fn route_data_with_retry(
reg: &ConnRegistry,
conn_id: u64,
flags: u32,
data: Bytes,
timeout_ms: u64,
) -> RouteResult {
let mut attempt = 0usize;
loop {
let routed = reg
.route_with_timeout(
conn_id,
MeResponse::Data {
flags,
data: data.clone(),
},
timeout_ms,
)
.await;
match routed {
RouteResult::QueueFullBase | RouteResult::QueueFullHigh => {
attempt = attempt.saturating_add(1);
if attempt >= DATA_ROUTE_MAX_ATTEMPTS {
return routed;
}
tokio::task::yield_now().await;
}
_ => return routed,
}
}
}
pub(crate) async fn reader_loop(
mut rd: tokio::io::ReadHalf<TcpStream>,
dk: [u8; 32],
mut div: [u8; 16],
crc_mode: RpcChecksumMode,
reg: Arc<ConnRegistry>,
enc_leftover: BytesMut,
mut dec: BytesMut,
tx: mpsc::Sender<WriterCommand>,
ping_tracker: Arc<Mutex<HashMap<i64, Instant>>>,
rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
stats: Arc<Stats>,
writer_id: u64,
degraded: Arc<AtomicBool>,
writer_rtt_ema_ms_x10: Arc<AtomicU32>,
reader_route_data_wait_ms: Arc<AtomicU64>,
cancel: CancellationToken,
) -> Result<()> {
let mut raw = enc_leftover;
let mut expected_seq: i32 = 0;
let mut data_route_queue_full_streak = HashMap::<u64, u8>::new();
loop {
let mut tmp = [0u8; 65_536];
let n = tokio::select! {
res = rd.read(&mut tmp) => res.map_err(ProxyError::Io)?,
_ = cancel.cancelled() => return Ok(()),
};
if n == 0 {
stats.increment_me_reader_eof_total();
return Err(ProxyError::Io(std::io::Error::new(
ErrorKind::UnexpectedEof,
"ME socket closed by peer",
)));
}
raw.extend_from_slice(&tmp[..n]);
let blocks = raw.len() / 16 * 16;
if blocks > 0 {
let mut chunk = raw.split_to(blocks);
let mut new_iv = [0u8; 16];
new_iv.copy_from_slice(&chunk[blocks - 16..blocks]);
AesCbc::new(dk, div)
.decrypt_in_place(&mut chunk[..])
.map_err(|e| ProxyError::Crypto(format!("{e}")))?;
div = new_iv;
dec.extend_from_slice(&chunk);
}
while dec.len() >= 12 {
let fl = u32::from_le_bytes(dec[0..4].try_into().unwrap()) as usize;
if fl == 4 {
let _ = dec.split_to(4);
continue;
}
if !(12..=(1 << 24)).contains(&fl) {
warn!(frame_len = fl, "Invalid RPC frame len");
dec.clear();
break;
}
if dec.len() < fl {
break;
}
let frame = dec.split_to(fl).freeze();
let pe = fl - 4;
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
let actual_crc = rpc_crc(crc_mode, &frame[..pe]);
if actual_crc != ec {
stats.increment_me_crc_mismatch();
warn!(
frame_len = fl,
expected_crc = format_args!("0x{ec:08x}"),
actual_crc = format_args!("0x{actual_crc:08x}"),
"CRC mismatch — CBC crypto desync, aborting ME connection"
);
return Err(ProxyError::Proxy("CRC mismatch (crypto desync)".into()));
}
let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap());
if seq_no != expected_seq {
stats.increment_me_seq_mismatch();
warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch");
return Err(ProxyError::SeqNoMismatch {
expected: expected_seq,
got: seq_no,
});
}
expected_seq = expected_seq.wrapping_add(1);
let payload = frame.slice(8..pe);
if payload.len() < 4 {
continue;
}
let pt = u32::from_le_bytes(payload[0..4].try_into().unwrap());
let body = payload.slice(4..);
if pt == RPC_PROXY_ANS_U32 && body.len() >= 12 {
let flags = u32::from_le_bytes(body[0..4].try_into().unwrap());
let cid = u64::from_le_bytes(body[4..12].try_into().unwrap());
let data = body.slice(12..);
trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS");
let route_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed);
let routed =
route_data_with_retry(reg.as_ref(), cid, flags, data, route_wait_ms).await;
if matches!(routed, RouteResult::Routed) {
data_route_queue_full_streak.remove(&cid);
continue;
}
match routed {
RouteResult::NoConn => stats.increment_me_route_drop_no_conn(),
RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(),
RouteResult::QueueFullBase => {
stats.increment_me_route_drop_queue_full();
stats.increment_me_route_drop_queue_full_base();
}
RouteResult::QueueFullHigh => {
stats.increment_me_route_drop_queue_full();
stats.increment_me_route_drop_queue_full_high();
}
RouteResult::Routed => {}
}
if should_close_on_route_result_for_data(routed) {
data_route_queue_full_streak.remove(&cid);
reg.unregister(cid).await;
send_close_conn(&tx, cid).await;
continue;
}
if is_data_route_queue_full(routed) {
let streak = data_route_queue_full_streak.entry(cid).or_insert(0);
*streak = streak.saturating_add(1);
if should_close_on_queue_full_streak(*streak) {
data_route_queue_full_streak.remove(&cid);
reg.unregister(cid).await;
send_close_conn(&tx, cid).await;
}
}
} else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
let cfm = u32::from_le_bytes(body[8..12].try_into().unwrap());
trace!(cid, cfm, "RPC_SIMPLE_ACK");
let routed = reg.route_nowait(cid, MeResponse::Ack(cfm)).await;
if !matches!(routed, RouteResult::Routed) {
match routed {
RouteResult::NoConn => stats.increment_me_route_drop_no_conn(),
RouteResult::ChannelClosed => {
stats.increment_me_route_drop_channel_closed()
}
RouteResult::QueueFullBase => {
stats.increment_me_route_drop_queue_full();
stats.increment_me_route_drop_queue_full_base();
}
RouteResult::QueueFullHigh => {
stats.increment_me_route_drop_queue_full();
stats.increment_me_route_drop_queue_full_high();
}
RouteResult::Routed => {}
}
if should_close_on_route_result_for_ack(routed) {
reg.unregister(cid).await;
send_close_conn(&tx, cid).await;
}
}
} else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
debug!(cid, "RPC_CLOSE_EXT from ME");
let _ = reg.route_nowait(cid, MeResponse::Close).await;
reg.unregister(cid).await;
data_route_queue_full_streak.remove(&cid);
} else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
debug!(cid, "RPC_CLOSE_CONN from ME");
let _ = reg.route_nowait(cid, MeResponse::Close).await;
reg.unregister(cid).await;
data_route_queue_full_streak.remove(&cid);
} else if pt == RPC_PING_U32 && body.len() >= 8 {
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
trace!(ping_id, "RPC_PING -> RPC_PONG");
let mut pong = Vec::with_capacity(12);
pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes());
pong.extend_from_slice(&ping_id.to_le_bytes());
match tx.try_send(WriterCommand::DataAndFlush(Bytes::from(pong))) {
Ok(()) => {}
Err(TrySendError::Full(_)) => {
debug!(ping_id, "PONG dropped: writer command channel is full");
}
Err(TrySendError::Closed(_)) => {
warn!("PONG send failed: writer channel closed");
break;
}
}
} else if pt == RPC_PONG_U32 && body.len() >= 8 {
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
stats.increment_me_keepalive_pong();
if let Some(sent) = {
let mut guard = ping_tracker.lock().await;
guard.remove(&ping_id)
} {
let rtt = sent.elapsed().as_secs_f64() * 1000.0;
let mut stats = rtt_stats.lock().await;
let entry = stats.entry(writer_id).or_insert((rtt, rtt));
entry.1 = entry.1 * 0.8 + rtt * 0.2;
if rtt < entry.0 {
entry.0 = rtt;
} else {
// allow slow baseline drift upward to avoid stale minimum
entry.0 = entry.0 * 0.99 + rtt * 0.01;
}
let degraded_now = entry.1 > entry.0 * 2.0;
degraded.store(degraded_now, Ordering::Relaxed);
writer_rtt_ema_ms_x10.store(
(entry.1 * 10.0).clamp(0.0, u32::MAX as f64) as u32,
Ordering::Relaxed,
);
trace!(
writer_id,
rtt_ms = rtt,
ema_ms = entry.1,
base_ms = entry.0,
degraded = degraded_now,
"ME RTT sample"
);
}
} else {
debug!(
rpc_type = format_args!("0x{pt:08x}"),
len = body.len(),
"Unknown RPC"
);
}
}
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use crate::transport::middle_proxy::ConnRegistry;
use super::{
MeResponse, RouteResult, is_data_route_queue_full, route_data_with_retry,
should_close_on_queue_full_streak, should_close_on_route_result_for_ack,
should_close_on_route_result_for_data,
};
#[test]
fn data_route_only_fatal_results_close_immediately() {
assert!(!should_close_on_route_result_for_data(RouteResult::Routed));
assert!(!should_close_on_route_result_for_data(
RouteResult::QueueFullBase
));
assert!(!should_close_on_route_result_for_data(
RouteResult::QueueFullHigh
));
assert!(should_close_on_route_result_for_data(RouteResult::NoConn));
assert!(should_close_on_route_result_for_data(
RouteResult::ChannelClosed
));
}
#[test]
fn data_route_queue_full_uses_starvation_threshold() {
assert!(is_data_route_queue_full(RouteResult::QueueFullBase));
assert!(is_data_route_queue_full(RouteResult::QueueFullHigh));
assert!(!is_data_route_queue_full(RouteResult::NoConn));
assert!(!should_close_on_queue_full_streak(1));
assert!(!should_close_on_queue_full_streak(2));
assert!(should_close_on_queue_full_streak(3));
assert!(should_close_on_queue_full_streak(u8::MAX));
}
#[test]
fn ack_queue_full_is_soft_dropped_without_forced_close() {
assert!(!should_close_on_route_result_for_ack(RouteResult::Routed));
assert!(!should_close_on_route_result_for_ack(
RouteResult::QueueFullBase
));
assert!(!should_close_on_route_result_for_ack(
RouteResult::QueueFullHigh
));
assert!(should_close_on_route_result_for_ack(RouteResult::NoConn));
assert!(should_close_on_route_result_for_ack(
RouteResult::ChannelClosed
));
}
#[tokio::test]
async fn route_data_with_retry_returns_routed_when_channel_has_capacity() {
let reg = ConnRegistry::with_route_channel_capacity(1);
let (conn_id, mut rx) = reg.register().await;
let routed = route_data_with_retry(&reg, conn_id, 0, Bytes::from_static(b"a"), 20).await;
assert!(matches!(routed, RouteResult::Routed));
match rx.recv().await {
Some(MeResponse::Data { flags, data }) => {
assert_eq!(flags, 0);
assert_eq!(data, Bytes::from_static(b"a"));
}
other => panic!("expected routed data response, got {other:?}"),
}
}
#[tokio::test]
async fn route_data_with_retry_stops_after_bounded_attempts() {
let reg = ConnRegistry::with_route_channel_capacity(1);
let (conn_id, _rx) = reg.register().await;
assert!(matches!(
reg.route_nowait(conn_id, MeResponse::Ack(1)).await,
RouteResult::Routed
));
let routed = route_data_with_retry(&reg, conn_id, 0, Bytes::from_static(b"a"), 0).await;
assert!(matches!(
routed,
RouteResult::QueueFullBase | RouteResult::QueueFullHigh
));
}
}
async fn send_close_conn(tx: &mpsc::Sender<WriterCommand>, conn_id: u64) {
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes());
match tx.try_send(WriterCommand::DataAndFlush(Bytes::from(p))) {
Ok(()) => {}
Err(TrySendError::Full(_)) => {
debug!(
conn_id,
"ME close_conn signal skipped: writer command channel is full"
);
}
Err(TrySendError::Closed(_)) => {
debug!(
conn_id,
"ME close_conn signal skipped: writer command channel is closed"
);
}
}
}