Decomposing hot-path modules into focused submodules

Signed-off-by: Alexey <247128645+axkurcom@users.noreply.github.com>
This commit is contained in:
Alexey
2026-05-21 13:28:40 +03:00
parent c02c7fbe43
commit 98c985091c
46 changed files with 9297 additions and 8488 deletions

View File

@@ -4,6 +4,9 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::crypto::{AesCbc, crc32, crc32c};
use crate::error::{ProxyError, Result};
use crate::protocol::constants::*;
use crate::stream::PooledBuffer;
use super::wire::{append_proxy_req_payload_into, proxy_req_payload_len};
const RPC_WRITER_FRAME_BUF_SHRINK_THRESHOLD: usize = 256 * 1024;
const RPC_WRITER_FRAME_BUF_RETAIN: usize = 64 * 1024;
@@ -12,10 +15,21 @@ const RPC_WRITER_FRAME_BUF_RETAIN: usize = 64 * 1024;
pub(crate) enum WriterCommand {
Data(Bytes),
DataAndFlush(Bytes),
ProxyReq(ProxyReqCommand),
ControlAndFlush([u8; 12]),
Close,
}
/// Structured proxy request command that lets the writer encode directly into its frame buffer.
pub(crate) struct ProxyReqCommand {
pub(crate) conn_id: u64,
pub(crate) client_addr: std::net::SocketAddr,
pub(crate) our_addr: std::net::SocketAddr,
pub(crate) proto_flags: u32,
pub(crate) proxy_tag: Option<[u8; 16]>,
pub(crate) payload: PooledBuffer,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum RpcChecksumMode {
Crc32,
@@ -249,7 +263,37 @@ impl RpcWriter {
pub(crate) async fn send(&mut self, payload: &[u8]) -> Result<()> {
build_rpc_frame_into(&mut self.frame_buf, self.seq_no, payload, self.crc_mode);
self.seq_no = self.seq_no.wrapping_add(1);
self.encrypt_and_write_frame().await
}
pub(crate) async fn send_proxy_req(&mut self, command: &ProxyReqCommand) -> Result<()> {
let payload_len = proxy_req_payload_len(
command.payload.len(),
command.proxy_tag.as_ref().map(|tag| tag.as_slice()),
command.proto_flags,
);
let total_len = 4 + 4 + payload_len + 4;
self.frame_buf.clear();
self.frame_buf.reserve(total_len + 15);
self.frame_buf
.extend_from_slice(&(total_len as u32).to_le_bytes());
self.frame_buf.extend_from_slice(&self.seq_no.to_le_bytes());
append_proxy_req_payload_into(
&mut self.frame_buf,
command.conn_id,
command.client_addr,
command.our_addr,
command.payload.as_ref(),
command.proxy_tag.as_ref().map(|tag| tag.as_slice()),
command.proto_flags,
);
let c = rpc_crc(self.crc_mode, &self.frame_buf);
self.frame_buf.extend_from_slice(&c.to_le_bytes());
self.seq_no = self.seq_no.wrapping_add(1);
self.encrypt_and_write_frame().await
}
async fn encrypt_and_write_frame(&mut self) -> Result<()> {
let pad = (16 - (self.frame_buf.len() % 16)) % 16;
let pad_pattern: [u8; 4] = [0x04, 0x00, 0x00, 0x00];
for i in 0..pad {

View File

@@ -60,6 +60,9 @@ async fn writer_command_loop(
Some(WriterCommand::DataAndFlush(payload)) => {
rpc_writer.send_and_flush(&payload).await?;
}
Some(WriterCommand::ProxyReq(command)) => {
rpc_writer.send_proxy_req(&command).await?;
}
Some(WriterCommand::ControlAndFlush(payload)) => {
rpc_writer.send_and_flush(&payload).await?;
}

View File

@@ -2,14 +2,13 @@ use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use std::time::{SystemTime, UNIX_EPOCH};
use dashmap::DashMap;
use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::{Mutex, Semaphore, mpsc};
use super::MeResponse;
use super::codec::WriterCommand;
use super::{MeResponse, RouteBytePermit};
const ROUTE_BACKPRESSURE_BASE_TIMEOUT_MS: u64 = 25;
const ROUTE_BACKPRESSURE_HIGH_TIMEOUT_MS: u64 = 120;
@@ -18,6 +17,8 @@ const ROUTE_QUEUED_BYTE_PERMIT_UNIT: usize = 16 * 1024;
const ROUTE_QUEUED_PERMITS_PER_SLOT: usize = 4;
const ROUTE_QUEUED_MAX_FRAME_PERMITS: usize = 1024;
mod writer;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RouteResult {
Routed,
@@ -218,760 +219,7 @@ impl ConnRegistry {
);
(id, rx)
}
pub async fn register_writer(&self, writer_id: u64, tx: mpsc::Sender<WriterCommand>) {
let mut binding = self.binding.inner.lock().await;
binding.writers.insert(writer_id, tx.clone());
binding
.conns_for_writer
.entry(writer_id)
.or_insert_with(HashSet::new);
self.writers.map.insert(writer_id, tx);
}
/// Unregister connection, returning associated writer_id if any.
pub async fn unregister(&self, id: u64) -> Option<u64> {
self.routing.map.remove(&id);
self.routing.byte_budget.remove(&id);
self.hot_binding.map.remove(&id);
let mut binding = self.binding.inner.lock().await;
binding.meta.remove(&id);
if let Some(writer_id) = binding.writer_for_conn.remove(&id) {
let became_empty = if let Some(set) = binding.conns_for_writer.get_mut(&writer_id) {
set.remove(&id);
set.is_empty()
} else {
false
};
if became_empty {
binding
.writer_idle_since_epoch_secs
.insert(writer_id, Self::now_epoch_secs());
}
return Some(writer_id);
}
None
}
async fn attach_route_byte_permit(
&self,
id: u64,
resp: MeResponse,
timeout_ms: Option<u64>,
) -> std::result::Result<MeResponse, RouteResult> {
let MeResponse::Data {
flags,
data,
route_permit,
} = resp
else {
return Ok(resp);
};
if route_permit.is_some() {
return Ok(MeResponse::Data {
flags,
data,
route_permit,
});
}
let Some(semaphore) = self
.routing
.byte_budget
.get(&id)
.map(|entry| entry.value().clone())
else {
return Err(RouteResult::NoConn);
};
let permits = Self::route_data_permits(data.len());
let permit = match timeout_ms {
Some(0) => semaphore
.try_acquire_many_owned(permits)
.map_err(|_| RouteResult::QueueFullHigh)?,
Some(timeout_ms) => {
let acquire = semaphore.acquire_many_owned(permits);
match tokio::time::timeout(Duration::from_millis(timeout_ms.max(1)), acquire).await
{
Ok(Ok(permit)) => permit,
Ok(Err(_)) => return Err(RouteResult::ChannelClosed),
Err(_) => return Err(RouteResult::QueueFullHigh),
}
}
None => semaphore
.acquire_many_owned(permits)
.await
.map_err(|_| RouteResult::ChannelClosed)?,
};
Ok(MeResponse::Data {
flags,
data,
route_permit: Some(RouteBytePermit::new(permit)),
})
}
#[allow(dead_code)]
pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult {
let tx = self.routing.map.get(&id).map(|entry| entry.value().clone());
let Some(tx) = tx else {
return RouteResult::NoConn;
};
let base_timeout_ms = self
.route_backpressure_base_timeout_ms
.load(Ordering::Relaxed)
.max(1);
let resp = match self
.attach_route_byte_permit(id, resp, Some(base_timeout_ms))
.await
{
Ok(resp) => resp,
Err(result) => return result,
};
match tx.try_send(resp) {
Ok(()) => RouteResult::Routed,
Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed,
Err(TrySendError::Full(resp)) => {
// Absorb short bursts without dropping/closing the session immediately.
let high_timeout_ms = self
.route_backpressure_high_timeout_ms
.load(Ordering::Relaxed)
.max(base_timeout_ms);
let high_watermark_pct = self
.route_backpressure_high_watermark_pct
.load(Ordering::Relaxed)
.clamp(1, 100);
let used = self.route_channel_capacity.saturating_sub(tx.capacity());
let used_pct = if self.route_channel_capacity == 0 {
100
} else {
(used.saturating_mul(100) / self.route_channel_capacity) as u8
};
let high_profile = used_pct >= high_watermark_pct;
let timeout_ms = if high_profile {
high_timeout_ms
} else {
base_timeout_ms
};
let timeout_dur = Duration::from_millis(timeout_ms);
match tokio::time::timeout(timeout_dur, tx.send(resp)).await {
Ok(Ok(())) => RouteResult::Routed,
Ok(Err(_)) => RouteResult::ChannelClosed,
Err(_) => {
if high_profile {
RouteResult::QueueFullHigh
} else {
RouteResult::QueueFullBase
}
}
}
}
}
}
pub async fn route_nowait(&self, id: u64, resp: MeResponse) -> RouteResult {
let tx = self.routing.map.get(&id).map(|entry| entry.value().clone());
let Some(tx) = tx else {
return RouteResult::NoConn;
};
let resp = match self.attach_route_byte_permit(id, resp, Some(0)).await {
Ok(resp) => resp,
Err(result) => return result,
};
match tx.try_send(resp) {
Ok(()) => RouteResult::Routed,
Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed,
Err(TrySendError::Full(_)) => RouteResult::QueueFullBase,
}
}
pub async fn route_with_timeout(
&self,
id: u64,
resp: MeResponse,
timeout_ms: u64,
) -> RouteResult {
if timeout_ms == 0 {
return self.route_nowait(id, resp).await;
}
let tx = self.routing.map.get(&id).map(|entry| entry.value().clone());
let Some(tx) = tx else {
return RouteResult::NoConn;
};
let resp = match self
.attach_route_byte_permit(id, resp, Some(timeout_ms))
.await
{
Ok(resp) => resp,
Err(result) => return result,
};
match tx.try_send(resp) {
Ok(()) => RouteResult::Routed,
Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed,
Err(TrySendError::Full(resp)) => {
let high_watermark_pct = self
.route_backpressure_high_watermark_pct
.load(Ordering::Relaxed)
.clamp(1, 100);
let used = self.route_channel_capacity.saturating_sub(tx.capacity());
let used_pct = if self.route_channel_capacity == 0 {
100
} else {
(used.saturating_mul(100) / self.route_channel_capacity) as u8
};
let high_profile = used_pct >= high_watermark_pct;
let timeout_dur = Duration::from_millis(timeout_ms.max(1));
match tokio::time::timeout(timeout_dur, tx.send(resp)).await {
Ok(Ok(())) => RouteResult::Routed,
Ok(Err(_)) => RouteResult::ChannelClosed,
Err(_) => {
if high_profile {
RouteResult::QueueFullHigh
} else {
RouteResult::QueueFullBase
}
}
}
}
}
}
pub async fn bind_writer(&self, conn_id: u64, writer_id: u64, meta: ConnMeta) -> bool {
let mut binding = self.binding.inner.lock().await;
// ROUTING IS THE SOURCE OF TRUTH:
// never keep/attach writer binding for a connection that is already
// absent from the routing table.
if !self.routing.map.contains_key(&conn_id) {
return false;
}
if !binding.writers.contains_key(&writer_id) {
return false;
}
let previous_writer_id = binding.writer_for_conn.insert(conn_id, writer_id);
if let Some(previous_writer_id) = previous_writer_id
&& previous_writer_id != writer_id
{
let became_empty =
if let Some(set) = binding.conns_for_writer.get_mut(&previous_writer_id) {
set.remove(&conn_id);
set.is_empty()
} else {
false
};
if became_empty {
binding
.writer_idle_since_epoch_secs
.insert(previous_writer_id, Self::now_epoch_secs());
}
}
binding.meta.insert(conn_id, meta.clone());
binding.last_meta_for_writer.insert(writer_id, meta.clone());
binding.writer_idle_since_epoch_secs.remove(&writer_id);
binding
.conns_for_writer
.entry(writer_id)
.or_insert_with(HashSet::new)
.insert(conn_id);
self.hot_binding
.map
.insert(conn_id, HotConnBinding { writer_id, meta });
true
}
pub async fn mark_writer_idle(&self, writer_id: u64) {
let mut binding = self.binding.inner.lock().await;
binding
.conns_for_writer
.entry(writer_id)
.or_insert_with(HashSet::new);
binding
.writer_idle_since_epoch_secs
.entry(writer_id)
.or_insert(Self::now_epoch_secs());
}
pub async fn get_last_writer_meta(&self, writer_id: u64) -> Option<ConnMeta> {
let binding = self.binding.inner.lock().await;
binding.last_meta_for_writer.get(&writer_id).cloned()
}
pub async fn writer_idle_since_snapshot(&self) -> HashMap<u64, u64> {
let binding = self.binding.inner.lock().await;
binding.writer_idle_since_epoch_secs.clone()
}
pub async fn writer_idle_since_for_writer_ids(&self, writer_ids: &[u64]) -> HashMap<u64, u64> {
let binding = self.binding.inner.lock().await;
let mut out = HashMap::<u64, u64>::with_capacity(writer_ids.len());
for writer_id in writer_ids {
if let Some(idle_since) = binding.writer_idle_since_epoch_secs.get(writer_id).copied() {
out.insert(*writer_id, idle_since);
}
}
out
}
pub(super) async fn writer_activity_snapshot(&self) -> WriterActivitySnapshot {
let binding = self.binding.inner.lock().await;
let mut bound_clients_by_writer = HashMap::<u64, usize>::new();
let mut active_sessions_by_target_dc = HashMap::<i16, usize>::new();
for (writer_id, conn_ids) in &binding.conns_for_writer {
bound_clients_by_writer.insert(*writer_id, conn_ids.len());
}
for conn_meta in binding.meta.values() {
if conn_meta.target_dc == 0 {
continue;
}
*active_sessions_by_target_dc
.entry(conn_meta.target_dc)
.or_insert(0) += 1;
}
WriterActivitySnapshot {
bound_clients_by_writer,
active_sessions_by_target_dc,
}
}
pub async fn get_writer(&self, conn_id: u64) -> Option<ConnWriter> {
if !self.routing.map.contains_key(&conn_id) {
return None;
}
let writer_id = self
.hot_binding
.map
.get(&conn_id)
.map(|entry| entry.writer_id)?;
let writer = self
.writers
.map
.get(&writer_id)
.map(|entry| entry.value().clone())?;
Some(ConnWriter {
writer_id,
tx: writer,
})
}
/// Returns the active writer and routing metadata from one hot-binding lookup.
pub async fn get_writer_with_meta(&self, conn_id: u64) -> Option<(ConnWriter, ConnMeta)> {
if !self.routing.map.contains_key(&conn_id) {
return None;
}
let hot = self.hot_binding.map.get(&conn_id)?;
let writer_id = hot.writer_id;
let meta = hot.meta.clone();
let writer = self
.writers
.map
.get(&writer_id)
.map(|entry| entry.value().clone())?;
Some((
ConnWriter {
writer_id,
tx: writer,
},
meta,
))
}
pub async fn active_conn_ids(&self) -> Vec<u64> {
let binding = self.binding.inner.lock().await;
binding.writer_for_conn.keys().copied().collect()
}
pub async fn writer_lost(&self, writer_id: u64) -> Vec<BoundConn> {
let mut binding = self.binding.inner.lock().await;
binding.writers.remove(&writer_id);
self.writers.map.remove(&writer_id);
binding.last_meta_for_writer.remove(&writer_id);
binding.writer_idle_since_epoch_secs.remove(&writer_id);
let conns = binding
.conns_for_writer
.remove(&writer_id)
.unwrap_or_default()
.into_iter()
.collect::<Vec<_>>();
let mut out = Vec::new();
for conn_id in conns {
if binding.writer_for_conn.get(&conn_id).copied() != Some(writer_id) {
continue;
}
binding.writer_for_conn.remove(&conn_id);
let remove_hot = self
.hot_binding
.map
.get(&conn_id)
.map(|hot| hot.writer_id == writer_id)
.unwrap_or(false);
if remove_hot {
self.hot_binding.map.remove(&conn_id);
}
if let Some(m) = binding.meta.get(&conn_id) {
out.push(BoundConn {
conn_id,
meta: m.clone(),
});
}
}
out
}
#[allow(dead_code)]
pub async fn get_meta(&self, conn_id: u64) -> Option<ConnMeta> {
self.hot_binding
.map
.get(&conn_id)
.map(|entry| entry.meta.clone())
}
pub async fn is_writer_empty(&self, writer_id: u64) -> bool {
let binding = self.binding.inner.lock().await;
binding
.conns_for_writer
.get(&writer_id)
.map(|s| s.is_empty())
.unwrap_or(true)
}
#[allow(dead_code)]
pub async fn unregister_writer_if_empty(&self, writer_id: u64) -> bool {
let mut binding = self.binding.inner.lock().await;
let Some(conn_ids) = binding.conns_for_writer.get(&writer_id) else {
// Writer is already absent from the registry.
return true;
};
if !conn_ids.is_empty() {
return false;
}
binding.writers.remove(&writer_id);
self.writers.map.remove(&writer_id);
binding.last_meta_for_writer.remove(&writer_id);
binding.writer_idle_since_epoch_secs.remove(&writer_id);
binding.conns_for_writer.remove(&writer_id);
true
}
#[allow(dead_code)]
pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet<u64> {
let binding = self.binding.inner.lock().await;
let mut out = HashSet::<u64>::with_capacity(writer_ids.len());
for writer_id in writer_ids {
if let Some(conns) = binding.conns_for_writer.get(writer_id)
&& !conns.is_empty()
{
out.insert(*writer_id);
}
}
out
}
}
#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use bytes::Bytes;
use super::{ConnMeta, ConnRegistry, RouteResult};
use crate::transport::middle_proxy::MeResponse;
#[tokio::test]
async fn writer_activity_snapshot_tracks_writer_and_dc_load() {
let registry = ConnRegistry::new();
let (conn_a, _rx_a) = registry.register().await;
let (conn_b, _rx_b) = registry.register().await;
let (conn_c, _rx_c) = registry.register().await;
let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8);
let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8);
registry.register_writer(10, writer_tx_a.clone()).await;
registry.register_writer(20, writer_tx_b.clone()).await;
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443);
assert!(
registry
.bind_writer(
conn_a,
10,
ConnMeta {
target_dc: 2,
client_addr: addr,
our_addr: addr,
proto_flags: 0,
},
)
.await
);
assert!(
registry
.bind_writer(
conn_b,
10,
ConnMeta {
target_dc: -2,
client_addr: addr,
our_addr: addr,
proto_flags: 0,
},
)
.await
);
assert!(
registry
.bind_writer(
conn_c,
20,
ConnMeta {
target_dc: 4,
client_addr: addr,
our_addr: addr,
proto_flags: 0,
},
)
.await
);
let snapshot = registry.writer_activity_snapshot().await;
assert_eq!(snapshot.bound_clients_by_writer.get(&10), Some(&2));
assert_eq!(snapshot.bound_clients_by_writer.get(&20), Some(&1));
assert_eq!(snapshot.active_sessions_by_target_dc.get(&2), Some(&1));
assert_eq!(snapshot.active_sessions_by_target_dc.get(&-2), Some(&1));
assert_eq!(snapshot.active_sessions_by_target_dc.get(&4), Some(&1));
}
#[tokio::test]
async fn route_data_is_bounded_by_byte_permits_before_channel_capacity() {
let registry = ConnRegistry::with_route_byte_permits_for_tests(4, 1);
let (conn_id, mut rx) = registry.register().await;
let routed = registry
.route_nowait(
conn_id,
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xAA]),
route_permit: None,
},
)
.await;
assert!(matches!(routed, RouteResult::Routed));
let blocked = registry
.route_nowait(
conn_id,
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xBB]),
route_permit: None,
},
)
.await;
assert!(
matches!(blocked, RouteResult::QueueFullHigh),
"byte budget must reject data before count capacity is exhausted"
);
drop(rx.recv().await);
let routed_after_drain = registry
.route_nowait(
conn_id,
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xCC]),
route_permit: None,
},
)
.await;
assert!(
matches!(routed_after_drain, RouteResult::Routed),
"receiving queued data must release byte permits"
);
}
#[tokio::test]
async fn bind_writer_rebinds_conn_atomically() {
let registry = ConnRegistry::new();
let (conn_id, _rx) = registry.register().await;
let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8);
let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8);
registry.register_writer(10, writer_tx_a).await;
registry.register_writer(20, writer_tx_b).await;
let client_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443);
let first_our_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 443);
let second_our_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(2, 2, 2, 2)), 443);
assert!(
registry
.bind_writer(
conn_id,
10,
ConnMeta {
target_dc: 2,
client_addr,
our_addr: first_our_addr,
proto_flags: 1,
},
)
.await
);
assert!(
registry
.bind_writer(
conn_id,
20,
ConnMeta {
target_dc: 2,
client_addr,
our_addr: second_our_addr,
proto_flags: 2,
},
)
.await
);
let writer = registry.get_writer(conn_id).await.expect("writer binding");
assert_eq!(writer.writer_id, 20);
let meta = registry.get_meta(conn_id).await.expect("conn meta");
assert_eq!(meta.our_addr, second_our_addr);
assert_eq!(meta.proto_flags, 2);
let snapshot = registry.writer_activity_snapshot().await;
assert_eq!(snapshot.bound_clients_by_writer.get(&10), Some(&0));
assert_eq!(snapshot.bound_clients_by_writer.get(&20), Some(&1));
assert!(
registry
.writer_idle_since_snapshot()
.await
.contains_key(&10)
);
}
#[tokio::test]
async fn writer_lost_does_not_drop_rebound_conn() {
let registry = ConnRegistry::new();
let (conn_id, _rx) = registry.register().await;
let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8);
let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8);
registry.register_writer(10, writer_tx_a).await;
registry.register_writer(20, writer_tx_b).await;
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443);
assert!(
registry
.bind_writer(
conn_id,
10,
ConnMeta {
target_dc: 2,
client_addr: addr,
our_addr: addr,
proto_flags: 0,
},
)
.await
);
assert!(
registry
.bind_writer(
conn_id,
20,
ConnMeta {
target_dc: 2,
client_addr: addr,
our_addr: addr,
proto_flags: 1,
},
)
.await
);
let lost = registry.writer_lost(10).await;
assert!(lost.is_empty());
assert_eq!(
registry
.get_writer(conn_id)
.await
.expect("writer")
.writer_id,
20
);
let removed_writer = registry.unregister(conn_id).await;
assert_eq!(removed_writer, Some(20));
assert!(registry.is_writer_empty(20).await);
}
#[tokio::test]
async fn bind_writer_rejects_unregistered_writer() {
let registry = ConnRegistry::new();
let (conn_id, _rx) = registry.register().await;
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443);
assert!(
!registry
.bind_writer(
conn_id,
10,
ConnMeta {
target_dc: 2,
client_addr: addr,
our_addr: addr,
proto_flags: 0,
},
)
.await
);
assert!(registry.get_writer(conn_id).await.is_none());
}
#[tokio::test]
async fn non_empty_writer_ids_returns_only_writers_with_bound_clients() {
let registry = ConnRegistry::new();
let (conn_id, _rx) = registry.register().await;
let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8);
let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8);
registry.register_writer(10, writer_tx_a).await;
registry.register_writer(20, writer_tx_b).await;
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443);
assert!(
registry
.bind_writer(
conn_id,
10,
ConnMeta {
target_dc: 2,
client_addr: addr,
our_addr: addr,
proto_flags: 0,
},
)
.await
);
let non_empty = registry.non_empty_writer_ids(&[10, 20, 30]).await;
assert!(non_empty.contains(&10));
assert!(!non_empty.contains(&20));
assert!(!non_empty.contains(&30));
}
}
mod tests;

