diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 851c0b7..8faeabf 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -415,7 +415,6 @@ impl MePool { let degraded = Arc::new(AtomicBool::new(false)); let draining = Arc::new(AtomicBool::new(false)); let (tx, mut rx) = mpsc::channel::(4096); - let tx_for_keepalive = tx.clone(); let mut rpc_writer = RpcWriter { writer: hs.wr, key: hs.write_key, @@ -461,7 +460,6 @@ impl MePool { let rtt_stats = self.rtt_stats.clone(); let stats_reader = self.stats.clone(); let stats_ping = self.stats.clone(); - let stats_keepalive = self.stats.clone(); let pool = Arc::downgrade(self); let cancel_ping = cancel.clone(); let tx_ping = tx.clone(); @@ -474,7 +472,6 @@ impl MePool { let keepalive_jitter = self.me_keepalive_jitter; let cancel_reader_token = cancel.clone(); let cancel_ping_token = cancel_ping.clone(); - let cancel_keepalive_token = cancel.clone(); tokio::spawn(async move { let res = reader_loop( @@ -513,15 +510,40 @@ impl MePool { let pool_ping = Arc::downgrade(self); tokio::spawn(async move { let mut ping_id: i64 = rand::random::(); - loop { + // Per-writer jittered start to avoid phase sync. + let startup_jitter = if keepalive_enabled { + let jitter_cap_ms = keepalive_interval.as_millis() / 2; + let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) + } else { let jitter = rand::rng() .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; + Duration::from_secs(wait) + }; + tokio::select! { + _ = cancel_ping_token.cancelled() => return, + _ = tokio::time::sleep(startup_jitter) => {} + } + loop { + let wait = if keepalive_enabled { + let jitter_cap_ms = keepalive_interval.as_millis() / 2; + let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); + keepalive_interval + + Duration::from_millis( + rand::rng().random_range(0..=effective_jitter_ms as u64) + ) + } else { + let jitter = rand::rng() + .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); + let secs = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; + Duration::from_secs(secs) + }; tokio::select! { _ = cancel_ping_token.cancelled() => { break; } - _ = tokio::time::sleep(Duration::from_secs(wait)) => {} + _ = tokio::time::sleep(wait) => {} } let sent_id = ping_id; let mut p = Vec::with_capacity(12); @@ -538,8 +560,10 @@ impl MePool { tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); } ping_id = ping_id.wrapping_add(1); + stats_ping.increment_me_keepalive_sent(); if tx_ping.send(WriterCommand::DataAndFlush(p)).await.is_err() { - debug!("Active ME ping failed, removing dead writer"); + stats_ping.increment_me_keepalive_failed(); + debug!("ME ping failed, removing dead writer"); cancel_ping.cancel(); if let Some(pool) = pool_ping.upgrade() { if cleanup_for_ping @@ -554,46 +578,6 @@ impl MePool { } }); - if keepalive_enabled { - let tx_keepalive = tx_for_keepalive; - let cancel_keepalive = cancel_keepalive_token; - let ping_tracker_keepalive = ping_tracker.clone(); - tokio::spawn(async move { - // Per-writer jittered start to avoid phase sync. - let jitter_cap_ms = keepalive_interval.as_millis() / 2; - let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); - let initial_jitter_ms = rand::rng().random_range(0..=effective_jitter_ms as u64); - tokio::time::sleep(Duration::from_millis(initial_jitter_ms)).await; - let mut ping_id: i64 = rand::random::(); - loop { - tokio::select! { - _ = cancel_keepalive.cancelled() => break, - _ = tokio::time::sleep(keepalive_interval + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64))) => {} - } - let sent_id = ping_id; - ping_id = ping_id.wrapping_add(1); - let mut p = Vec::with_capacity(12); - p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); - p.extend_from_slice(&sent_id.to_le_bytes()); - { - let mut tracker = ping_tracker_keepalive.lock().await; - let before = tracker.len(); - tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120)); - let expired = before.saturating_sub(tracker.len()); - if expired > 0 { - stats_keepalive.increment_me_keepalive_timeout_by(expired as u64); - } - tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); - } - stats_keepalive.increment_me_keepalive_sent(); - if tx_keepalive.send(WriterCommand::DataAndFlush(p)).await.is_err() { - stats_keepalive.increment_me_keepalive_failed(); - break; - } - } - }); - } - Ok(()) } @@ -630,15 +614,19 @@ impl MePool { } async fn remove_writer_only(&self, writer_id: u64) -> Vec { + let mut close_tx: Option> = None; { let mut ws = self.writers.write().await; if let Some(pos) = ws.iter().position(|w| w.id == writer_id) { let w = ws.remove(pos); w.cancel.cancel(); - let _ = w.tx.send(WriterCommand::Close).await; + close_tx = Some(w.tx.clone()); self.conn_count.fetch_sub(1, Ordering::Relaxed); } } + if let Some(tx) = close_tx { + let _ = tx.send(WriterCommand::Close).await; + } self.rtt_stats.lock().await.remove(&writer_id); self.registry.writer_lost(writer_id).await }