diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 6f19789..506c354 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -38,6 +38,233 @@ fn is_me_peer_closed_error(error: &ProxyError) -> bool { matches!(error, ProxyError::Io(ioe) if ioe.kind() == ErrorKind::UnexpectedEof) } +enum WriterLifecycleExit { + Reader(Result<()>), + Writer(Result<()>), + Ping, + Signal, + Cancelled, +} + +async fn writer_command_loop( + mut rx: mpsc::Receiver, + mut rpc_writer: RpcWriter, + cancel: CancellationToken, +) -> Result<()> { + loop { + tokio::select! { + cmd = rx.recv() => { + match cmd { + Some(WriterCommand::Data(payload)) => { + rpc_writer.send(&payload).await?; + } + Some(WriterCommand::DataAndFlush(payload)) => { + rpc_writer.send_and_flush(&payload).await?; + } + Some(WriterCommand::Close) | None => return Ok(()), + } + } + _ = cancel.cancelled() => return Ok(()), + } + } +} + +#[allow(clippy::too_many_arguments)] +async fn ping_loop( + pool_ping: std::sync::Weak, + writer_id: u64, + tx_ping: mpsc::Sender, + ping_tracker_ping: Arc>>, + stats_ping: Arc, + keepalive_enabled: bool, + keepalive_interval: Duration, + keepalive_jitter: Duration, + cancel_ping_token: CancellationToken, +) { + let mut ping_id: i64 = rand::random::(); + let mut cleanup_tick: u32 = 0; + let idle_interval_cap = Duration::from_secs(ME_IDLE_KEEPALIVE_MAX_SECS); + // Per-writer jittered start to avoid phase sync. + let startup_jitter = if keepalive_enabled { + let mut interval = keepalive_interval; + let Some(pool) = pool_ping.upgrade() else { + return; + }; + if pool.registry.is_writer_empty(writer_id).await { + interval = interval.min(idle_interval_cap); + } + let jitter_cap_ms = interval.as_millis() / 2; + let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) + } else { + let jitter = rand::rng().random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); + let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; + Duration::from_secs(wait) + }; + tokio::select! { + _ = cancel_ping_token.cancelled() => return, + _ = tokio::time::sleep(startup_jitter) => {} + } + loop { + let wait = if keepalive_enabled { + let mut interval = keepalive_interval; + let Some(pool) = pool_ping.upgrade() else { + return; + }; + if pool.registry.is_writer_empty(writer_id).await { + interval = interval.min(idle_interval_cap); + } + let jitter_cap_ms = interval.as_millis() / 2; + let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); + interval + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) + } else { + let jitter = rand::rng().random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); + let secs = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; + Duration::from_secs(secs) + }; + tokio::select! { + _ = cancel_ping_token.cancelled() => return, + _ = tokio::time::sleep(wait) => {} + } + let sent_id = ping_id; + let mut p = Vec::with_capacity(12); + p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); + p.extend_from_slice(&sent_id.to_le_bytes()); + { + let mut tracker = ping_tracker_ping.lock().await; + cleanup_tick = cleanup_tick.wrapping_add(1); + if cleanup_tick.is_multiple_of(ME_PING_TRACKER_CLEANUP_EVERY) { + let before = tracker.len(); + tracker.retain(|_, ts| ts.elapsed() < Duration::from_secs(120)); + let expired = before.saturating_sub(tracker.len()); + if expired > 0 { + stats_ping.increment_me_keepalive_timeout_by(expired as u64); + } + } + tracker.insert(sent_id, std::time::Instant::now()); + } + ping_id = ping_id.wrapping_add(1); + stats_ping.increment_me_keepalive_sent(); + if tx_ping + .send(WriterCommand::DataAndFlush(Bytes::from(p))) + .await + .is_err() + { + stats_ping.increment_me_keepalive_failed(); + debug!("ME ping failed, removing dead writer"); + return; + } + } +} + +#[allow(clippy::too_many_arguments)] +async fn rpc_proxy_req_signal_loop( + pool_signal: std::sync::Weak, + writer_id: u64, + tx_signal: mpsc::Sender, + stats_signal: Arc, + cancel_signal: CancellationToken, + keepalive_jitter_signal: Duration, + rpc_proxy_req_every_secs: u64, +) { + if rpc_proxy_req_every_secs == 0 { + return; + } + + let interval = Duration::from_secs(rpc_proxy_req_every_secs); + let startup_jitter_ms = { + let jitter_cap_ms = interval.as_millis() / 2; + let effective_jitter_ms = keepalive_jitter_signal + .as_millis() + .min(jitter_cap_ms) + .max(1); + rand::rng().random_range(0..=effective_jitter_ms as u64) + }; + + tokio::select! { + _ = cancel_signal.cancelled() => return, + _ = tokio::time::sleep(Duration::from_millis(startup_jitter_ms)) => {} + } + + loop { + let wait = { + let jitter_cap_ms = interval.as_millis() / 2; + let effective_jitter_ms = keepalive_jitter_signal + .as_millis() + .min(jitter_cap_ms) + .max(1); + interval + Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) + }; + + tokio::select! { + _ = cancel_signal.cancelled() => return, + _ = tokio::time::sleep(wait) => {} + } + + let Some(pool) = pool_signal.upgrade() else { + return; + }; + + let Some(meta) = pool.registry.get_last_writer_meta(writer_id).await else { + stats_signal.increment_me_rpc_proxy_req_signal_skipped_no_meta_total(); + continue; + }; + + let (conn_id, mut service_rx) = pool.registry.register().await; + // Service RPC_PROXY_REQ signal path is intentionally route-only: + // do not bind synthetic conn_id into regular writer/client accounting. + + let payload = build_proxy_req_payload( + conn_id, + meta.client_addr, + meta.our_addr, + &[], + pool.proxy_tag.as_deref(), + meta.proto_flags, + ); + + if tx_signal + .send(WriterCommand::DataAndFlush(payload)) + .await + .is_err() + { + stats_signal.increment_me_rpc_proxy_req_signal_failed_total(); + let _ = pool.registry.unregister(conn_id).await; + return; + } + + stats_signal.increment_me_rpc_proxy_req_signal_sent_total(); + + if matches!( + tokio::time::timeout( + Duration::from_millis(ME_RPC_PROXY_REQ_RESPONSE_WAIT_MS), + service_rx.recv(), + ) + .await, + Ok(Some(_)) + ) { + stats_signal.increment_me_rpc_proxy_req_signal_response_total(); + } + + let mut close_payload = Vec::with_capacity(12); + close_payload.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); + close_payload.extend_from_slice(&conn_id.to_le_bytes()); + + if tx_signal + .send(WriterCommand::DataAndFlush(Bytes::from(close_payload))) + .await + .is_err() + { + stats_signal.increment_me_rpc_proxy_req_signal_failed_total(); + let _ = pool.registry.unregister(conn_id).await; + return; + } + + stats_signal.increment_me_rpc_proxy_req_signal_close_sent_total(); + let _ = pool.registry.unregister(conn_id).await; + } +} + impl MePool { pub(crate) async fn prune_closed_writers(self: &Arc) { let closed_writer_ids: Vec = { @@ -138,46 +365,14 @@ impl MePool { let draining_started_at_epoch_secs = Arc::new(AtomicU64::new(0)); let drain_deadline_epoch_secs = Arc::new(AtomicU64::new(0)); let allow_drain_fallback = Arc::new(AtomicBool::new(false)); - let (tx, mut rx) = mpsc::channel::(self.writer_cmd_channel_capacity); - let mut rpc_writer = RpcWriter { + let (tx, rx) = mpsc::channel::(self.writer_cmd_channel_capacity); + let rpc_writer = RpcWriter { writer: hs.wr, key: hs.write_key, iv: hs.write_iv, seq_no: 0, crc_mode: hs.crc_mode, }; - let cancel_wr = cancel.clone(); - let cleanup_done = Arc::new(AtomicBool::new(false)); - let cleanup_for_writer = cleanup_done.clone(); - let pool_writer_task = Arc::downgrade(self); - tokio::spawn(async move { - loop { - tokio::select! { - cmd = rx.recv() => { - match cmd { - Some(WriterCommand::Data(payload)) => { - if rpc_writer.send(&payload).await.is_err() { break; } - } - Some(WriterCommand::DataAndFlush(payload)) => { - if rpc_writer.send_and_flush(&payload).await.is_err() { break; } - } - Some(WriterCommand::Close) | None => break, - } - } - _ = cancel_wr.cancelled() => break, - } - } - if cleanup_for_writer - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() - { - if let Some(pool) = pool_writer_task.upgrade() { - pool.remove_writer_and_close_clients(writer_id).await; - } else { - cancel_wr.cancel(); - } - } - }); let writer = MeWriter { id: writer_id, addr, @@ -207,290 +402,120 @@ impl MePool { let writers_arc = self.writers_arc(); let ping_tracker = Arc::new(tokio::sync::Mutex::new(HashMap::::new())); let ping_tracker_reader = ping_tracker.clone(); + let ping_tracker_ping = ping_tracker.clone(); let rtt_stats = self.rtt_stats.clone(); let stats_reader = self.stats.clone(); let stats_reader_close = self.stats.clone(); let stats_ping = self.stats.clone(); - let pool = Arc::downgrade(self); - let cancel_ping = cancel.clone(); + let stats_signal = self.stats.clone(); + let pool_lifecycle = Arc::downgrade(self); + let pool_ping = Arc::downgrade(self); + let pool_signal = Arc::downgrade(self); + let tx_reader = tx.clone(); let tx_ping = tx.clone(); - let ping_tracker_ping = ping_tracker.clone(); - let cleanup_for_reader = cleanup_done.clone(); - let cleanup_for_ping = cleanup_done.clone(); + let tx_signal = tx.clone(); let keepalive_enabled = self.me_keepalive_enabled; let keepalive_interval = self.me_keepalive_interval; let keepalive_jitter = self.me_keepalive_jitter; - let rpc_proxy_req_every_secs = self.rpc_proxy_req_every_secs.load(Ordering::Relaxed); - let tx_signal = tx.clone(); - let stats_signal = self.stats.clone(); - let cancel_signal = cancel.clone(); - let cleanup_for_signal = cleanup_done.clone(); - let pool_signal = Arc::downgrade(self); let keepalive_jitter_signal = self.me_keepalive_jitter; - let cancel_reader_token = cancel.clone(); - let cancel_ping_token = cancel_ping.clone(); + let rpc_proxy_req_every_secs = self.rpc_proxy_req_every_secs.load(Ordering::Relaxed); + let cancel_reader = cancel.clone(); + let cancel_writer = cancel.clone(); + let cancel_ping = cancel.clone(); + let cancel_signal = cancel.clone(); + let cancel_select = cancel.clone(); + let cancel_cleanup = cancel.clone(); let reader_route_data_wait_ms = self.me_reader_route_data_wait_ms.clone(); tokio::spawn(async move { - let res = reader_loop( - hs.rd, - hs.read_key, - hs.read_iv, - hs.crc_mode, - reg.clone(), - BytesMut::new(), - BytesMut::new(), - tx.clone(), - ping_tracker_reader, - rtt_stats.clone(), - stats_reader, - writer_id, - degraded.clone(), - rtt_ema_ms_x10.clone(), - reader_route_data_wait_ms, - cancel_reader_token.clone(), - ) - .await; - let idle_close_by_peer = if let Err(e) = res.as_ref() { - is_me_peer_closed_error(e) && reg.is_writer_empty(writer_id).await - } else { - false - }; - if idle_close_by_peer { - stats_reader_close.increment_me_idle_close_by_peer_total(); - info!(writer_id, "ME socket closed by peer on idle writer"); - } - if cleanup_for_reader - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() - { - if let Some(pool) = pool.upgrade() { - pool.remove_writer_and_close_clients(writer_id).await; - } else { - // Fallback for shutdown races: make writer task exit quickly so stale - // channels are observable by periodic prune. - cancel_reader_token.cancel(); - } - } - if let Err(e) = res - && !idle_close_by_peer - { - warn!(error = %e, "ME reader ended"); - } - let remaining = writers_arc.read().await.len(); - debug!(writer_id, remaining, "ME reader task finished"); - }); + // Reader MUST be the first branch in biased select! to avoid read starvation. + let exit = tokio::select! { + biased; - let pool_ping = Arc::downgrade(self); - tokio::spawn(async move { - let mut ping_id: i64 = rand::random::(); - let mut cleanup_tick: u32 = 0; - let idle_interval_cap = Duration::from_secs(ME_IDLE_KEEPALIVE_MAX_SECS); - // Per-writer jittered start to avoid phase sync. - let startup_jitter = if keepalive_enabled { - let mut interval = keepalive_interval; - if let Some(pool) = pool_ping.upgrade() { - if pool.registry.is_writer_empty(writer_id).await { - interval = interval.min(idle_interval_cap); - } - } else { - return; + reader_res = reader_loop( + hs.rd, + hs.read_key, + hs.read_iv, + hs.crc_mode, + reg.clone(), + BytesMut::new(), + BytesMut::new(), + tx_reader, + ping_tracker_reader, + rtt_stats, + stats_reader, + writer_id, + degraded, + rtt_ema_ms_x10, + reader_route_data_wait_ms, + cancel_reader, + ) => WriterLifecycleExit::Reader(reader_res), + writer_res = writer_command_loop(rx, rpc_writer, cancel_writer) => { + WriterLifecycleExit::Writer(writer_res) } - let jitter_cap_ms = interval.as_millis() / 2; - let effective_jitter_ms = keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); - Duration::from_millis(rand::rng().random_range(0..=effective_jitter_ms as u64)) - } else { - let jitter = rand::rng() - .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); - let wait = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; - Duration::from_secs(wait) + _ = ping_loop( + pool_ping, + writer_id, + tx_ping, + ping_tracker_ping, + stats_ping, + keepalive_enabled, + keepalive_interval, + keepalive_jitter, + cancel_ping, + ) => WriterLifecycleExit::Ping, + _ = rpc_proxy_req_signal_loop( + pool_signal, + writer_id, + tx_signal, + stats_signal, + cancel_signal, + keepalive_jitter_signal, + rpc_proxy_req_every_secs, + ) => WriterLifecycleExit::Signal, + _ = cancel_select.cancelled() => WriterLifecycleExit::Cancelled, }; - tokio::select! { - _ = cancel_ping_token.cancelled() => return, - _ = tokio::time::sleep(startup_jitter) => {} - } - loop { - let wait = if keepalive_enabled { - let mut interval = keepalive_interval; - if let Some(pool) = pool_ping.upgrade() { - if pool.registry.is_writer_empty(writer_id).await { - interval = interval.min(idle_interval_cap); - } + + match exit { + WriterLifecycleExit::Reader(res) => { + let idle_close_by_peer = if let Err(e) = res.as_ref() { + is_me_peer_closed_error(e) && reg.is_writer_empty(writer_id).await } else { - break; + false + }; + if idle_close_by_peer { + stats_reader_close.increment_me_idle_close_by_peer_total(); + info!(writer_id, "ME socket closed by peer on idle writer"); } - let jitter_cap_ms = interval.as_millis() / 2; - let effective_jitter_ms = - keepalive_jitter.as_millis().min(jitter_cap_ms).max(1); - interval - + Duration::from_millis( - rand::rng().random_range(0..=effective_jitter_ms as u64), - ) - } else { - let jitter = rand::rng() - .random_range(-ME_ACTIVE_PING_JITTER_SECS..=ME_ACTIVE_PING_JITTER_SECS); - let secs = (ME_ACTIVE_PING_SECS as i64 + jitter).max(5) as u64; - Duration::from_secs(secs) - }; - tokio::select! { - _ = cancel_ping_token.cancelled() => { - break; - } - _ = tokio::time::sleep(wait) => {} - } - let sent_id = ping_id; - let mut p = Vec::with_capacity(12); - p.extend_from_slice(&RPC_PING_U32.to_le_bytes()); - p.extend_from_slice(&sent_id.to_le_bytes()); - { - let mut tracker = ping_tracker_ping.lock().await; - cleanup_tick = cleanup_tick.wrapping_add(1); - if cleanup_tick.is_multiple_of(ME_PING_TRACKER_CLEANUP_EVERY) { - let before = tracker.len(); - tracker.retain(|_, ts| ts.elapsed() < Duration::from_secs(120)); - let expired = before.saturating_sub(tracker.len()); - if expired > 0 { - stats_ping.increment_me_keepalive_timeout_by(expired as u64); - } - } - tracker.insert(sent_id, std::time::Instant::now()); - } - ping_id = ping_id.wrapping_add(1); - stats_ping.increment_me_keepalive_sent(); - if tx_ping - .send(WriterCommand::DataAndFlush(Bytes::from(p))) - .await - .is_err() - { - stats_ping.increment_me_keepalive_failed(); - debug!("ME ping failed, removing dead writer"); - cancel_ping.cancel(); - if cleanup_for_ping - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() - && let Some(pool) = pool_ping.upgrade() + if let Err(e) = res + && !idle_close_by_peer { - pool.remove_writer_and_close_clients(writer_id).await; + warn!(error = %e, "ME reader ended"); } - break; } - } - }); - - tokio::spawn(async move { - if rpc_proxy_req_every_secs == 0 { - return; - } - - let interval = Duration::from_secs(rpc_proxy_req_every_secs); - let startup_jitter_ms = { - let jitter_cap_ms = interval.as_millis() / 2; - let effective_jitter_ms = keepalive_jitter_signal - .as_millis() - .min(jitter_cap_ms) - .max(1); - rand::rng().random_range(0..=effective_jitter_ms as u64) - }; - - tokio::select! { - _ = cancel_signal.cancelled() => return, - _ = tokio::time::sleep(Duration::from_millis(startup_jitter_ms)) => {} - } - - loop { - let wait = { - let jitter_cap_ms = interval.as_millis() / 2; - let effective_jitter_ms = keepalive_jitter_signal - .as_millis() - .min(jitter_cap_ms) - .max(1); - interval - + Duration::from_millis( - rand::rng().random_range(0..=effective_jitter_ms as u64), - ) - }; - - tokio::select! { - _ = cancel_signal.cancelled() => break, - _ = tokio::time::sleep(wait) => {} - } - - let Some(pool) = pool_signal.upgrade() else { - break; - }; - - let Some(meta) = pool.registry.get_last_writer_meta(writer_id).await else { - stats_signal.increment_me_rpc_proxy_req_signal_skipped_no_meta_total(); - continue; - }; - - let (conn_id, mut service_rx) = pool.registry.register().await; - // Service RPC_PROXY_REQ signal path is intentionally route-only: - // do not bind synthetic conn_id into regular writer/client accounting. - - let payload = build_proxy_req_payload( - conn_id, - meta.client_addr, - meta.our_addr, - &[], - pool.proxy_tag.as_deref(), - meta.proto_flags, - ); - - if tx_signal - .send(WriterCommand::DataAndFlush(payload)) - .await - .is_err() - { - stats_signal.increment_me_rpc_proxy_req_signal_failed_total(); - let _ = pool.registry.unregister(conn_id).await; - cancel_signal.cancel(); - if cleanup_for_signal - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() - { - pool.remove_writer_and_close_clients(writer_id).await; + WriterLifecycleExit::Writer(res) => { + if let Err(e) = res { + warn!(error = %e, "ME writer command loop ended"); } - break; } - - stats_signal.increment_me_rpc_proxy_req_signal_sent_total(); - - if matches!( - tokio::time::timeout( - Duration::from_millis(ME_RPC_PROXY_REQ_RESPONSE_WAIT_MS), - service_rx.recv(), - ) - .await, - Ok(Some(_)) - ) { - stats_signal.increment_me_rpc_proxy_req_signal_response_total(); + WriterLifecycleExit::Ping => { + debug!(writer_id, "ME ping loop finished"); } - - let mut close_payload = Vec::with_capacity(12); - close_payload.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); - close_payload.extend_from_slice(&conn_id.to_le_bytes()); - - if tx_signal - .send(WriterCommand::DataAndFlush(Bytes::from(close_payload))) - .await - .is_err() - { - stats_signal.increment_me_rpc_proxy_req_signal_failed_total(); - let _ = pool.registry.unregister(conn_id).await; - cancel_signal.cancel(); - if cleanup_for_signal - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed) - .is_ok() - { - pool.remove_writer_and_close_clients(writer_id).await; - } - break; + WriterLifecycleExit::Signal => { + debug!(writer_id, "ME rpc_proxy_req signal loop finished"); } - - stats_signal.increment_me_rpc_proxy_req_signal_close_sent_total(); - let _ = pool.registry.unregister(conn_id).await; + WriterLifecycleExit::Cancelled => {} } + + if let Some(pool) = pool_lifecycle.upgrade() { + pool.remove_writer_and_close_clients(writer_id).await; + } else { + // Fallback for shutdown races: make lifecycle exit observable by prune. + cancel_cleanup.cancel(); + } + + let remaining = writers_arc.read().await.len(); + debug!(writer_id, remaining, "ME writer lifecycle task finished"); }); Ok(())