View File

@@ -0,0 +1,288 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use bytes::Bytes;
use super::{ConnMeta, ConnRegistry, RouteResult};
use crate::transport::middle_proxy::MeResponse;
#[tokio::test]
async fn writer_activity_snapshot_tracks_writer_and_dc_load() {
let registry = ConnRegistry::new();
let (conn_a, _rx_a) = registry.register().await;
let (conn_b, _rx_b) = registry.register().await;
let (conn_c, _rx_c) = registry.register().await;
let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8);
let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8);
registry.register_writer(10, writer_tx_a.clone()).await;
registry.register_writer(20, writer_tx_b.clone()).await;
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443);
assert!(
registry
.bind_writer(
conn_a,
10,
ConnMeta {
target_dc: 2,
client_addr: addr,
our_addr: addr,
proto_flags: 0,
},
)
.await
);
assert!(
registry
.bind_writer(
conn_b,
10,
ConnMeta {
target_dc: -2,
client_addr: addr,
our_addr: addr,
proto_flags: 0,
},
)
.await
);
assert!(
registry
.bind_writer(
conn_c,
20,
ConnMeta {
target_dc: 4,
client_addr: addr,
our_addr: addr,
proto_flags: 0,
},
)
.await
);
let snapshot = registry.writer_activity_snapshot().await;
assert_eq!(snapshot.bound_clients_by_writer.get(&10), Some(&2));
assert_eq!(snapshot.bound_clients_by_writer.get(&20), Some(&1));
assert_eq!(snapshot.active_sessions_by_target_dc.get(&2), Some(&1));
assert_eq!(snapshot.active_sessions_by_target_dc.get(&-2), Some(&1));
assert_eq!(snapshot.active_sessions_by_target_dc.get(&4), Some(&1));
}
#[tokio::test]
async fn route_data_is_bounded_by_byte_permits_before_channel_capacity() {
let registry = ConnRegistry::with_route_byte_permits_for_tests(4, 1);
let (conn_id, mut rx) = registry.register().await;
let routed = registry
.route_nowait(
conn_id,
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xAA]),
route_permit: None,
},
)
.await;
assert!(matches!(routed, RouteResult::Routed));
let blocked = registry
.route_nowait(
conn_id,
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xBB]),
route_permit: None,
},
)
.await;
assert!(
matches!(blocked, RouteResult::QueueFullHigh),
"byte budget must reject data before count capacity is exhausted"
);
drop(rx.recv().await);
let routed_after_drain = registry
.route_nowait(
conn_id,
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xCC]),
route_permit: None,
},
)
.await;
assert!(
matches!(routed_after_drain, RouteResult::Routed),
"receiving queued data must release byte permits"
);
}
#[tokio::test]
async fn bind_writer_rebinds_conn_atomically() {
let registry = ConnRegistry::new();
let (conn_id, _rx) = registry.register().await;
let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8);
let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8);
registry.register_writer(10, writer_tx_a).await;
registry.register_writer(20, writer_tx_b).await;
let client_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443);
let first_our_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 443);
let second_our_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(2, 2, 2, 2)), 443);
assert!(
registry
.bind_writer(
conn_id,
10,
ConnMeta {
target_dc: 2,
client_addr,
our_addr: first_our_addr,
proto_flags: 1,
},
)
.await
);
assert!(
registry
.bind_writer(
conn_id,
20,
ConnMeta {
target_dc: 2,
client_addr,
our_addr: second_our_addr,
proto_flags: 2,
},
)
.await
);
let writer = registry.get_writer(conn_id).await.expect("writer binding");
assert_eq!(writer.writer_id, 20);
let meta = registry.get_meta(conn_id).await.expect("conn meta");
assert_eq!(meta.our_addr, second_our_addr);
assert_eq!(meta.proto_flags, 2);
let snapshot = registry.writer_activity_snapshot().await;
assert_eq!(snapshot.bound_clients_by_writer.get(&10), Some(&0));
assert_eq!(snapshot.bound_clients_by_writer.get(&20), Some(&1));
assert!(
registry
.writer_idle_since_snapshot()
.await
.contains_key(&10)
);
}
#[tokio::test]
async fn writer_lost_does_not_drop_rebound_conn() {
let registry = ConnRegistry::new();
let (conn_id, _rx) = registry.register().await;
let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8);
let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8);
registry.register_writer(10, writer_tx_a).await;
registry.register_writer(20, writer_tx_b).await;
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443);
assert!(
registry
.bind_writer(
conn_id,
10,
ConnMeta {
target_dc: 2,
client_addr: addr,
our_addr: addr,
proto_flags: 0,
},
)
.await
);
assert!(
registry
.bind_writer(
conn_id,
20,
ConnMeta {
target_dc: 2,
client_addr: addr,
our_addr: addr,
proto_flags: 1,
},
)
.await
);
let lost = registry.writer_lost(10).await;
assert!(lost.is_empty());
assert_eq!(
registry
.get_writer(conn_id)
.await
.expect("writer")
.writer_id,
20
);
let removed_writer = registry.unregister(conn_id).await;
assert_eq!(removed_writer, Some(20));
assert!(registry.is_writer_empty(20).await);
}
#[tokio::test]
async fn bind_writer_rejects_unregistered_writer() {
let registry = ConnRegistry::new();
let (conn_id, _rx) = registry.register().await;
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443);
assert!(
!registry
.bind_writer(
conn_id,
10,
ConnMeta {
target_dc: 2,
client_addr: addr,
our_addr: addr,
proto_flags: 0,
},
)
.await
);
assert!(registry.get_writer(conn_id).await.is_none());
}
#[tokio::test]
async fn non_empty_writer_ids_returns_only_writers_with_bound_clients() {
let registry = ConnRegistry::new();
let (conn_id, _rx) = registry.register().await;
let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8);
let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8);
registry.register_writer(10, writer_tx_a).await;
registry.register_writer(20, writer_tx_b).await;
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443);
assert!(
registry
.bind_writer(
conn_id,
10,
ConnMeta {
target_dc: 2,
client_addr: addr,
our_addr: addr,
proto_flags: 0,
},
)
.await
);
let non_empty = registry.non_empty_writer_ids(&[10, 20, 30]).await;
assert!(non_empty.contains(&10));
assert!(!non_empty.contains(&20));
assert!(!non_empty.contains(&30));
}

