mirror of https://github.com/telemt/telemt.git
191 lines
7.3 KiB
Rust
191 lines
7.3 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::sync::atomic::{AtomicBool, Ordering};
|
|
use std::time::Instant;
|
|
|
|
use bytes::{Bytes, BytesMut};
|
|
use tokio::io::AsyncReadExt;
|
|
use tokio::net::TcpStream;
|
|
use tokio::sync::{Mutex, mpsc};
|
|
use tokio_util::sync::CancellationToken;
|
|
use tracing::{debug, trace, warn};
|
|
|
|
use crate::crypto::{AesCbc, crc32};
|
|
use crate::error::{ProxyError, Result};
|
|
use crate::protocol::constants::*;
|
|
|
|
use super::codec::WriterCommand;
|
|
use super::{ConnRegistry, MeResponse};
|
|
|
|
pub(crate) async fn reader_loop(
|
|
mut rd: tokio::io::ReadHalf<TcpStream>,
|
|
dk: [u8; 32],
|
|
mut div: [u8; 16],
|
|
reg: Arc<ConnRegistry>,
|
|
enc_leftover: BytesMut,
|
|
mut dec: BytesMut,
|
|
tx: mpsc::Sender<WriterCommand>,
|
|
ping_tracker: Arc<Mutex<HashMap<i64, (Instant, u64)>>>,
|
|
rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
|
|
_writer_id: u64,
|
|
degraded: Arc<AtomicBool>,
|
|
cancel: CancellationToken,
|
|
) -> Result<()> {
|
|
let mut raw = enc_leftover;
|
|
let mut expected_seq: i32 = 0;
|
|
let mut crc_errors = 0u32;
|
|
let mut seq_mismatch = 0u32;
|
|
|
|
loop {
|
|
let mut tmp = [0u8; 16_384];
|
|
let n = tokio::select! {
|
|
res = rd.read(&mut tmp) => res.map_err(ProxyError::Io)?,
|
|
_ = cancel.cancelled() => return Ok(()),
|
|
};
|
|
if n == 0 {
|
|
return Ok(());
|
|
}
|
|
raw.extend_from_slice(&tmp[..n]);
|
|
|
|
let blocks = raw.len() / 16 * 16;
|
|
if blocks > 0 {
|
|
let mut new_iv = [0u8; 16];
|
|
new_iv.copy_from_slice(&raw[blocks - 16..blocks]);
|
|
|
|
let mut chunk = vec![0u8; blocks];
|
|
chunk.copy_from_slice(&raw[..blocks]);
|
|
AesCbc::new(dk, div)
|
|
.decrypt_in_place(&mut chunk)
|
|
.map_err(|e| ProxyError::Crypto(format!("{e}")))?;
|
|
div = new_iv;
|
|
dec.extend_from_slice(&chunk);
|
|
let _ = raw.split_to(blocks);
|
|
}
|
|
|
|
while dec.len() >= 12 {
|
|
let fl = u32::from_le_bytes(dec[0..4].try_into().unwrap()) as usize;
|
|
if fl == 4 {
|
|
let _ = dec.split_to(4);
|
|
continue;
|
|
}
|
|
if !(12..=(1 << 24)).contains(&fl) {
|
|
warn!(frame_len = fl, "Invalid RPC frame len");
|
|
dec.clear();
|
|
break;
|
|
}
|
|
if dec.len() < fl {
|
|
break;
|
|
}
|
|
|
|
let frame = dec.split_to(fl);
|
|
let pe = fl - 4;
|
|
let ec = u32::from_le_bytes(frame[pe..pe + 4].try_into().unwrap());
|
|
if crc32(&frame[..pe]) != ec {
|
|
warn!("CRC mismatch in data frame");
|
|
crc_errors += 1;
|
|
if crc_errors > 3 {
|
|
return Err(ProxyError::Proxy("Too many CRC mismatches".into()));
|
|
}
|
|
continue;
|
|
}
|
|
|
|
let seq_no = i32::from_le_bytes(frame[4..8].try_into().unwrap());
|
|
if seq_no != expected_seq {
|
|
warn!(seq_no, expected = expected_seq, "ME RPC seq mismatch");
|
|
seq_mismatch += 1;
|
|
if seq_mismatch > 10 {
|
|
return Err(ProxyError::Proxy("Too many seq mismatches".into()));
|
|
}
|
|
expected_seq = seq_no.wrapping_add(1);
|
|
} else {
|
|
expected_seq = expected_seq.wrapping_add(1);
|
|
}
|
|
|
|
let payload = &frame[8..pe];
|
|
if payload.len() < 4 {
|
|
continue;
|
|
}
|
|
|
|
let pt = u32::from_le_bytes(payload[0..4].try_into().unwrap());
|
|
let body = &payload[4..];
|
|
|
|
if pt == RPC_PROXY_ANS_U32 && body.len() >= 12 {
|
|
let flags = u32::from_le_bytes(body[0..4].try_into().unwrap());
|
|
let cid = u64::from_le_bytes(body[4..12].try_into().unwrap());
|
|
let data = Bytes::copy_from_slice(&body[12..]);
|
|
trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS");
|
|
|
|
let routed = reg.route(cid, MeResponse::Data { flags, data }).await;
|
|
if !routed {
|
|
reg.unregister(cid).await;
|
|
send_close_conn(&tx, cid).await;
|
|
}
|
|
} else if pt == RPC_SIMPLE_ACK_U32 && body.len() >= 12 {
|
|
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
|
|
let cfm = u32::from_le_bytes(body[8..12].try_into().unwrap());
|
|
trace!(cid, cfm, "RPC_SIMPLE_ACK");
|
|
|
|
let routed = reg.route(cid, MeResponse::Ack(cfm)).await;
|
|
if !routed {
|
|
reg.unregister(cid).await;
|
|
send_close_conn(&tx, cid).await;
|
|
}
|
|
} else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 {
|
|
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
|
|
debug!(cid, "RPC_CLOSE_EXT from ME");
|
|
reg.route(cid, MeResponse::Close).await;
|
|
reg.unregister(cid).await;
|
|
} else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 {
|
|
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
|
|
debug!(cid, "RPC_CLOSE_CONN from ME");
|
|
reg.route(cid, MeResponse::Close).await;
|
|
reg.unregister(cid).await;
|
|
} else if pt == RPC_PING_U32 && body.len() >= 8 {
|
|
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
|
|
trace!(ping_id, "RPC_PING -> RPC_PONG");
|
|
let mut pong = Vec::with_capacity(12);
|
|
pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes());
|
|
pong.extend_from_slice(&ping_id.to_le_bytes());
|
|
if tx.send(WriterCommand::DataAndFlush(pong)).await.is_err() {
|
|
warn!("PONG send failed");
|
|
break;
|
|
}
|
|
} else if pt == RPC_PONG_U32 && body.len() >= 8 {
|
|
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
|
|
if let Some((sent, wid)) = {
|
|
let mut guard = ping_tracker.lock().await;
|
|
guard.remove(&ping_id)
|
|
} {
|
|
let rtt = sent.elapsed().as_secs_f64() * 1000.0;
|
|
let mut stats = rtt_stats.lock().await;
|
|
let entry = stats.entry(wid).or_insert((rtt, rtt));
|
|
entry.1 = entry.1 * 0.8 + rtt * 0.2;
|
|
if rtt < entry.0 {
|
|
entry.0 = rtt;
|
|
} else {
|
|
// allow slow baseline drift upward to avoid stale minimum
|
|
entry.0 = entry.0 * 0.99 + rtt * 0.01;
|
|
}
|
|
let degraded_now = entry.1 > entry.0 * 2.0;
|
|
degraded.store(degraded_now, Ordering::Relaxed);
|
|
trace!(writer_id = wid, rtt_ms = rtt, ema_ms = entry.1, base_ms = entry.0, degraded = degraded_now, "ME RTT sample");
|
|
}
|
|
} else {
|
|
debug!(
|
|
rpc_type = format_args!("0x{pt:08x}"),
|
|
len = body.len(),
|
|
"Unknown RPC"
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn send_close_conn(tx: &mpsc::Sender<WriterCommand>, conn_id: u64) {
|
|
let mut p = Vec::with_capacity(12);
|
|
p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes());
|
|
p.extend_from_slice(&conn_id.to_le_bytes());
|
|
|
|
let _ = tx.send(WriterCommand::DataAndFlush(p)).await;
|
|
}
|