mirror of https://github.com/telemt/telemt.git
424 lines
16 KiB
Rust
424 lines
16 KiB
Rust
#![allow(clippy::too_many_arguments)]
|
|
|
|
use std::collections::HashMap;
|
|
use std::io::ErrorKind;
|
|
use std::sync::Arc;
|
|
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
|
|
use std::time::Instant;
|
|
|
|
use bytes::{Bytes, BytesMut};
|
|
use tokio::io::AsyncReadExt;
|
|
use tokio::net::TcpStream;
|
|
use tokio::sync::mpsc::error::TrySendError;
|
|
use tokio::sync::{Mutex, mpsc};
|
|
use tokio_util::sync::CancellationToken;
|
|
use tracing::{debug, trace, warn};
|
|
|
|
use crate::crypto::AesCbc;
|
|
use crate::error::{ProxyError, Result};
|
|
use crate::protocol::constants::*;
|
|
use crate::stats::Stats;
|
|
|
|
use super::codec::{RpcChecksumMode, WriterCommand, rpc_crc};
|
|
use super::registry::RouteResult;
|
|
use super::{ConnRegistry, MeResponse};
|
|
|
|
const DATA_ROUTE_MAX_ATTEMPTS: usize = 3;
|
|
const DATA_ROUTE_QUEUE_FULL_STARVATION_THRESHOLD: u8 = 3;
|
|
|
|
fn should_close_on_route_result_for_data(result: RouteResult) -> bool {
|
|
matches!(result, RouteResult::NoConn | RouteResult::ChannelClosed)
|
|
}
|
|
|
|
fn should_close_on_route_result_for_ack(result: RouteResult) -> bool {
|
|
matches!(result, RouteResult::NoConn | RouteResult::ChannelClosed)
|
|
}
|
|
|
|
fn is_data_route_queue_full(result: RouteResult) -> bool {
|
|
matches!(
|
|
result,
|
|
RouteResult::QueueFullBase | RouteResult::QueueFullHigh
|
|
)
|
|
}
|
|
|
|
fn should_close_on_queue_full_streak(streak: u8) -> bool {
|
|
streak >= DATA_ROUTE_QUEUE_FULL_STARVATION_THRESHOLD
|
|
}
|
|
|
|
async fn route_data_with_retry(
|
|
reg: &ConnRegistry,
|
|
conn_id: u64,
|
|
flags: u32,
|
|
data: Bytes,
|
|
timeout_ms: u64,
|
|
) -> RouteResult {
|
|
let mut attempt = 0usize;
|
|
loop {
|
|
let routed = reg
|
|
.route_with_timeout(
|
|
conn_id,
|
|
MeResponse::Data {
|
|
flags,
|
|
data: data.clone(),
|
|
},
|
|
timeout_ms,
|
|
)
|
|
.await;
|
|
match routed {
|
|
RouteResult::QueueFullBase | RouteResult::QueueFullHigh => {
|
|
attempt = attempt.saturating_add(1);
|
|
if attempt >= DATA_ROUTE_MAX_ATTEMPTS {
|
|
return routed;
|
|
}
|
|
tokio::task::yield_now().await;
|
|
}
|
|
_ => return routed,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(crate) async fn reader_loop(
|
|
mut rd: tokio::io::ReadHalf<TcpStream>,
|
|
dk: [u8; 32],
|
|
mut div: [u8; 16],
|
|
crc_mode: RpcChecksumMode,
|
|
reg: Arc<ConnRegistry>,
|
|
enc_leftover: BytesMut,
|
|
mut dec: BytesMut,
|
|
tx: mpsc::Sender<WriterCommand>,
|
|
ping_tracker: Arc<Mutex<HashMap<i64, Instant>>>,
|
|
rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
|
|
stats: Arc<Stats>,
|
|
writer_id: u64,
|
|
degraded: Arc<AtomicBool>,
|
|
writer_rtt_ema_ms_x10: Arc<AtomicU32>,
|
|
reader_route_data_wait_ms: Arc<AtomicU64>,
|
|
cancel: CancellationToken,
|
|
) -> Result<()> {
|
|
let mut raw = enc_leftover;
|
|
let mut expected_seq: i32 = 0;
|
|
let mut data_route_queue_full_streak = HashMap::<u64, u8>::new();
|
|
|
|
loop {
|
|
let mut tmp = [0u8; 65_536];
|
|
let n = tokio::select! {
|
|
res = rd.read(&mut tmp) => res.map_err(ProxyError::Io)?,
|
|
_ = cancel.cancelled() => return Ok(()),
|
|
};
|
|
if n == 0 {
|
|
stats.increment_me_reader_eof_total();
|
|
return Err(ProxyError::Io(std::io::Error::new(
|
|
ErrorKind::UnexpectedEof,
|
|
"ME socket closed by peer",
|
|
)));
|
|
}
|
|
raw.extend_from_slice(&tmp[..n]);
|
|
|
|
let blocks = raw.len() / 16 * 16;
|
|
if blocks > 0 {
|
|
let mut chunk = raw.split_to(blocks);
|
|
let mut new_iv = [0u8; 16];
|
|
new_iv.copy_from_slice(&chunk[blocks - 16..blocks]);
|
|
AesCbc::new(dk, div)
|
|
.decrypt_in_place(&mut chunk[..])
|
|
.map_err(|e| ProxyError::Crypto(format!("{e}")))?;
|
|
div = new_iv;
|
|
dec.extend_from_slice(&chunk);
|
|
}
|
|
|
|
while dec.len() >= 12 {
|
|
let fl = u32::from_le_bytes(dec[0..4].try_into().unwrap()) as usize;
|
|
if fl == 4 {
|
|
let _ = dec.split_to(4);
|
|
continue;
|
|
}
|
|
if !(12..=(1 << 24)).contains(&fl) {
|
|
warn!(frame_len = fl, "Invalid RPC frame len");
|
|
dec.clear();
|
|
break;
|
|
}
|
|
if dec.len() < fl {
|
|
break;
|
|
}
|
|
|
|
let frame = dec.split_to(fl).freeze();
|
|
let pe = fl - 4;
|
|
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
|
|
let actual_crc = rpc_crc(crc_mode, &frame[..pe]);
|
|
if actual_crc != ec {
|
|
stats.increment_me_crc_mismatch();
|
|
warn!(
|
|
frame_len = fl,
|
|
expected_crc = format_args!("0x{ec:08x}"),
|
|
actual_crc = format_args!("0x{actual_crc:08x}"),
|
|
"CRC mismatch — CBC crypto desync, aborting ME connection"
|
|
);
|
|
return Err(ProxyError::Proxy("CRC mismatch (crypto desync)".into()));
|
|
}
|
|
|
|
let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap());
|
|
if seq_no != expected_seq {
|
|
stats.increment_me_seq_mismatch();
|
|
warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch");
|
|
return Err(ProxyError::SeqNoMismatch {
|
|
expected: expected_seq,
|
|
got: seq_no,
|
|
});
|
|
}
|
|
expected_seq = expected_seq.wrapping_add(1);
|
|
|
|
let payload = frame.slice(8..pe);
|
|
if payload.len() < 4 {
|
|
continue;
|
|
}
|
|
|
|
let pt = u32::from_le_bytes(payload[0..4].try_into().unwrap());
|
|
let body = payload.slice(4..);
|
|
|
|
if pt == RPC_PROXY_ANS_U32 && body.len() >= 12 {
|
|
let flags = u32::from_le_bytes(body[0..4].try_into().unwrap());
|
|
let cid = u64::from_le_bytes(body[4..12].try_into().unwrap());
|
|
let data = body.slice(12..);
|
|
trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS");
|
|
|
|
let route_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed);
|
|
let routed =
|
|
route_data_with_retry(reg.as_ref(), cid, flags, data, route_wait_ms).await;
|
|
if matches!(routed, RouteResult::Routed) {
|
|
data_route_queue_full_streak.remove(&cid);
|
|
continue;
|
|
}
|
|
match routed {
|
|
RouteResult::NoConn => stats.increment_me_route_drop_no_conn(),
|
|
RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(),
|
|
RouteResult::QueueFullBase => {
|
|
stats.increment_me_route_drop_queue_full();
|
|
stats.increment_me_route_drop_queue_full_base();
|
|
}
|
|
RouteResult::QueueFullHigh => {
|
|
stats.increment_me_route_drop_queue_full();
|
|
stats.increment_me_route_drop_queue_full_high();
|
|
}
|
|
RouteResult::Routed => {}
|
|
}
|
|
if should_close_on_route_result_for_data(routed) {
|
|
data_route_queue_full_streak.remove(&cid);
|
|
reg.unregister(cid).await;
|
|
send_close_conn(&tx, cid).await;
|
|
continue;
|
|
}
|
|
if is_data_route_queue_full(routed) {
|
|
let streak = data_route_queue_full_streak.entry(cid).or_insert(0);
|
|
*streak = streak.saturating_add(1);
|
|
if should_close_on_queue_full_streak(*streak) {
|
|
data_route_queue_full_streak.remove(&cid);
|
|
reg.unregister(cid).await;
|
|
send_close_conn(&tx, cid).await;
|
|
}
|
|
}
|
|
} else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 {
|
|
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
|
|
let cfm = u32::from_le_bytes(body[8..12].try_into().unwrap());
|
|
trace!(cid, cfm, "RPC_SIMPLE_ACK");
|
|
|
|
let routed = reg.route_nowait(cid, MeResponse::Ack(cfm)).await;
|
|
if !matches!(routed, RouteResult::Routed) {
|
|
match routed {
|
|
RouteResult::NoConn => stats.increment_me_route_drop_no_conn(),
|
|
RouteResult::ChannelClosed => {
|
|
stats.increment_me_route_drop_channel_closed()
|
|
}
|
|
RouteResult::QueueFullBase => {
|
|
stats.increment_me_route_drop_queue_full();
|
|
stats.increment_me_route_drop_queue_full_base();
|
|
}
|
|
RouteResult::QueueFullHigh => {
|
|
stats.increment_me_route_drop_queue_full();
|
|
stats.increment_me_route_drop_queue_full_high();
|
|
}
|
|
RouteResult::Routed => {}
|
|
}
|
|
if should_close_on_route_result_for_ack(routed) {
|
|
reg.unregister(cid).await;
|
|
send_close_conn(&tx, cid).await;
|
|
}
|
|
}
|
|
} else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 {
|
|
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
|
|
debug!(cid, "RPC_CLOSE_EXT from ME");
|
|
let _ = reg.route_nowait(cid, MeResponse::Close).await;
|
|
reg.unregister(cid).await;
|
|
data_route_queue_full_streak.remove(&cid);
|
|
} else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 {
|
|
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
|
|
debug!(cid, "RPC_CLOSE_CONN from ME");
|
|
let _ = reg.route_nowait(cid, MeResponse::Close).await;
|
|
reg.unregister(cid).await;
|
|
data_route_queue_full_streak.remove(&cid);
|
|
} else if pt == RPC_PING_U32 && body.len() >= 8 {
|
|
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
|
|
trace!(ping_id, "RPC_PING -> RPC_PONG");
|
|
let mut pong = Vec::with_capacity(12);
|
|
pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes());
|
|
pong.extend_from_slice(&ping_id.to_le_bytes());
|
|
match tx.try_send(WriterCommand::DataAndFlush(Bytes::from(pong))) {
|
|
Ok(()) => {}
|
|
Err(TrySendError::Full(_)) => {
|
|
debug!(ping_id, "PONG dropped: writer command channel is full");
|
|
}
|
|
Err(TrySendError::Closed(_)) => {
|
|
warn!("PONG send failed: writer channel closed");
|
|
break;
|
|
}
|
|
}
|
|
} else if pt == RPC_PONG_U32 && body.len() >= 8 {
|
|
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
|
|
stats.increment_me_keepalive_pong();
|
|
if let Some(sent) = {
|
|
let mut guard = ping_tracker.lock().await;
|
|
guard.remove(&ping_id)
|
|
} {
|
|
let rtt = sent.elapsed().as_secs_f64() * 1000.0;
|
|
let mut stats = rtt_stats.lock().await;
|
|
let entry = stats.entry(writer_id).or_insert((rtt, rtt));
|
|
entry.1 = entry.1 * 0.8 + rtt * 0.2;
|
|
if rtt < entry.0 {
|
|
entry.0 = rtt;
|
|
} else {
|
|
// allow slow baseline drift upward to avoid stale minimum
|
|
entry.0 = entry.0 * 0.99 + rtt * 0.01;
|
|
}
|
|
let degraded_now = entry.1 > entry.0 * 2.0;
|
|
degraded.store(degraded_now, Ordering::Relaxed);
|
|
writer_rtt_ema_ms_x10.store(
|
|
(entry.1 * 10.0).clamp(0.0, u32::MAX as f64) as u32,
|
|
Ordering::Relaxed,
|
|
);
|
|
trace!(
|
|
writer_id,
|
|
rtt_ms = rtt,
|
|
ema_ms = entry.1,
|
|
base_ms = entry.0,
|
|
degraded = degraded_now,
|
|
"ME RTT sample"
|
|
);
|
|
}
|
|
} else {
|
|
debug!(
|
|
rpc_type = format_args!("0x{pt:08x}"),
|
|
len = body.len(),
|
|
"Unknown RPC"
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use bytes::Bytes;
|
|
|
|
use crate::transport::middle_proxy::ConnRegistry;
|
|
|
|
use super::{
|
|
MeResponse, RouteResult, is_data_route_queue_full, route_data_with_retry,
|
|
should_close_on_queue_full_streak, should_close_on_route_result_for_ack,
|
|
should_close_on_route_result_for_data,
|
|
};
|
|
|
|
#[test]
|
|
fn data_route_only_fatal_results_close_immediately() {
|
|
assert!(!should_close_on_route_result_for_data(RouteResult::Routed));
|
|
assert!(!should_close_on_route_result_for_data(
|
|
RouteResult::QueueFullBase
|
|
));
|
|
assert!(!should_close_on_route_result_for_data(
|
|
RouteResult::QueueFullHigh
|
|
));
|
|
assert!(should_close_on_route_result_for_data(RouteResult::NoConn));
|
|
assert!(should_close_on_route_result_for_data(
|
|
RouteResult::ChannelClosed
|
|
));
|
|
}
|
|
|
|
#[test]
|
|
fn data_route_queue_full_uses_starvation_threshold() {
|
|
assert!(is_data_route_queue_full(RouteResult::QueueFullBase));
|
|
assert!(is_data_route_queue_full(RouteResult::QueueFullHigh));
|
|
assert!(!is_data_route_queue_full(RouteResult::NoConn));
|
|
assert!(!should_close_on_queue_full_streak(1));
|
|
assert!(!should_close_on_queue_full_streak(2));
|
|
assert!(should_close_on_queue_full_streak(3));
|
|
assert!(should_close_on_queue_full_streak(u8::MAX));
|
|
}
|
|
|
|
#[test]
|
|
fn ack_queue_full_is_soft_dropped_without_forced_close() {
|
|
assert!(!should_close_on_route_result_for_ack(RouteResult::Routed));
|
|
assert!(!should_close_on_route_result_for_ack(
|
|
RouteResult::QueueFullBase
|
|
));
|
|
assert!(!should_close_on_route_result_for_ack(
|
|
RouteResult::QueueFullHigh
|
|
));
|
|
assert!(should_close_on_route_result_for_ack(RouteResult::NoConn));
|
|
assert!(should_close_on_route_result_for_ack(
|
|
RouteResult::ChannelClosed
|
|
));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn route_data_with_retry_returns_routed_when_channel_has_capacity() {
|
|
let reg = ConnRegistry::with_route_channel_capacity(1);
|
|
let (conn_id, mut rx) = reg.register().await;
|
|
|
|
let routed = route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 20).await;
|
|
assert!(matches!(routed, RouteResult::Routed));
|
|
match rx.recv().await {
|
|
Some(MeResponse::Data { flags, data }) => {
|
|
assert_eq!(flags, 0);
|
|
assert_eq!(data, Bytes::from_static(b"a"));
|
|
}
|
|
other => panic!("expected routed data response, got {other:?}"),
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn route_data_with_retry_stops_after_bounded_attempts() {
|
|
let reg = ConnRegistry::with_route_channel_capacity(1);
|
|
let (conn_id, _rx) = reg.register().await;
|
|
|
|
assert!(matches!(
|
|
reg.route_nowait(conn_id, MeResponse::Ack(1)).await,
|
|
RouteResult::Routed
|
|
));
|
|
|
|
let routed = route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 0).await;
|
|
assert!(matches!(
|
|
routed,
|
|
RouteResult::QueueFullBase | RouteResult::QueueFullHigh
|
|
));
|
|
}
|
|
}
|
|
|
|
async fn send_close_conn(tx: &mpsc::Sender<WriterCommand>, conn_id: u64) {
|
|
let mut p = Vec::with_capacity(12);
|
|
p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes());
|
|
p.extend_from_slice(&conn_id.to_le_bytes());
|
|
match tx.try_send(WriterCommand::DataAndFlush(Bytes::from(p))) {
|
|
Ok(()) => {}
|
|
Err(TrySendError::Full(_)) => {
|
|
debug!(
|
|
conn_id,
|
|
"ME close_conn signal skipped: writer command channel is full"
|
|
);
|
|
}
|
|
Err(TrySendError::Closed(_)) => {
|
|
debug!(
|
|
conn_id,
|
|
"ME close_conn signal skipped: writer command channel is closed"
|
|
);
|
|
}
|
|
}
|
|
}
|