View File

@@ -0,0 +1,481 @@
use std::collections::{HashMap, HashSet};
use std::sync::atomic::Ordering;
use std::time::Duration;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use super::super::codec::WriterCommand;
use super::super::{MeResponse, RouteBytePermit};
use super::{
BoundConn, ConnMeta, ConnRegistry, ConnWriter, HotConnBinding, RouteResult,
WriterActivitySnapshot,
};
impl ConnRegistry {
pub async fn register_writer(&self, writer_id: u64, tx: mpsc::Sender<WriterCommand>) {
let mut binding = self.binding.inner.lock().await;
binding.writers.insert(writer_id, tx.clone());
binding
.conns_for_writer
.entry(writer_id)
.or_insert_with(HashSet::new);
self.writers.map.insert(writer_id, tx);
}
/// Unregister connection, returning associated writer_id if any.
pub async fn unregister(&self, id: u64) -> Option<u64> {
self.routing.map.remove(&id);
self.routing.byte_budget.remove(&id);
self.hot_binding.map.remove(&id);
let mut binding = self.binding.inner.lock().await;
binding.meta.remove(&id);
if let Some(writer_id) = binding.writer_for_conn.remove(&id) {
let became_empty = if let Some(set) = binding.conns_for_writer.get_mut(&writer_id) {
set.remove(&id);
set.is_empty()
} else {
false
};
if became_empty {
binding
.writer_idle_since_epoch_secs
.insert(writer_id, Self::now_epoch_secs());
}
return Some(writer_id);
}
None
}
async fn attach_route_byte_permit(
&self,
id: u64,
resp: MeResponse,
timeout_ms: Option<u64>,
) -> std::result::Result<MeResponse, RouteResult> {
let MeResponse::Data {
flags,
data,
route_permit,
} = resp
else {
return Ok(resp);
};
if route_permit.is_some() {
return Ok(MeResponse::Data {
flags,
data,
route_permit,
});
}
let Some(semaphore) = self
.routing
.byte_budget
.get(&id)
.map(|entry| entry.value().clone())
else {
return Err(RouteResult::NoConn);
};
let permits = Self::route_data_permits(data.len());
let permit = match timeout_ms {
Some(0) => semaphore
.try_acquire_many_owned(permits)
.map_err(|_| RouteResult::QueueFullHigh)?,
Some(timeout_ms) => {
let acquire = semaphore.acquire_many_owned(permits);
match tokio::time::timeout(Duration::from_millis(timeout_ms.max(1)), acquire).await
{
Ok(Ok(permit)) => permit,
Ok(Err(_)) => return Err(RouteResult::ChannelClosed),
Err(_) => return Err(RouteResult::QueueFullHigh),
}
}
None => semaphore
.acquire_many_owned(permits)
.await
.map_err(|_| RouteResult::ChannelClosed)?,
};
Ok(MeResponse::Data {
flags,
data,
route_permit: Some(RouteBytePermit::new(permit)),
})
}
#[allow(dead_code)]
pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult {
let tx = self.routing.map.get(&id).map(|entry| entry.value().clone());
let Some(tx) = tx else {
return RouteResult::NoConn;
};
let base_timeout_ms = self
.route_backpressure_base_timeout_ms
.load(Ordering::Relaxed)
.max(1);
let resp = match self
.attach_route_byte_permit(id, resp, Some(base_timeout_ms))
.await
{
Ok(resp) => resp,
Err(result) => return result,
};
match tx.try_send(resp) {
Ok(()) => RouteResult::Routed,
Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed,
Err(TrySendError::Full(resp)) => {
// Absorb short bursts without dropping/closing the session immediately.
let high_timeout_ms = self
.route_backpressure_high_timeout_ms
.load(Ordering::Relaxed)
.max(base_timeout_ms);
let high_watermark_pct = self
.route_backpressure_high_watermark_pct
.load(Ordering::Relaxed)
.clamp(1, 100);
let used = self.route_channel_capacity.saturating_sub(tx.capacity());
let used_pct = if self.route_channel_capacity == 0 {
100
} else {
(used.saturating_mul(100) / self.route_channel_capacity) as u8
};
let high_profile = used_pct >= high_watermark_pct;
let timeout_ms = if high_profile {
high_timeout_ms
} else {
base_timeout_ms
};
let timeout_dur = Duration::from_millis(timeout_ms);
match tokio::time::timeout(timeout_dur, tx.send(resp)).await {
Ok(Ok(())) => RouteResult::Routed,
Ok(Err(_)) => RouteResult::ChannelClosed,
Err(_) => {
if high_profile {
RouteResult::QueueFullHigh
} else {
RouteResult::QueueFullBase
}
}
}
}
}
}
pub async fn route_nowait(&self, id: u64, resp: MeResponse) -> RouteResult {
let tx = self.routing.map.get(&id).map(|entry| entry.value().clone());
let Some(tx) = tx else {
return RouteResult::NoConn;
};
let resp = match self.attach_route_byte_permit(id, resp, Some(0)).await {
Ok(resp) => resp,
Err(result) => return result,
};
match tx.try_send(resp) {
Ok(()) => RouteResult::Routed,
Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed,
Err(TrySendError::Full(_)) => RouteResult::QueueFullBase,
}
}
pub async fn route_with_timeout(
&self,
id: u64,
resp: MeResponse,
timeout_ms: u64,
) -> RouteResult {
if timeout_ms == 0 {
return self.route_nowait(id, resp).await;
}
let tx = self.routing.map.get(&id).map(|entry| entry.value().clone());
let Some(tx) = tx else {
return RouteResult::NoConn;
};
let resp = match self
.attach_route_byte_permit(id, resp, Some(timeout_ms))
.await
{
Ok(resp) => resp,
Err(result) => return result,
};
match tx.try_send(resp) {
Ok(()) => RouteResult::Routed,
Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed,
Err(TrySendError::Full(resp)) => {
let high_watermark_pct = self
.route_backpressure_high_watermark_pct
.load(Ordering::Relaxed)
.clamp(1, 100);
let used = self.route_channel_capacity.saturating_sub(tx.capacity());
let used_pct = if self.route_channel_capacity == 0 {
100
} else {
(used.saturating_mul(100) / self.route_channel_capacity) as u8
};
let high_profile = used_pct >= high_watermark_pct;
let timeout_dur = Duration::from_millis(timeout_ms.max(1));
match tokio::time::timeout(timeout_dur, tx.send(resp)).await {
Ok(Ok(())) => RouteResult::Routed,
Ok(Err(_)) => RouteResult::ChannelClosed,
Err(_) => {
if high_profile {
RouteResult::QueueFullHigh
} else {
RouteResult::QueueFullBase
}
}
}
}
}
}
pub async fn bind_writer(&self, conn_id: u64, writer_id: u64, meta: ConnMeta) -> bool {
let mut binding = self.binding.inner.lock().await;
// ROUTING IS THE SOURCE OF TRUTH:
// never keep/attach writer binding for a connection that is already
// absent from the routing table.
if !self.routing.map.contains_key(&conn_id) {
return false;
}
if !binding.writers.contains_key(&writer_id) {
return false;
}
let previous_writer_id = binding.writer_for_conn.insert(conn_id, writer_id);
if let Some(previous_writer_id) = previous_writer_id
&& previous_writer_id != writer_id
{
let became_empty =
if let Some(set) = binding.conns_for_writer.get_mut(&previous_writer_id) {
set.remove(&conn_id);
set.is_empty()
} else {
false
};
if became_empty {
binding
.writer_idle_since_epoch_secs
.insert(previous_writer_id, Self::now_epoch_secs());
}
}
binding.meta.insert(conn_id, meta.clone());
binding.last_meta_for_writer.insert(writer_id, meta.clone());
binding.writer_idle_since_epoch_secs.remove(&writer_id);
binding
.conns_for_writer
.entry(writer_id)
.or_insert_with(HashSet::new)
.insert(conn_id);
self.hot_binding
.map
.insert(conn_id, HotConnBinding { writer_id, meta });
true
}
pub async fn mark_writer_idle(&self, writer_id: u64) {
let mut binding = self.binding.inner.lock().await;
binding
.conns_for_writer
.entry(writer_id)
.or_insert_with(HashSet::new);
binding
.writer_idle_since_epoch_secs
.entry(writer_id)
.or_insert(Self::now_epoch_secs());
}
pub async fn get_last_writer_meta(&self, writer_id: u64) -> Option<ConnMeta> {
let binding = self.binding.inner.lock().await;
binding.last_meta_for_writer.get(&writer_id).cloned()
}
pub async fn writer_idle_since_snapshot(&self) -> HashMap<u64, u64> {
let binding = self.binding.inner.lock().await;
binding.writer_idle_since_epoch_secs.clone()
}
pub async fn writer_idle_since_for_writer_ids(&self, writer_ids: &[u64]) -> HashMap<u64, u64> {
let binding = self.binding.inner.lock().await;
let mut out = HashMap::<u64, u64>::with_capacity(writer_ids.len());
for writer_id in writer_ids {
if let Some(idle_since) = binding.writer_idle_since_epoch_secs.get(writer_id).copied() {
out.insert(*writer_id, idle_since);
}
}
out
}
pub(in crate::transport::middle_proxy) async fn writer_activity_snapshot(
&self,
) -> WriterActivitySnapshot {
let binding = self.binding.inner.lock().await;
let mut bound_clients_by_writer = HashMap::<u64, usize>::new();
let mut active_sessions_by_target_dc = HashMap::<i16, usize>::new();
for (writer_id, conn_ids) in &binding.conns_for_writer {
bound_clients_by_writer.insert(*writer_id, conn_ids.len());
}
for conn_meta in binding.meta.values() {
if conn_meta.target_dc == 0 {
continue;
}
*active_sessions_by_target_dc
.entry(conn_meta.target_dc)
.or_insert(0) += 1;
}
WriterActivitySnapshot {
bound_clients_by_writer,
active_sessions_by_target_dc,
}
}
pub async fn get_writer(&self, conn_id: u64) -> Option<ConnWriter> {
if !self.routing.map.contains_key(&conn_id) {
return None;
}
let writer_id = self
.hot_binding
.map
.get(&conn_id)
.map(|entry| entry.writer_id)?;
let writer = self
.writers
.map
.get(&writer_id)
.map(|entry| entry.value().clone())?;
Some(ConnWriter {
writer_id,
tx: writer,
})
}
/// Returns the active writer and routing metadata from one hot-binding lookup.
pub async fn get_writer_with_meta(&self, conn_id: u64) -> Option<(ConnWriter, ConnMeta)> {
if !self.routing.map.contains_key(&conn_id) {
return None;
}
let hot = self.hot_binding.map.get(&conn_id)?;
let writer_id = hot.writer_id;
let meta = hot.meta.clone();
let writer = self
.writers
.map
.get(&writer_id)
.map(|entry| entry.value().clone())?;
Some((
ConnWriter {
writer_id,
tx: writer,
},
meta,
))
}
pub async fn active_conn_ids(&self) -> Vec<u64> {
let binding = self.binding.inner.lock().await;
binding.writer_for_conn.keys().copied().collect()
}
pub async fn writer_lost(&self, writer_id: u64) -> Vec<BoundConn> {
let mut binding = self.binding.inner.lock().await;
binding.writers.remove(&writer_id);
self.writers.map.remove(&writer_id);
binding.last_meta_for_writer.remove(&writer_id);
binding.writer_idle_since_epoch_secs.remove(&writer_id);
let conns = binding
.conns_for_writer
.remove(&writer_id)
.unwrap_or_default()
.into_iter()
.collect::<Vec<_>>();
let mut out = Vec::new();
for conn_id in conns {
if binding.writer_for_conn.get(&conn_id).copied() != Some(writer_id) {
continue;
}
binding.writer_for_conn.remove(&conn_id);
let remove_hot = self
.hot_binding
.map
.get(&conn_id)
.map(|hot| hot.writer_id == writer_id)
.unwrap_or(false);
if remove_hot {
self.hot_binding.map.remove(&conn_id);
}
if let Some(m) = binding.meta.get(&conn_id) {
out.push(BoundConn {
conn_id,
meta: m.clone(),
});
}
}
out
}
#[allow(dead_code)]
pub async fn get_meta(&self, conn_id: u64) -> Option<ConnMeta> {
self.hot_binding
.map
.get(&conn_id)
.map(|entry| entry.meta.clone())
}
pub async fn is_writer_empty(&self, writer_id: u64) -> bool {
let binding = self.binding.inner.lock().await;
binding
.conns_for_writer
.get(&writer_id)
.map(|s| s.is_empty())
.unwrap_or(true)
}
#[allow(dead_code)]
pub async fn unregister_writer_if_empty(&self, writer_id: u64) -> bool {
let mut binding = self.binding.inner.lock().await;
let Some(conn_ids) = binding.conns_for_writer.get(&writer_id) else {
// Writer is already absent from the registry.
return true;
};
if !conn_ids.is_empty() {
return false;
}
binding.writers.remove(&writer_id);
self.writers.map.remove(&writer_id);
binding.last_meta_for_writer.remove(&writer_id);
binding.writer_idle_since_epoch_secs.remove(&writer_id);
binding.conns_for_writer.remove(&writer_id);
true
}
#[allow(dead_code)]
pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet<u64> {
let binding = self.binding.inner.lock().await;
let mut out = HashSet::<u64>::with_capacity(writer_ids.len());
for writer_id in writer_ids {
if let Some(conns) = binding.conns_for_writer.get(writer_id)
&& !conns.is_empty()
{
out.insert(*writer_id);
}
}
out
}
}

