diff --git a/src/config/load.rs b/src/config/load.rs index 0635f80..c296697 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -612,6 +612,11 @@ impl ProxyConfig { "general.me_route_backpressure_base_timeout_ms must be > 0".to_string(), )); } + if config.general.me_route_backpressure_base_timeout_ms > 5000 { + return Err(ProxyError::Config( + "general.me_route_backpressure_base_timeout_ms must be within [1, 5000]".to_string(), + )); + } if config.general.me_route_backpressure_high_timeout_ms < config.general.me_route_backpressure_base_timeout_ms @@ -620,6 +625,11 @@ impl ProxyConfig { "general.me_route_backpressure_high_timeout_ms must be >= general.me_route_backpressure_base_timeout_ms".to_string(), )); } + if config.general.me_route_backpressure_high_timeout_ms > 5000 { + return Err(ProxyError::Config( + "general.me_route_backpressure_high_timeout_ms must be within [1, 5000]".to_string(), + )); + } if !(1..=100).contains(&config.general.me_route_backpressure_high_watermark_pct) { return Err(ProxyError::Config( @@ -1624,6 +1634,47 @@ mod tests { let _ = std::fs::remove_file(path_valid); } + #[test] + fn me_route_backpressure_base_timeout_ms_out_of_range_is_rejected() { + let toml = r#" + [general] + me_route_backpressure_base_timeout_ms = 5001 + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_me_route_backpressure_base_timeout_ms_out_of_range_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("general.me_route_backpressure_base_timeout_ms must be within [1, 5000]")); + let _ = std::fs::remove_file(path); + } + + #[test] + fn me_route_backpressure_high_timeout_ms_out_of_range_is_rejected() { + let toml = r#" + [general] + me_route_backpressure_base_timeout_ms = 100 + me_route_backpressure_high_timeout_ms = 5001 + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_me_route_backpressure_high_timeout_ms_out_of_range_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("general.me_route_backpressure_high_timeout_ms must be within [1, 5000]")); + let _ = std::fs::remove_file(path); + } + #[test] fn me_route_no_writer_wait_ms_out_of_range_is_rejected() { let toml = r#" diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index a6186b6..7d78b84 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -492,11 +492,9 @@ impl MePool { } pub(crate) async fn remove_writer_and_close_clients(self: &Arc, writer_id: u64) { - let conns = self.remove_writer_only(writer_id).await; - for bound in conns { - let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await; - let _ = self.registry.unregister(bound.conn_id).await; - } + // Full client cleanup now happens inside `registry.writer_lost` to keep + // writer reap/remove paths strictly non-blocking per connection. + let _ = self.remove_writer_only(writer_id).await; } async fn remove_writer_only(self: &Arc, writer_id: u64) -> Vec { diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index 785bc2c..8b15fc1 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -8,6 +8,7 @@ use bytes::{Bytes, BytesMut}; use tokio::io::AsyncReadExt; use tokio::net::TcpStream; use tokio::sync::{Mutex, mpsc}; +use tokio::sync::mpsc::error::TrySendError; use tokio_util::sync::CancellationToken; use tracing::{debug, trace, warn}; @@ -173,12 +174,12 @@ pub(crate) async fn reader_loop( } else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 { let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); debug!(cid, "RPC_CLOSE_EXT from ME"); - reg.route(cid, MeResponse::Close).await; + let _ = reg.route_nowait(cid, MeResponse::Close).await; reg.unregister(cid).await; } else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 { let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); debug!(cid, "RPC_CLOSE_CONN from ME"); - reg.route(cid, MeResponse::Close).await; + let _ = reg.route_nowait(cid, MeResponse::Close).await; reg.unregister(cid).await; } else if pt == RPC_PING_U32 && body.len() >= 8 { let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap()); @@ -186,13 +187,15 @@ pub(crate) async fn reader_loop( let mut pong = Vec::with_capacity(12); pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes()); pong.extend_from_slice(&ping_id.to_le_bytes()); - if tx - .send(WriterCommand::DataAndFlush(Bytes::from(pong))) - .await - .is_err() - { - warn!("PONG send failed"); - break; + match tx.try_send(WriterCommand::DataAndFlush(Bytes::from(pong))) { + Ok(()) => {} + Err(TrySendError::Full(_)) => { + debug!(ping_id, "PONG dropped: writer command channel is full"); + } + Err(TrySendError::Closed(_)) => { + warn!("PONG send failed: writer channel closed"); + break; + } } } else if pt == RPC_PONG_U32 && body.len() >= 8 { let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap()); @@ -232,6 +235,13 @@ async fn send_close_conn(tx: &mpsc::Sender, conn_id: u64) { let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes()); - - let _ = tx.send(WriterCommand::DataAndFlush(Bytes::from(p))).await; + match tx.try_send(WriterCommand::DataAndFlush(Bytes::from(p))) { + Ok(()) => {} + Err(TrySendError::Full(_)) => { + debug!(conn_id, "ME close_conn signal skipped: writer command channel is full"); + } + Err(TrySendError::Closed(_)) => { + debug!(conn_id, "ME close_conn signal skipped: writer command channel is closed"); + } + } } diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index b8a926e..2ee55c1 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -169,6 +169,7 @@ impl ConnRegistry { None } + #[allow(dead_code)] pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult { let tx = { let inner = self.inner.read().await; @@ -445,30 +446,38 @@ impl ConnRegistry { } pub async fn writer_lost(&self, writer_id: u64) -> Vec { - let mut inner = self.inner.write().await; - inner.writers.remove(&writer_id); - inner.last_meta_for_writer.remove(&writer_id); - inner.writer_idle_since_epoch_secs.remove(&writer_id); - let conns = inner - .conns_for_writer - .remove(&writer_id) - .unwrap_or_default() - .into_iter() - .collect::>(); - + let mut close_txs = Vec::>::new(); let mut out = Vec::new(); - for conn_id in conns { - if inner.writer_for_conn.get(&conn_id).copied() != Some(writer_id) { - continue; - } - inner.writer_for_conn.remove(&conn_id); - if let Some(m) = inner.meta.get(&conn_id) { - out.push(BoundConn { - conn_id, - meta: m.clone(), - }); + { + let mut inner = self.inner.write().await; + inner.writers.remove(&writer_id); + inner.last_meta_for_writer.remove(&writer_id); + inner.writer_idle_since_epoch_secs.remove(&writer_id); + let conns = inner + .conns_for_writer + .remove(&writer_id) + .unwrap_or_default() + .into_iter() + .collect::>(); + + for conn_id in conns { + if inner.writer_for_conn.get(&conn_id).copied() != Some(writer_id) { + continue; + } + inner.writer_for_conn.remove(&conn_id); + if let Some(client_tx) = inner.map.remove(&conn_id) { + close_txs.push(client_tx); + } + if let Some(meta) = inner.meta.remove(&conn_id) { + out.push(BoundConn { conn_id, meta }); + } } } + + for client_tx in close_txs { + let _ = client_tx.try_send(MeResponse::Close); + } + out } @@ -491,6 +500,7 @@ impl ConnRegistry { #[cfg(test)] mod tests { use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + use std::time::Duration; use super::ConnMeta; use super::ConnRegistry; @@ -663,6 +673,39 @@ mod tests { assert!(registry.is_writer_empty(20).await); } + #[tokio::test] + async fn writer_lost_removes_bound_conn_from_registry_and_signals_close() { + let registry = ConnRegistry::new(); + let (conn_id, mut rx) = registry.register().await; + let (writer_tx, _writer_rx) = tokio::sync::mpsc::channel(8); + registry.register_writer(10, writer_tx).await; + let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443); + + assert!( + registry + .bind_writer( + conn_id, + 10, + ConnMeta { + target_dc: 2, + client_addr: addr, + our_addr: addr, + proto_flags: 0, + }, + ) + .await + ); + + let lost = registry.writer_lost(10).await; + assert_eq!(lost.len(), 1); + assert_eq!(lost[0].conn_id, conn_id); + assert!(registry.get_writer(conn_id).await.is_none()); + assert!(registry.get_meta(conn_id).await.is_none()); + assert_eq!(registry.unregister(conn_id).await, None); + let close = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await; + assert!(matches!(close, Ok(Some(MeResponse::Close)))); + } + #[tokio::test] async fn bind_writer_rejects_unregistered_writer() { let registry = ConnRegistry::new(); diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 1c255ef..6791064 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -643,13 +643,19 @@ impl MePool { let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes()); - if w.tx - .send(WriterCommand::DataAndFlush(Bytes::from(p))) - .await - .is_err() - { - debug!("ME close write failed"); - self.remove_writer_and_close_clients(w.writer_id).await; + match w.tx.try_send(WriterCommand::DataAndFlush(Bytes::from(p))) { + Ok(()) => {} + Err(TrySendError::Full(_)) => { + debug!( + conn_id, + writer_id = w.writer_id, + "ME close skipped: writer command channel is full" + ); + } + Err(TrySendError::Closed(_)) => { + debug!("ME close write failed"); + self.remove_writer_and_close_clients(w.writer_id).await; + } } } else { debug!(conn_id, "ME close skipped (writer missing)"); @@ -666,8 +672,12 @@ impl MePool { p.extend_from_slice(&conn_id.to_le_bytes()); match w.tx.try_send(WriterCommand::DataAndFlush(Bytes::from(p))) { Ok(()) => {} - Err(TrySendError::Full(cmd)) => { - let _ = tokio::time::timeout(Duration::from_millis(50), w.tx.send(cmd)).await; + Err(TrySendError::Full(_)) => { + debug!( + conn_id, + writer_id = w.writer_id, + "ME close_conn skipped: writer command channel is full" + ); } Err(TrySendError::Closed(_)) => { debug!(conn_id, "ME close_conn skipped: writer channel closed");