View File

@@ -1,7 +1,6 @@
#![allow(clippy::too_many_arguments)]
use std::cmp::Reverse;
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::Ordering;
@@ -10,16 +9,14 @@ use std::time::{Duration, Instant};
use tokio::sync::mpsc::error::TrySendError;
use tracing::{debug, warn};
use super::MePool;
use super::codec::{ProxyReqCommand, WriterCommand};
use super::registry::ConnMeta;
use super::wire::build_proxy_req_payload;
use crate::config::{MeRouteNoWriterMode, MeWriterPickMode};
use crate::error::{ProxyError, Result};
use crate::network::IpFamily;
use crate::protocol::constants::{RPC_CLOSE_CONN_U32, RPC_CLOSE_EXT_U32};
use super::MePool;
use super::codec::{WriterCommand, build_control_payload};
use super::pool::WriterContour;
use super::registry::ConnMeta;
use super::wire::build_proxy_req_payload;
use crate::stream::PooledBuffer;
use rand::seq::SliceRandom;
const IDLE_WRITER_PENALTY_MID_SECS: u64 = 45;
@@ -33,6 +30,21 @@ const PICK_PENALTY_DRAINING: u64 = 600;
const PICK_PENALTY_STALE: u64 = 300;
const PICK_PENALTY_DEGRADED: u64 = 250;
mod close;
mod recovery;
mod selection;
fn proxy_tag_array(tag: Option<&[u8]>) -> Option<[u8; 16]> {
tag.and_then(|tag| <[u8; 16]>::try_from(tag).ok())
}
fn proxy_req_payload_from_command(cmd: WriterCommand) -> Option<PooledBuffer> {
match cmd {
WriterCommand::ProxyReq(command) => Some(command.payload),
_ => None,
}
}
impl MePool {
/// Send RPC_PROXY_REQ. `tag_override`: per-user ad_tag (from access.user_ad_tags); if None, uses pool default.
pub async fn send_proxy_req(
@@ -84,14 +96,10 @@ impl MePool {
let mut hybrid_wait_current = hybrid_wait_step;
loop {
if let Some((current, current_meta)) =
self.registry.get_writer_with_meta(conn_id).await
if let Some((current, current_meta)) = self.registry.get_writer_with_meta(conn_id).await
{
let (current_payload, _) = build_routed_payload(current_meta.our_addr);
match current
.tx
.try_send(WriterCommand::Data(current_payload))
{
match current.tx.try_send(WriterCommand::Data(current_payload)) {
Ok(()) => {
self.note_hybrid_route_success();
return Ok(());
@@ -528,401 +536,93 @@ impl MePool {
}
}
async fn wait_for_writer_until(&self, deadline: Instant) -> bool {
let mut rx = self.writer_epoch.subscribe();
if !self.writers.read().await.is_empty() {
return true;
}
let now = Instant::now();
if now >= deadline {
return !self.writers.read().await.is_empty();
}
let timeout = deadline.saturating_duration_since(now);
if tokio::time::timeout(timeout, rx.changed()).await.is_ok() {
return !self.writers.read().await.is_empty();
}
!self.writers.read().await.is_empty()
}
async fn wait_for_candidate_until(&self, routed_dc: i32, deadline: Instant) -> bool {
let mut rx = self.writer_epoch.subscribe();
loop {
if self.has_candidate_for_target_dc(routed_dc).await {
return true;
}
let now = Instant::now();
if now >= deadline {
return self.has_candidate_for_target_dc(routed_dc).await;
}
if self.has_candidate_for_target_dc(routed_dc).await {
return true;
}
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return self.has_candidate_for_target_dc(routed_dc).await;
}
if tokio::time::timeout(remaining, rx.changed()).await.is_err() {
return self.has_candidate_for_target_dc(routed_dc).await;
}
}
}
async fn has_candidate_for_target_dc(&self, routed_dc: i32) -> bool {
let writers_snapshot = {
let ws = self.writers.read().await;
if ws.is_empty() {
return false;
}
ws.clone()
};
let mut candidate_indices = self
.candidate_indices_for_dc(&writers_snapshot, routed_dc, false)
.await;
if candidate_indices.is_empty() {
candidate_indices = self
.candidate_indices_for_dc(&writers_snapshot, routed_dc, true)
.await;
}
!candidate_indices.is_empty()
}
async fn trigger_async_recovery_for_target_dc(self: &Arc<Self>, routed_dc: i32) -> bool {
let endpoints = self.endpoint_candidates_for_target_dc(routed_dc).await;
if endpoints.is_empty() {
return false;
}
self.stats.increment_me_async_recovery_trigger_total();
for addr in endpoints.into_iter().take(8) {
self.trigger_immediate_refill_for_dc(addr, routed_dc);
}
true
}
async fn trigger_async_recovery_global(self: &Arc<Self>) {
self.stats.increment_me_async_recovery_trigger_total();
let mut seen = HashSet::<(i32, SocketAddr)>::new();
for family in self.family_order() {
let map_guard = match family {
IpFamily::V4 => self.proxy_map_v4.read().await,
IpFamily::V6 => self.proxy_map_v6.read().await,
};
for (dc, addrs) in map_guard.iter() {
for (ip, port) in addrs {
let addr = SocketAddr::new(*ip, *port);
if seen.insert((*dc, addr)) {
self.trigger_immediate_refill_for_dc(addr, *dc);
}
if seen.len() >= 8 {
return;
}
}
}
}
}
async fn endpoint_candidates_for_target_dc(&self, routed_dc: i32) -> Vec<SocketAddr> {
self.preferred_endpoints_for_dc(routed_dc).await
}
async fn maybe_trigger_hybrid_recovery(
/// Send RPC_PROXY_REQ while keeping the first bound-writer path allocation-light.
pub async fn send_proxy_req_pooled(
self: &Arc<Self>,
routed_dc: i32,
hybrid_recovery_round: &mut u32,
hybrid_last_recovery_at: &mut Option<Instant>,
hybrid_wait_step: Duration,
) {
if !self.try_consume_hybrid_recovery_trigger_slot(HYBRID_RECOVERY_TRIGGER_MIN_INTERVAL_MS) {
return;
}
if let Some(last) = *hybrid_last_recovery_at
&& last.elapsed() < hybrid_wait_step
{
return;
}
conn_id: u64,
target_dc: i16,
client_addr: SocketAddr,
our_addr: SocketAddr,
payload: PooledBuffer,
proto_flags: u32,
tag_override: Option<[u8; 16]>,
) -> Result<()> {
let tag = tag_override.or_else(|| proxy_tag_array(self.proxy_tag.as_deref()));
let round = *hybrid_recovery_round;
let target_triggered = self.trigger_async_recovery_for_target_dc(routed_dc).await;
if !target_triggered || round.is_multiple_of(HYBRID_GLOBAL_BURST_PERIOD_ROUNDS) {
self.trigger_async_recovery_global().await;
}
*hybrid_recovery_round = round.saturating_add(1);
*hybrid_last_recovery_at = Some(Instant::now());
}
fn hybrid_total_wait_budget(&self) -> Duration {
let base = self
.route_runtime
.me_route_hybrid_max_wait
.max(Duration::from_millis(50));
let now_ms = Self::now_epoch_millis();
let last_success_ms = self
.route_runtime
.me_route_last_success_epoch_ms
.load(Ordering::Relaxed);
if last_success_ms != 0
&& now_ms.saturating_sub(last_success_ms) <= HYBRID_RECENT_SUCCESS_WINDOW_MS
{
return base.saturating_mul(2);
}
base
}
fn note_hybrid_route_success(&self) {
self.route_runtime
.me_route_last_success_epoch_ms
.store(Self::now_epoch_millis(), Ordering::Relaxed);
}
fn on_hybrid_timeout(&self, deadline: Instant, routed_dc: i32) {
self.stats.increment_me_hybrid_timeout_total();
let now_ms = Self::now_epoch_millis();
let mut last_warn_ms = self
.route_runtime
.me_route_hybrid_timeout_warn_epoch_ms
.load(Ordering::Relaxed);
while now_ms.saturating_sub(last_warn_ms) >= HYBRID_TIMEOUT_WARN_RATE_LIMIT_MS {
match self
.route_runtime
.me_route_hybrid_timeout_warn_epoch_ms
.compare_exchange_weak(last_warn_ms, now_ms, Ordering::AcqRel, Ordering::Relaxed)
{
Ok(_) => {
warn!(
routed_dc,
budget_ms = self.hybrid_total_wait_budget().as_millis() as u64,
elapsed_ms = deadline.elapsed().as_millis() as u64,
"ME hybrid route timeout reached"
);
break;
if let Some((current, current_meta)) = self.registry.get_writer_with_meta(conn_id).await {
let command = WriterCommand::ProxyReq(ProxyReqCommand {
conn_id,
client_addr,
our_addr: current_meta.our_addr,
proto_flags,
proxy_tag: tag,
payload,
});
match current.tx.try_send(command) {
Ok(()) => {
self.note_hybrid_route_success();
return Ok(());
}
Err(actual) => last_warn_ms = actual,
}
}
}
fn try_consume_hybrid_recovery_trigger_slot(&self, min_interval_ms: u64) -> bool {
let now_ms = Self::now_epoch_millis();
let mut last_trigger_ms = self
.route_runtime
.me_async_recovery_last_trigger_epoch_ms
.load(Ordering::Relaxed);
loop {
if now_ms.saturating_sub(last_trigger_ms) < min_interval_ms {
return false;
}
match self
.route_runtime
.me_async_recovery_last_trigger_epoch_ms
.compare_exchange_weak(last_trigger_ms, now_ms, Ordering::AcqRel, Ordering::Relaxed)
{
Ok(_) => return true,
Err(actual) => last_trigger_ms = actual,
}
}
}
pub async fn send_close(self: &Arc<Self>, conn_id: u64) -> Result<()> {
if let Some(w) = self.registry.get_writer(conn_id).await {
let payload = build_control_payload(RPC_CLOSE_EXT_U32, conn_id);
if w.tx
.send(WriterCommand::ControlAndFlush(payload))
.await
.is_err()
{
debug!("ME close write failed");
self.remove_writer_and_close_clients(w.writer_id).await;
}
} else {
debug!(conn_id, "ME close skipped (writer missing)");
}
self.registry.unregister(conn_id).await;
Ok(())
}
pub async fn send_close_conn(self: &Arc<Self>, conn_id: u64) -> Result<()> {
if let Some(w) = self.registry.get_writer(conn_id).await {
let payload = build_control_payload(RPC_CLOSE_CONN_U32, conn_id);
match w.tx.try_send(WriterCommand::ControlAndFlush(payload)) {
Ok(()) => {}
Err(TrySendError::Full(cmd)) => {
let _ = tokio::time::timeout(Duration::from_millis(50), w.tx.send(cmd)).await;
}
Err(TrySendError::Closed(_)) => {
debug!(conn_id, "ME close_conn skipped: writer channel closed");
Err(TrySendError::Full(cmd)) => match current.tx.send(cmd).await {
Ok(()) => {
self.note_hybrid_route_success();
return Ok(());
}
Err(send_err) => {
let Some(payload) = proxy_req_payload_from_command(send_err.0) else {
return Err(ProxyError::Proxy(
"ME writer rejected unexpected command type".into(),
));
};
warn!(writer_id = current.writer_id, "ME writer channel closed");
self.remove_writer_and_close_clients(current.writer_id)
.await;
return self
.send_proxy_req(
conn_id,
target_dc,
client_addr,
our_addr,
payload.as_ref(),
proto_flags,
tag.as_ref().map(|tag| tag.as_slice()),
)
.await;
}
},
Err(TrySendError::Closed(cmd)) => {
let Some(payload) = proxy_req_payload_from_command(cmd) else {
return Err(ProxyError::Proxy(
"ME writer rejected unexpected command type".into(),
));
};
warn!(writer_id = current.writer_id, "ME writer channel closed");
self.remove_writer_and_close_clients(current.writer_id)
.await;
return self
.send_proxy_req(
conn_id,
target_dc,
client_addr,
our_addr,
payload.as_ref(),
proto_flags,
tag.as_ref().map(|tag| tag.as_slice()),
)
.await;
}
}
} else {
debug!(conn_id, "ME close_conn skipped (writer missing)");
}
self.registry.unregister(conn_id).await;
Ok(())
}
pub async fn shutdown_send_close_conn_all(self: &Arc<Self>) -> usize {
let conn_ids = self.registry.active_conn_ids().await;
let total = conn_ids.len();
for conn_id in conn_ids {
let _ = self.send_close_conn(conn_id).await;
}
total
}
pub fn connection_count(&self) -> usize {
self.conn_count.load(Ordering::Relaxed)
}
pub(super) async fn candidate_indices_for_dc(
&self,
writers: &[super::pool::MeWriter],
routed_dc: i32,
include_warm: bool,
) -> Vec<usize> {
let preferred = self.preferred_endpoints_for_dc(routed_dc).await;
if preferred.is_empty() {
return Vec::new();
}
let mut out = Vec::new();
for (idx, w) in writers.iter().enumerate() {
if !self.writer_eligible_for_selection(w, include_warm) {
continue;
}
if w.writer_dc == routed_dc && preferred.contains(&w.addr) {
out.push(idx);
}
}
out
}
fn writer_eligible_for_selection(
&self,
writer: &super::pool::MeWriter,
include_warm: bool,
) -> bool {
if !self.writer_accepts_new_binding(writer) {
return false;
}
match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) {
WriterContour::Active => true,
WriterContour::Warm => include_warm,
WriterContour::Draining => true,
}
}
fn writer_contour_rank_for_selection(&self, writer: &super::pool::MeWriter) -> usize {
match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) {
WriterContour::Active => 0,
WriterContour::Warm => 1,
WriterContour::Draining => 2,
}
}
fn writer_idle_rank_for_selection(
&self,
writer: &super::pool::MeWriter,
idle_since_by_writer: &HashMap<u64, u64>,
now_epoch_secs: u64,
) -> usize {
let Some(idle_since) = idle_since_by_writer.get(&writer.id).copied() else {
return 0;
};
let idle_age_secs = now_epoch_secs.saturating_sub(idle_since);
if idle_age_secs >= IDLE_WRITER_PENALTY_HIGH_SECS {
2
} else if idle_age_secs >= IDLE_WRITER_PENALTY_MID_SECS {
1
} else {
0
}
}
fn writer_pick_score(
&self,
writer: &super::pool::MeWriter,
idle_since_by_writer: &HashMap<u64, u64>,
now_epoch_secs: u64,
) -> u64 {
let contour_penalty = match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) {
WriterContour::Active => 0,
WriterContour::Warm => PICK_PENALTY_WARM,
WriterContour::Draining => PICK_PENALTY_DRAINING,
};
let stale_penalty = if writer.generation < self.current_generation() {
PICK_PENALTY_STALE
} else {
0
};
let degraded_penalty = if writer.degraded.load(Ordering::Relaxed) {
PICK_PENALTY_DEGRADED
} else {
0
};
let idle_penalty =
(self.writer_idle_rank_for_selection(writer, idle_since_by_writer, now_epoch_secs)
as u64)
* 100;
let queue_cap = self.writer_lifecycle.writer_cmd_channel_capacity.max(1) as u64;
let queue_remaining = writer.tx.capacity() as u64;
let queue_used = queue_cap.saturating_sub(queue_remaining.min(queue_cap));
let queue_util_pct = queue_used.saturating_mul(100) / queue_cap;
let queue_penalty = queue_util_pct.saturating_mul(4);
let rtt_penalty =
((writer.rtt_ema_ms_x10.load(Ordering::Relaxed) as u64).saturating_add(5) / 10)
.min(400);
contour_penalty
.saturating_add(stale_penalty)
.saturating_add(degraded_penalty)
.saturating_add(idle_penalty)
.saturating_add(queue_penalty)
.saturating_add(rtt_penalty)
}
fn p2c_ordered_candidate_indices(
&self,
candidate_indices: &[usize],
writers_snapshot: &[super::pool::MeWriter],
idle_since_by_writer: &HashMap<u64, u64>,
now_epoch_secs: u64,
start: usize,
sample_size: usize,
) -> Vec<usize> {
let total = candidate_indices.len();
if total == 0 {
return Vec::new();
}
let mut sampled = Vec::<usize>::with_capacity(sample_size.min(total));
let mut seen = HashSet::<usize>::with_capacity(total);
for offset in 0..sample_size.min(total) {
let idx = candidate_indices[(start + offset) % total];
if seen.insert(idx) {
sampled.push(idx);
}
}
sampled.sort_by_key(|idx| {
let writer = &writers_snapshot[*idx];
(
self.writer_pick_score(writer, idle_since_by_writer, now_epoch_secs),
writer.addr,
writer.id,
)
});
let mut ordered = Vec::<usize>::with_capacity(total);
ordered.extend(sampled.iter().copied());
for offset in 0..total {
let idx = candidate_indices[(start + offset) % total];
if seen.insert(idx) {
ordered.push(idx);
}
}
ordered
self.send_proxy_req(
conn_id,
target_dc,
client_addr,
our_addr,
payload.as_ref(),
proto_flags,
tag.as_ref().map(|tag| tag.as_slice()),
)
.await
}
}

View File

@@ -0,0 +1,66 @@
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Duration;
use tokio::sync::mpsc::error::TrySendError;
use tracing::debug;
use crate::error::Result;
use crate::protocol::constants::{RPC_CLOSE_CONN_U32, RPC_CLOSE_EXT_U32};
use super::super::MePool;
use super::super::codec::{WriterCommand, build_control_payload};
impl MePool {
pub async fn send_close(self: &Arc<Self>, conn_id: u64) -> Result<()> {
if let Some(w) = self.registry.get_writer(conn_id).await {
let payload = build_control_payload(RPC_CLOSE_EXT_U32, conn_id);
if w.tx
.send(WriterCommand::ControlAndFlush(payload))
.await
.is_err()
{
debug!("ME close write failed");
self.remove_writer_and_close_clients(w.writer_id).await;
}
} else {
debug!(conn_id, "ME close skipped (writer missing)");
}
self.registry.unregister(conn_id).await;
Ok(())
}
pub async fn send_close_conn(self: &Arc<Self>, conn_id: u64) -> Result<()> {
if let Some(w) = self.registry.get_writer(conn_id).await {
let payload = build_control_payload(RPC_CLOSE_CONN_U32, conn_id);
match w.tx.try_send(WriterCommand::ControlAndFlush(payload)) {
Ok(()) => {}
Err(TrySendError::Full(cmd)) => {
let _ = tokio::time::timeout(Duration::from_millis(50), w.tx.send(cmd)).await;
}
Err(TrySendError::Closed(_)) => {
debug!(conn_id, "ME close_conn skipped: writer channel closed");
}
}
} else {
debug!(conn_id, "ME close_conn skipped (writer missing)");
}
self.registry.unregister(conn_id).await;
Ok(())
}
pub async fn shutdown_send_close_conn_all(self: &Arc<Self>) -> usize {
let conn_ids = self.registry.active_conn_ids().await;
let total = conn_ids.len();
for conn_id in conn_ids {
let _ = self.send_close_conn(conn_id).await;
}
total
}
pub fn connection_count(&self) -> usize {
self.conn_count.load(Ordering::Relaxed)
}
}

View File

@@ -0,0 +1,218 @@
use std::collections::HashSet;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::{Duration, Instant};
use tracing::warn;
use crate::network::IpFamily;
use super::super::MePool;
use super::{
HYBRID_GLOBAL_BURST_PERIOD_ROUNDS, HYBRID_RECENT_SUCCESS_WINDOW_MS,
HYBRID_RECOVERY_TRIGGER_MIN_INTERVAL_MS, HYBRID_TIMEOUT_WARN_RATE_LIMIT_MS,
};
impl MePool {
pub(super) async fn wait_for_writer_until(&self, deadline: Instant) -> bool {
let mut rx = self.writer_epoch.subscribe();
if !self.writers.read().await.is_empty() {
return true;
}
let now = Instant::now();
if now >= deadline {
return !self.writers.read().await.is_empty();
}
let timeout = deadline.saturating_duration_since(now);
if tokio::time::timeout(timeout, rx.changed()).await.is_ok() {
return !self.writers.read().await.is_empty();
}
!self.writers.read().await.is_empty()
}
pub(super) async fn wait_for_candidate_until(&self, routed_dc: i32, deadline: Instant) -> bool {
let mut rx = self.writer_epoch.subscribe();
loop {
if self.has_candidate_for_target_dc(routed_dc).await {
return true;
}
let now = Instant::now();
if now >= deadline {
return self.has_candidate_for_target_dc(routed_dc).await;
}
if self.has_candidate_for_target_dc(routed_dc).await {
return true;
}
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return self.has_candidate_for_target_dc(routed_dc).await;
}
if tokio::time::timeout(remaining, rx.changed()).await.is_err() {
return self.has_candidate_for_target_dc(routed_dc).await;
}
}
}
pub(super) async fn has_candidate_for_target_dc(&self, routed_dc: i32) -> bool {
let writers_snapshot = {
let ws = self.writers.read().await;
if ws.is_empty() {
return false;
}
ws.clone()
};
let mut candidate_indices = self
.candidate_indices_for_dc(&writers_snapshot, routed_dc, false)
.await;
if candidate_indices.is_empty() {
candidate_indices = self
.candidate_indices_for_dc(&writers_snapshot, routed_dc, true)
.await;
}
!candidate_indices.is_empty()
}
pub(super) async fn trigger_async_recovery_for_target_dc(
self: &Arc<Self>,
routed_dc: i32,
) -> bool {
let endpoints = self.endpoint_candidates_for_target_dc(routed_dc).await;
if endpoints.is_empty() {
return false;
}
self.stats.increment_me_async_recovery_trigger_total();
for addr in endpoints.into_iter().take(8) {
self.trigger_immediate_refill_for_dc(addr, routed_dc);
}
true
}
pub(super) async fn trigger_async_recovery_global(self: &Arc<Self>) {
self.stats.increment_me_async_recovery_trigger_total();
let mut seen = HashSet::<(i32, SocketAddr)>::new();
for family in self.family_order() {
let map_guard = match family {
IpFamily::V4 => self.proxy_map_v4.read().await,
IpFamily::V6 => self.proxy_map_v6.read().await,
};
for (dc, addrs) in map_guard.iter() {
for (ip, port) in addrs {
let addr = SocketAddr::new(*ip, *port);
if seen.insert((*dc, addr)) {
self.trigger_immediate_refill_for_dc(addr, *dc);
}
if seen.len() >= 8 {
return;
}
}
}
}
}
pub(super) async fn endpoint_candidates_for_target_dc(
&self,
routed_dc: i32,
) -> Vec<SocketAddr> {
self.preferred_endpoints_for_dc(routed_dc).await
}
pub(super) async fn maybe_trigger_hybrid_recovery(
self: &Arc<Self>,
routed_dc: i32,
hybrid_recovery_round: &mut u32,
hybrid_last_recovery_at: &mut Option<Instant>,
hybrid_wait_step: Duration,
) {
if !self.try_consume_hybrid_recovery_trigger_slot(HYBRID_RECOVERY_TRIGGER_MIN_INTERVAL_MS) {
return;
}
if let Some(last) = *hybrid_last_recovery_at
&& last.elapsed() < hybrid_wait_step
{
return;
}
let round = *hybrid_recovery_round;
let target_triggered = self.trigger_async_recovery_for_target_dc(routed_dc).await;
if !target_triggered || round.is_multiple_of(HYBRID_GLOBAL_BURST_PERIOD_ROUNDS) {
self.trigger_async_recovery_global().await;
}
*hybrid_recovery_round = round.saturating_add(1);
*hybrid_last_recovery_at = Some(Instant::now());
}
pub(super) fn hybrid_total_wait_budget(&self) -> Duration {
let base = self
.route_runtime
.me_route_hybrid_max_wait
.max(Duration::from_millis(50));
let now_ms = Self::now_epoch_millis();
let last_success_ms = self
.route_runtime
.me_route_last_success_epoch_ms
.load(Ordering::Relaxed);
if last_success_ms != 0
&& now_ms.saturating_sub(last_success_ms) <= HYBRID_RECENT_SUCCESS_WINDOW_MS
{
return base.saturating_mul(2);
}
base
}
pub(super) fn note_hybrid_route_success(&self) {
self.route_runtime
.me_route_last_success_epoch_ms
.store(Self::now_epoch_millis(), Ordering::Relaxed);
}
pub(super) fn on_hybrid_timeout(&self, deadline: Instant, routed_dc: i32) {
self.stats.increment_me_hybrid_timeout_total();
let now_ms = Self::now_epoch_millis();
let mut last_warn_ms = self
.route_runtime
.me_route_hybrid_timeout_warn_epoch_ms
.load(Ordering::Relaxed);
while now_ms.saturating_sub(last_warn_ms) >= HYBRID_TIMEOUT_WARN_RATE_LIMIT_MS {
match self
.route_runtime
.me_route_hybrid_timeout_warn_epoch_ms
.compare_exchange_weak(last_warn_ms, now_ms, Ordering::AcqRel, Ordering::Relaxed)
{
Ok(_) => {
warn!(
routed_dc,
budget_ms = self.hybrid_total_wait_budget().as_millis() as u64,
elapsed_ms = deadline.elapsed().as_millis() as u64,
"ME hybrid route timeout reached"
);
break;
}
Err(actual) => last_warn_ms = actual,
}
}
}
pub(super) fn try_consume_hybrid_recovery_trigger_slot(&self, min_interval_ms: u64) -> bool {
let now_ms = Self::now_epoch_millis();
let mut last_trigger_ms = self
.route_runtime
.me_async_recovery_last_trigger_epoch_ms
.load(Ordering::Relaxed);
loop {
if now_ms.saturating_sub(last_trigger_ms) < min_interval_ms {
return false;
}
match self
.route_runtime
.me_async_recovery_last_trigger_epoch_ms
.compare_exchange_weak(last_trigger_ms, now_ms, Ordering::AcqRel, Ordering::Relaxed)
{
Ok(_) => return true,
Err(actual) => last_trigger_ms = actual,
}
}
}
}

View File

@@ -0,0 +1,165 @@
use std::collections::{HashMap, HashSet};
use std::sync::atomic::Ordering;
use super::super::MePool;
use super::super::pool::WriterContour;
use super::{
IDLE_WRITER_PENALTY_HIGH_SECS, IDLE_WRITER_PENALTY_MID_SECS, PICK_PENALTY_DEGRADED,
PICK_PENALTY_DRAINING, PICK_PENALTY_STALE, PICK_PENALTY_WARM,
};
impl MePool {
pub(super) async fn candidate_indices_for_dc(
&self,
writers: &[super::super::pool::MeWriter],
routed_dc: i32,
include_warm: bool,
) -> Vec<usize> {
let preferred = self.preferred_endpoints_for_dc(routed_dc).await;
if preferred.is_empty() {
return Vec::new();
}
let mut out = Vec::new();
for (idx, w) in writers.iter().enumerate() {
if !self.writer_eligible_for_selection(w, include_warm) {
continue;
}
if w.writer_dc == routed_dc && preferred.contains(&w.addr) {
out.push(idx);
}
}
out
}
pub(super) fn writer_eligible_for_selection(
&self,
writer: &super::super::pool::MeWriter,
include_warm: bool,
) -> bool {
if !self.writer_accepts_new_binding(writer) {
return false;
}
match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) {
WriterContour::Active => true,
WriterContour::Warm => include_warm,
WriterContour::Draining => true,
}
}
pub(super) fn writer_contour_rank_for_selection(
&self,
writer: &super::super::pool::MeWriter,
) -> usize {
match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) {
WriterContour::Active => 0,
WriterContour::Warm => 1,
WriterContour::Draining => 2,
}
}
pub(super) fn writer_idle_rank_for_selection(
&self,
writer: &super::super::pool::MeWriter,
idle_since_by_writer: &HashMap<u64, u64>,
now_epoch_secs: u64,
) -> usize {
let Some(idle_since) = idle_since_by_writer.get(&writer.id).copied() else {
return 0;
};
let idle_age_secs = now_epoch_secs.saturating_sub(idle_since);
if idle_age_secs >= IDLE_WRITER_PENALTY_HIGH_SECS {
2
} else if idle_age_secs >= IDLE_WRITER_PENALTY_MID_SECS {
1
} else {
0
}
}
pub(super) fn writer_pick_score(
&self,
writer: &super::super::pool::MeWriter,
idle_since_by_writer: &HashMap<u64, u64>,
now_epoch_secs: u64,
) -> u64 {
let contour_penalty = match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) {
WriterContour::Active => 0,
WriterContour::Warm => PICK_PENALTY_WARM,
WriterContour::Draining => PICK_PENALTY_DRAINING,
};
let stale_penalty = if writer.generation < self.current_generation() {
PICK_PENALTY_STALE
} else {
0
};
let degraded_penalty = if writer.degraded.load(Ordering::Relaxed) {
PICK_PENALTY_DEGRADED
} else {
0
};
let idle_penalty =
(self.writer_idle_rank_for_selection(writer, idle_since_by_writer, now_epoch_secs)
as u64)
* 100;
let queue_cap = self.writer_lifecycle.writer_cmd_channel_capacity.max(1) as u64;
let queue_remaining = writer.tx.capacity() as u64;
let queue_used = queue_cap.saturating_sub(queue_remaining.min(queue_cap));
let queue_util_pct = queue_used.saturating_mul(100) / queue_cap;
let queue_penalty = queue_util_pct.saturating_mul(4);
let rtt_penalty =
((writer.rtt_ema_ms_x10.load(Ordering::Relaxed) as u64).saturating_add(5) / 10)
.min(400);
contour_penalty
.saturating_add(stale_penalty)
.saturating_add(degraded_penalty)
.saturating_add(idle_penalty)
.saturating_add(queue_penalty)
.saturating_add(rtt_penalty)
}
pub(super) fn p2c_ordered_candidate_indices(
&self,
candidate_indices: &[usize],
writers_snapshot: &[super::super::pool::MeWriter],
idle_since_by_writer: &HashMap<u64, u64>,
now_epoch_secs: u64,
start: usize,
sample_size: usize,
) -> Vec<usize> {
let total = candidate_indices.len();
if total == 0 {
return Vec::new();
}
let mut sampled = Vec::<usize>::with_capacity(sample_size.min(total));
let mut seen = HashSet::<usize>::with_capacity(total);
for offset in 0..sample_size.min(total) {
let idx = candidate_indices[(start + offset) % total];
if seen.insert(idx) {
sampled.push(idx);
}
}
sampled.sort_by_key(|idx| {
let writer = &writers_snapshot[*idx];
(
self.writer_pick_score(writer, idle_since_by_writer, now_epoch_secs),
writer.addr,
writer.id,
)
});
let mut ordered = Vec::<usize>::with_capacity(total);
ordered.extend(sampled.iter().copied());
for offset in 0..total {
let idx = candidate_indices[(start + offset) % total];
if seen.insert(idx) {
ordered.push(idx);
}
}
ordered
}
}

View File

@@ -165,6 +165,7 @@ async fn recv_data_count(rx: &mut mpsc::Receiver<WriterCommand>, budget: Duratio
match tokio::time::timeout(remaining.min(Duration::from_millis(10)), rx.recv()).await {
Ok(Some(WriterCommand::Data(_))) => data_count += 1,
Ok(Some(WriterCommand::DataAndFlush(_))) => data_count += 1,
Ok(Some(WriterCommand::ProxyReq(_))) => data_count += 1,
Ok(Some(WriterCommand::ControlAndFlush(_))) => data_count += 1,
Ok(Some(WriterCommand::Close)) => {}
Ok(None) => break,

View File

@@ -42,22 +42,45 @@ fn append_mapped_addr_and_port(buf: &mut Vec<u8>, addr: SocketAddr) {
buf.extend_from_slice(&(addr.port() as u32).to_le_bytes());
}
pub(crate) fn build_proxy_req_payload(
fn proxy_tag_wire_len(tag: &[u8]) -> usize {
if tag.len() < 254 {
4 + 1 + tag.len() + ((4 - ((1 + tag.len()) % 4)) % 4)
} else {
4 + 4 + tag.len() + ((4 - (tag.len() % 4)) % 4)
}
}
/// Returns the exact unencrypted RPC_PROXY_REQ payload length for pre-sizing frame buffers.
pub(crate) fn proxy_req_payload_len(
data_len: usize,
proxy_tag: Option<&[u8]>,
proto_flags: u32,
) -> usize {
let base_len = 4 + 4 + 8 + 20 + 20;
let extra_len = if proto_flags & RPC_FLAG_HAS_AD_TAG != 0 {
4 + proxy_tag.map(proxy_tag_wire_len).unwrap_or(0)
} else {
0
};
base_len + extra_len + data_len
}
/// Appends RPC_PROXY_REQ payload bytes without allocating an intermediate payload buffer.
pub(crate) fn append_proxy_req_payload_into(
b: &mut Vec<u8>,
conn_id: u64,
client_addr: SocketAddr,
our_addr: SocketAddr,
data: &[u8],
proxy_tag: Option<&[u8]>,
proto_flags: u32,
) -> Bytes {
let mut b = Vec::with_capacity(128 + data.len());
) {
b.extend_from_slice(&RPC_PROXY_REQ_U32.to_le_bytes());
b.extend_from_slice(&proto_flags.to_le_bytes());
b.extend_from_slice(&conn_id.to_le_bytes());
append_mapped_addr_and_port(&mut b, client_addr);
append_mapped_addr_and_port(&mut b, our_addr);
append_mapped_addr_and_port(b, client_addr);
append_mapped_addr_and_port(b, our_addr);
if proto_flags & RPC_FLAG_HAS_AD_TAG != 0 {
let extra_start = b.len();
@@ -86,6 +109,26 @@ pub(crate) fn build_proxy_req_payload(
}
b.extend_from_slice(data);
}
pub(crate) fn build_proxy_req_payload(
conn_id: u64,
client_addr: SocketAddr,
our_addr: SocketAddr,
data: &[u8],
proxy_tag: Option<&[u8]>,
proto_flags: u32,
) -> Bytes {
let mut b = Vec::with_capacity(128 + data.len());
append_proxy_req_payload_into(
&mut b,
conn_id,
client_addr,
our_addr,
data,
proxy_tag,
proto_flags,
);
Bytes::from(b)
}