Merge branch 'main' into feat/shadowsocks-upstream

This commit is contained in:
Alexey
2026-03-19 17:19:01 +03:00
committed by GitHub
23 changed files with 1761 additions and 449 deletions

View File

@@ -36,12 +36,16 @@ const DEFAULT_ME_HEALTH_INTERVAL_MS_UNHEALTHY: u64 = 1000;
const DEFAULT_ME_HEALTH_INTERVAL_MS_HEALTHY: u64 = 3000;
const DEFAULT_ME_ADMISSION_POLL_MS: u64 = 1000;
const DEFAULT_ME_WARN_RATE_LIMIT_MS: u64 = 5000;
const DEFAULT_ME_ROUTE_HYBRID_MAX_WAIT_MS: u64 = 3000;
const DEFAULT_ME_ROUTE_BLOCKING_SEND_TIMEOUT_MS: u64 = 250;
const DEFAULT_ME_C2ME_SEND_TIMEOUT_MS: u64 = 4000;
const DEFAULT_ME_POOL_DRAIN_SOFT_EVICT_ENABLED: bool = true;
const DEFAULT_ME_POOL_DRAIN_SOFT_EVICT_GRACE_SECS: u64 = 30;
const DEFAULT_ME_POOL_DRAIN_SOFT_EVICT_PER_WRITER: u8 = 1;
const DEFAULT_ME_POOL_DRAIN_SOFT_EVICT_BUDGET_PER_CORE: u16 = 8;
const DEFAULT_ME_POOL_DRAIN_SOFT_EVICT_COOLDOWN_MS: u64 = 5000;
const DEFAULT_USER_MAX_UNIQUE_IPS_WINDOW_SECS: u64 = 30;
const DEFAULT_ACCEPT_PERMIT_TIMEOUT_MS: u64 = 250;
const DEFAULT_UPSTREAM_CONNECT_RETRY_ATTEMPTS: u32 = 2;
const DEFAULT_UPSTREAM_UNHEALTHY_FAIL_THRESHOLD: u32 = 5;
const DEFAULT_UPSTREAM_CONNECT_BUDGET_MS: u64 = 3000;
@@ -156,6 +160,10 @@ pub(crate) fn default_server_max_connections() -> u32 {
10_000
}
pub(crate) fn default_accept_permit_timeout_ms() -> u64 {
DEFAULT_ACCEPT_PERMIT_TIMEOUT_MS
}
pub(crate) fn default_prefer_4() -> u8 {
4
}
@@ -380,6 +388,18 @@ pub(crate) fn default_me_warn_rate_limit_ms() -> u64 {
DEFAULT_ME_WARN_RATE_LIMIT_MS
}
pub(crate) fn default_me_route_hybrid_max_wait_ms() -> u64 {
DEFAULT_ME_ROUTE_HYBRID_MAX_WAIT_MS
}
pub(crate) fn default_me_route_blocking_send_timeout_ms() -> u64 {
DEFAULT_ME_ROUTE_BLOCKING_SEND_TIMEOUT_MS
}
pub(crate) fn default_me_c2me_send_timeout_ms() -> u64 {
DEFAULT_ME_C2ME_SEND_TIMEOUT_MS
}
pub(crate) fn default_upstream_connect_retry_attempts() -> u32 {
DEFAULT_UPSTREAM_CONNECT_RETRY_ATTEMPTS
}

View File

@@ -39,6 +39,7 @@ use super::load::{LoadedConfig, ProxyConfig};
const HOT_RELOAD_STABLE_SNAPSHOTS: u8 = 2;
const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50);
const HOT_RELOAD_STABLE_RECHECK: Duration = Duration::from_millis(75);
// ── Hot fields ────────────────────────────────────────────────────────────────
@@ -379,6 +380,14 @@ impl ReloadState {
self.applied_snapshot_hash = Some(hash);
self.reset_candidate();
}
fn pending_candidate(&self) -> Option<(u64, u8)> {
let hash = self.candidate_snapshot_hash?;
if self.candidate_hits < HOT_RELOAD_STABLE_SNAPSHOTS {
return Some((hash, self.candidate_hits));
}
None
}
}
fn normalize_watch_path(path: &Path) -> PathBuf {
@@ -603,6 +612,8 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
|| old.server.listen_tcp != new.server.listen_tcp
|| old.server.listen_unix_sock != new.server.listen_unix_sock
|| old.server.listen_unix_sock_perm != new.server.listen_unix_sock_perm
|| old.server.max_connections != new.server.max_connections
|| old.server.accept_permit_timeout_ms != new.server.accept_permit_timeout_ms
{
warned = true;
warn!("config reload: server listener settings changed; restart required");
@@ -662,6 +673,9 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
}
if old.general.me_route_no_writer_mode != new.general.me_route_no_writer_mode
|| old.general.me_route_no_writer_wait_ms != new.general.me_route_no_writer_wait_ms
|| old.general.me_route_hybrid_max_wait_ms != new.general.me_route_hybrid_max_wait_ms
|| old.general.me_route_blocking_send_timeout_ms
!= new.general.me_route_blocking_send_timeout_ms
|| old.general.me_route_inline_recovery_attempts
!= new.general.me_route_inline_recovery_attempts
|| old.general.me_route_inline_recovery_wait_ms
@@ -670,6 +684,10 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
warned = true;
warn!("config reload: general.me_route_no_writer_* changed; restart required");
}
if old.general.me_c2me_send_timeout_ms != new.general.me_c2me_send_timeout_ms {
warned = true;
warn!("config reload: general.me_c2me_send_timeout_ms changed; restart required");
}
if old.general.unknown_dc_log_path != new.general.unknown_dc_log_path
|| old.general.unknown_dc_file_log_enabled != new.general.unknown_dc_file_log_enabled
{
@@ -1253,6 +1271,73 @@ fn reload_config(
Some(next_manifest)
}
async fn reload_with_internal_stable_rechecks(
config_path: &PathBuf,
config_tx: &watch::Sender<Arc<ProxyConfig>>,
log_tx: &watch::Sender<LogLevel>,
detected_ip_v4: Option<IpAddr>,
detected_ip_v6: Option<IpAddr>,
reload_state: &mut ReloadState,
) -> Option<WatchManifest> {
let mut next_manifest = reload_config(
config_path,
config_tx,
log_tx,
detected_ip_v4,
detected_ip_v6,
reload_state,
);
let mut rechecks_left = HOT_RELOAD_STABLE_SNAPSHOTS.saturating_sub(1);
while rechecks_left > 0 {
let Some((snapshot_hash, candidate_hits)) = reload_state.pending_candidate() else {
break;
};
info!(
snapshot_hash,
candidate_hits,
required_hits = HOT_RELOAD_STABLE_SNAPSHOTS,
rechecks_left,
recheck_delay_ms = HOT_RELOAD_STABLE_RECHECK.as_millis(),
"config reload: scheduling internal stable recheck"
);
tokio::time::sleep(HOT_RELOAD_STABLE_RECHECK).await;
let recheck_manifest = reload_config(
config_path,
config_tx,
log_tx,
detected_ip_v4,
detected_ip_v6,
reload_state,
);
if recheck_manifest.is_some() {
next_manifest = recheck_manifest;
}
if reload_state.is_applied(snapshot_hash) {
info!(
snapshot_hash,
"config reload: applied after internal stable recheck"
);
break;
}
if reload_state.pending_candidate().is_none() {
info!(
snapshot_hash,
"config reload: internal stable recheck aborted"
);
break;
}
rechecks_left = rechecks_left.saturating_sub(1);
}
next_manifest
}
// ── Public API ────────────────────────────────────────────────────────────────
/// Spawn the hot-reload watcher task.
@@ -1376,14 +1461,16 @@ pub fn spawn_config_watcher(
tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await;
while notify_rx.try_recv().is_ok() {}
if let Some(next_manifest) = reload_config(
if let Some(next_manifest) = reload_with_internal_stable_rechecks(
&config_path,
&config_tx,
&log_tx,
detected_ip_v4,
detected_ip_v6,
&mut reload_state,
) {
)
.await
{
apply_watch_manifest(
inotify_watcher.as_mut(),
poll_watcher.as_mut(),
@@ -1540,6 +1627,35 @@ mod tests {
let _ = std::fs::remove_file(path);
}
#[tokio::test]
async fn reload_cycle_applies_after_single_external_event() {
let initial_tag = "10101010101010101010101010101010";
let final_tag = "20202020202020202020202020202020";
let path = temp_config_path("telemt_hot_reload_single_event");
write_reload_config(&path, Some(initial_tag), None);
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash;
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
let mut reload_state = ReloadState::new(Some(initial_hash));
write_reload_config(&path, Some(final_tag), None);
reload_with_internal_stable_rechecks(
&path,
&config_tx,
&log_tx,
None,
None,
&mut reload_state,
)
.await
.unwrap();
assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag));
let _ = std::fs::remove_file(path);
}
#[test]
fn reload_keeps_hot_apply_when_non_hot_fields_change() {
let initial_tag = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";

View File

@@ -378,6 +378,12 @@ impl ProxyConfig {
));
}
if config.general.me_c2me_send_timeout_ms > 60_000 {
return Err(ProxyError::Config(
"general.me_c2me_send_timeout_ms must be within [0, 60000]".to_string(),
));
}
if config.general.me_reader_route_data_wait_ms > 20 {
return Err(ProxyError::Config(
"general.me_reader_route_data_wait_ms must be within [0, 20]".to_string(),
@@ -640,6 +646,11 @@ impl ProxyConfig {
"general.me_route_backpressure_base_timeout_ms must be > 0".to_string(),
));
}
if config.general.me_route_backpressure_base_timeout_ms > 5000 {
return Err(ProxyError::Config(
"general.me_route_backpressure_base_timeout_ms must be within [1, 5000]".to_string(),
));
}
if config.general.me_route_backpressure_high_timeout_ms
< config.general.me_route_backpressure_base_timeout_ms
@@ -648,6 +659,11 @@ impl ProxyConfig {
"general.me_route_backpressure_high_timeout_ms must be >= general.me_route_backpressure_base_timeout_ms".to_string(),
));
}
if config.general.me_route_backpressure_high_timeout_ms > 5000 {
return Err(ProxyError::Config(
"general.me_route_backpressure_high_timeout_ms must be within [1, 5000]".to_string(),
));
}
if !(1..=100).contains(&config.general.me_route_backpressure_high_watermark_pct) {
return Err(ProxyError::Config(
@@ -662,6 +678,18 @@ impl ProxyConfig {
));
}
if !(50..=60_000).contains(&config.general.me_route_hybrid_max_wait_ms) {
return Err(ProxyError::Config(
"general.me_route_hybrid_max_wait_ms must be within [50, 60000]".to_string(),
));
}
if config.general.me_route_blocking_send_timeout_ms > 5000 {
return Err(ProxyError::Config(
"general.me_route_blocking_send_timeout_ms must be within [0, 5000]".to_string(),
));
}
if !(2..=4).contains(&config.general.me_writer_pick_sample_size) {
return Err(ProxyError::Config(
"general.me_writer_pick_sample_size must be within [2, 4]".to_string(),
@@ -722,6 +750,12 @@ impl ProxyConfig {
));
}
if config.server.accept_permit_timeout_ms > 60_000 {
return Err(ProxyError::Config(
"server.accept_permit_timeout_ms must be within [0, 60000]".to_string(),
));
}
if config.general.effective_me_pool_force_close_secs() > 0
&& config.general.effective_me_pool_force_close_secs()
< config.general.me_pool_drain_ttl_secs
@@ -1644,6 +1678,47 @@ mod tests {
let _ = std::fs::remove_file(path_valid);
}
#[test]
fn me_route_backpressure_base_timeout_ms_out_of_range_is_rejected() {
let toml = r#"
[general]
me_route_backpressure_base_timeout_ms = 5001
[censorship]
tls_domain = "example.com"
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_me_route_backpressure_base_timeout_ms_out_of_range_test.toml");
std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!(err.contains("general.me_route_backpressure_base_timeout_ms must be within [1, 5000]"));
let _ = std::fs::remove_file(path);
}
#[test]
fn me_route_backpressure_high_timeout_ms_out_of_range_is_rejected() {
let toml = r#"
[general]
me_route_backpressure_base_timeout_ms = 100
me_route_backpressure_high_timeout_ms = 5001
[censorship]
tls_domain = "example.com"
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_me_route_backpressure_high_timeout_ms_out_of_range_test.toml");
std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!(err.contains("general.me_route_backpressure_high_timeout_ms must be within [1, 5000]"));
let _ = std::fs::remove_file(path);
}
#[test]
fn me_route_no_writer_wait_ms_out_of_range_is_rejected() {
let toml = r#"

View File

@@ -462,6 +462,11 @@ pub struct GeneralConfig {
#[serde(default = "default_me_c2me_channel_capacity")]
pub me_c2me_channel_capacity: usize,
/// Maximum wait in milliseconds for enqueueing C2ME commands when the queue is full.
/// `0` keeps legacy unbounded wait behavior.
#[serde(default = "default_me_c2me_send_timeout_ms")]
pub me_c2me_send_timeout_ms: u64,
/// Bounded wait in milliseconds for routing ME DATA to per-connection queue.
/// `0` keeps legacy no-wait behavior.
#[serde(default = "default_me_reader_route_data_wait_ms")]
@@ -716,6 +721,15 @@ pub struct GeneralConfig {
#[serde(default = "default_me_route_no_writer_wait_ms")]
pub me_route_no_writer_wait_ms: u64,
/// Maximum cumulative wait in milliseconds for hybrid no-writer mode before failfast.
#[serde(default = "default_me_route_hybrid_max_wait_ms")]
pub me_route_hybrid_max_wait_ms: u64,
/// Maximum wait in milliseconds for blocking ME writer channel send fallback.
/// `0` keeps legacy unbounded wait behavior.
#[serde(default = "default_me_route_blocking_send_timeout_ms")]
pub me_route_blocking_send_timeout_ms: u64,
/// Number of inline recovery attempts in legacy mode.
#[serde(default = "default_me_route_inline_recovery_attempts")]
pub me_route_inline_recovery_attempts: u32,
@@ -921,6 +935,7 @@ impl Default for GeneralConfig {
me_writer_cmd_channel_capacity: default_me_writer_cmd_channel_capacity(),
me_route_channel_capacity: default_me_route_channel_capacity(),
me_c2me_channel_capacity: default_me_c2me_channel_capacity(),
me_c2me_send_timeout_ms: default_me_c2me_send_timeout_ms(),
me_reader_route_data_wait_ms: default_me_reader_route_data_wait_ms(),
me_d2c_flush_batch_max_frames: default_me_d2c_flush_batch_max_frames(),
me_d2c_flush_batch_max_bytes: default_me_d2c_flush_batch_max_bytes(),
@@ -990,6 +1005,8 @@ impl Default for GeneralConfig {
me_warn_rate_limit_ms: default_me_warn_rate_limit_ms(),
me_route_no_writer_mode: MeRouteNoWriterMode::default(),
me_route_no_writer_wait_ms: default_me_route_no_writer_wait_ms(),
me_route_hybrid_max_wait_ms: default_me_route_hybrid_max_wait_ms(),
me_route_blocking_send_timeout_ms: default_me_route_blocking_send_timeout_ms(),
me_route_inline_recovery_attempts: default_me_route_inline_recovery_attempts(),
me_route_inline_recovery_wait_ms: default_me_route_inline_recovery_wait_ms(),
links: LinksConfig::default(),
@@ -1225,6 +1242,11 @@ pub struct ServerConfig {
/// 0 means unlimited.
#[serde(default = "default_server_max_connections")]
pub max_connections: u32,
/// Maximum wait in milliseconds while acquiring a connection slot permit.
/// `0` keeps legacy unbounded wait behavior.
#[serde(default = "default_accept_permit_timeout_ms")]
pub accept_permit_timeout_ms: u64,
}
impl Default for ServerConfig {
@@ -1244,6 +1266,7 @@ impl Default for ServerConfig {
api: ApiConfig::default(),
listeners: Vec::new(),
max_connections: default_server_max_connections(),
accept_permit_timeout_ms: default_accept_permit_timeout_ms(),
}
}
}

View File

@@ -205,6 +205,7 @@ pub(crate) fn format_uptime(total_secs: u64) -> String {
format!("{} / {} seconds", parts.join(", "), total_secs)
}
#[allow(dead_code)]
pub(crate) async fn wait_until_admission_open(admission_rx: &mut watch::Receiver<bool>) -> bool {
loop {
if *admission_rx.borrow() {

View File

@@ -24,7 +24,7 @@ use crate::transport::{
ListenOptions, UpstreamManager, create_listener, find_listener_processes,
};
use super::helpers::{is_expected_handshake_eof, print_proxy_links, wait_until_admission_open};
use super::helpers::{is_expected_handshake_eof, print_proxy_links};
pub(crate) struct BoundListeners {
pub(crate) listeners: Vec<(TcpListener, bool)>,
@@ -195,7 +195,7 @@ pub(crate) async fn bind_listeners(
has_unix_listener = true;
let mut config_rx_unix: watch::Receiver<Arc<ProxyConfig>> = config_rx.clone();
let mut admission_rx_unix = admission_rx.clone();
let admission_rx_unix = admission_rx.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
@@ -212,17 +212,44 @@ pub(crate) async fn bind_listeners(
let unix_conn_counter = Arc::new(std::sync::atomic::AtomicU64::new(1));
loop {
if !wait_until_admission_open(&mut admission_rx_unix).await {
warn!("Conditional-admission gate channel closed for unix listener");
break;
}
match unix_listener.accept().await {
Ok((stream, _)) => {
let permit = match max_connections_unix.clone().acquire_owned().await {
Ok(permit) => permit,
Err(_) => {
error!("Connection limiter is closed");
break;
if !*admission_rx_unix.borrow() {
drop(stream);
continue;
}
let accept_permit_timeout_ms = config_rx_unix
.borrow()
.server
.accept_permit_timeout_ms;
let permit = if accept_permit_timeout_ms == 0 {
match max_connections_unix.clone().acquire_owned().await {
Ok(permit) => permit,
Err(_) => {
error!("Connection limiter is closed");
break;
}
}
} else {
match tokio::time::timeout(
Duration::from_millis(accept_permit_timeout_ms),
max_connections_unix.clone().acquire_owned(),
)
.await
{
Ok(Ok(permit)) => permit,
Ok(Err(_)) => {
error!("Connection limiter is closed");
break;
}
Err(_) => {
debug!(
timeout_ms = accept_permit_timeout_ms,
"Dropping accepted unix connection: permit wait timeout"
);
drop(stream);
continue;
}
}
};
let conn_id =
@@ -312,7 +339,7 @@ pub(crate) fn spawn_tcp_accept_loops(
) {
for (listener, listener_proxy_protocol) in listeners {
let mut config_rx: watch::Receiver<Arc<ProxyConfig>> = config_rx.clone();
let mut admission_rx_tcp = admission_rx.clone();
let admission_rx_tcp = admission_rx.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
@@ -327,17 +354,46 @@ pub(crate) fn spawn_tcp_accept_loops(
tokio::spawn(async move {
loop {
if !wait_until_admission_open(&mut admission_rx_tcp).await {
warn!("Conditional-admission gate channel closed for tcp listener");
break;
}
match listener.accept().await {
Ok((stream, peer_addr)) => {
let permit = match max_connections_tcp.clone().acquire_owned().await {
Ok(permit) => permit,
Err(_) => {
error!("Connection limiter is closed");
break;
if !*admission_rx_tcp.borrow() {
debug!(peer = %peer_addr, "Admission gate closed, dropping connection");
drop(stream);
continue;
}
let accept_permit_timeout_ms = config_rx
.borrow()
.server
.accept_permit_timeout_ms;
let permit = if accept_permit_timeout_ms == 0 {
match max_connections_tcp.clone().acquire_owned().await {
Ok(permit) => permit,
Err(_) => {
error!("Connection limiter is closed");
break;
}
}
} else {
match tokio::time::timeout(
Duration::from_millis(accept_permit_timeout_ms),
max_connections_tcp.clone().acquire_owned(),
)
.await
{
Ok(Ok(permit)) => permit,
Ok(Err(_)) => {
error!("Connection limiter is closed");
break;
}
Err(_) => {
debug!(
peer = %peer_addr,
timeout_ms = accept_permit_timeout_ms,
"Dropping accepted connection: permit wait timeout"
);
drop(stream);
continue;
}
}
};
let config = config_rx.borrow_and_update().clone();

View File

@@ -267,6 +267,8 @@ pub(crate) async fn initialize_me_pool(
config.general.me_warn_rate_limit_ms,
config.general.me_route_no_writer_mode,
config.general.me_route_no_writer_wait_ms,
config.general.me_route_hybrid_max_wait_ms,
config.general.me_route_blocking_send_timeout_ms,
config.general.me_route_inline_recovery_attempts,
config.general.me_route_inline_recovery_wait_ms,
);

View File

@@ -1692,6 +1692,57 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
}
);
let _ = writeln!(
out,
"# HELP telemt_me_writer_close_signal_drop_total Close-signal drops for already-removed ME writers"
);
let _ = writeln!(out, "# TYPE telemt_me_writer_close_signal_drop_total counter");
let _ = writeln!(
out,
"telemt_me_writer_close_signal_drop_total {}",
if me_allows_normal {
stats.get_me_writer_close_signal_drop_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_writer_close_signal_channel_full_total Close-signal drops caused by full writer command channels"
);
let _ = writeln!(
out,
"# TYPE telemt_me_writer_close_signal_channel_full_total counter"
);
let _ = writeln!(
out,
"telemt_me_writer_close_signal_channel_full_total {}",
if me_allows_normal {
stats.get_me_writer_close_signal_channel_full_total()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_draining_writers_reap_progress_total Draining-writer removals processed by reap cleanup"
);
let _ = writeln!(
out,
"# TYPE telemt_me_draining_writers_reap_progress_total counter"
);
let _ = writeln!(
out,
"telemt_me_draining_writers_reap_progress_total {}",
if me_allows_normal {
stats.get_me_draining_writers_reap_progress_total()
} else {
0
}
);
let _ = writeln!(out, "# HELP telemt_me_writer_removed_total Total ME writer removals");
let _ = writeln!(out, "# TYPE telemt_me_writer_removed_total counter");
let _ = writeln!(
@@ -2124,6 +2175,13 @@ mod tests {
assert!(output.contains("# TYPE telemt_me_rpc_proxy_req_signal_sent_total counter"));
assert!(output.contains("# TYPE telemt_me_idle_close_by_peer_total counter"));
assert!(output.contains("# TYPE telemt_me_writer_removed_total counter"));
assert!(output.contains("# TYPE telemt_me_writer_close_signal_drop_total counter"));
assert!(output.contains(
"# TYPE telemt_me_writer_close_signal_channel_full_total counter"
));
assert!(output.contains(
"# TYPE telemt_me_draining_writers_reap_progress_total counter"
));
assert!(output.contains("# TYPE telemt_pool_drain_soft_evict_total counter"));
assert!(output.contains("# TYPE telemt_pool_drain_soft_evict_writer_total counter"));
assert!(output.contains(

View File

@@ -222,6 +222,7 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool
async fn enqueue_c2me_command(
tx: &mpsc::Sender<C2MeCommand>,
cmd: C2MeCommand,
send_timeout: Duration,
) -> std::result::Result<(), mpsc::error::SendError<C2MeCommand>> {
match tx.try_send(cmd) {
Ok(()) => Ok(()),
@@ -231,7 +232,17 @@ async fn enqueue_c2me_command(
if tx.capacity() <= C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS {
tokio::task::yield_now().await;
}
tx.send(cmd).await
if send_timeout.is_zero() {
return tx.send(cmd).await;
}
match tokio::time::timeout(send_timeout, tx.reserve()).await {
Ok(Ok(permit)) => {
permit.send(cmd);
Ok(())
}
Ok(Err(_)) => Err(mpsc::error::SendError(cmd)),
Err(_) => Err(mpsc::error::SendError(cmd)),
}
}
}
}
@@ -355,6 +366,7 @@ where
.general
.me_c2me_channel_capacity
.max(C2ME_CHANNEL_CAPACITY_FALLBACK);
let c2me_send_timeout = Duration::from_millis(config.general.me_c2me_send_timeout_ms);
let (c2me_tx, mut c2me_rx) = mpsc::channel::<C2MeCommand>(c2me_channel_capacity);
let me_pool_c2me = me_pool.clone();
let effective_tag = effective_tag;
@@ -363,15 +375,42 @@ where
while let Some(cmd) = c2me_rx.recv().await {
match cmd {
C2MeCommand::Data { payload, flags } => {
me_pool_c2me.send_proxy_req(
conn_id,
success.dc_idx,
peer,
translated_local_addr,
payload.as_ref(),
flags,
effective_tag.as_deref(),
).await?;
if c2me_send_timeout.is_zero() {
me_pool_c2me
.send_proxy_req(
conn_id,
success.dc_idx,
peer,
translated_local_addr,
payload.as_ref(),
flags,
effective_tag.as_deref(),
)
.await?;
} else {
match tokio::time::timeout(
c2me_send_timeout,
me_pool_c2me.send_proxy_req(
conn_id,
success.dc_idx,
peer,
translated_local_addr,
payload.as_ref(),
flags,
effective_tag.as_deref(),
),
)
.await
{
Ok(send_result) => send_result?,
Err(_) => {
return Err(ProxyError::Proxy(format!(
"ME send timeout after {}ms",
c2me_send_timeout.as_millis()
)));
}
}
}
sent_since_yield = sent_since_yield.saturating_add(1);
if should_yield_c2me_sender(sent_since_yield, !c2me_rx.is_empty()) {
sent_since_yield = 0;
@@ -555,7 +594,7 @@ where
loop {
if session_lease.is_stale() {
stats.increment_reconnect_stale_close_total();
let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await;
let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close, c2me_send_timeout).await;
main_result = Err(ProxyError::Proxy("Session evicted by reconnect".to_string()));
break;
}
@@ -573,7 +612,7 @@ where
"Cutover affected middle session, closing client connection"
);
tokio::time::sleep(delay).await;
let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await;
let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close, c2me_send_timeout).await;
main_result = Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
break;
}
@@ -607,9 +646,13 @@ where
flags |= RPC_FLAG_NOT_ENCRYPTED;
}
// Keep client read loop lightweight: route heavy ME send path via a dedicated task.
if enqueue_c2me_command(&c2me_tx, C2MeCommand::Data { payload, flags })
.await
.is_err()
if enqueue_c2me_command(
&c2me_tx,
C2MeCommand::Data { payload, flags },
c2me_send_timeout,
)
.await
.is_err()
{
main_result = Err(ProxyError::Proxy("ME sender channel closed".into()));
break;
@@ -618,7 +661,12 @@ where
Ok(None) => {
debug!(conn_id, "Client EOF");
client_closed = true;
let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await;
let _ = enqueue_c2me_command(
&c2me_tx,
C2MeCommand::Close,
c2me_send_timeout,
)
.await;
break;
}
Err(e) => {
@@ -993,6 +1041,7 @@ mod tests {
payload: Bytes::from_static(&[1, 2, 3]),
flags: 0,
},
TokioDuration::from_millis(50),
)
.await
.unwrap();
@@ -1028,6 +1077,7 @@ mod tests {
payload: Bytes::from_static(&[7, 7]),
flags: 7,
},
TokioDuration::from_millis(100),
)
.await
.unwrap();

View File

@@ -123,6 +123,9 @@ pub struct Stats {
pool_drain_soft_evict_total: AtomicU64,
pool_drain_soft_evict_writer_total: AtomicU64,
pool_stale_pick_total: AtomicU64,
me_writer_close_signal_drop_total: AtomicU64,
me_writer_close_signal_channel_full_total: AtomicU64,
me_draining_writers_reap_progress_total: AtomicU64,
me_writer_removed_total: AtomicU64,
me_writer_removed_unexpected_total: AtomicU64,
me_refill_triggered_total: AtomicU64,
@@ -734,6 +737,24 @@ impl Stats {
self.pool_stale_pick_total.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_me_writer_close_signal_drop_total(&self) {
if self.telemetry_me_allows_normal() {
self.me_writer_close_signal_drop_total
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_me_writer_close_signal_channel_full_total(&self) {
if self.telemetry_me_allows_normal() {
self.me_writer_close_signal_channel_full_total
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_me_draining_writers_reap_progress_total(&self) {
if self.telemetry_me_allows_normal() {
self.me_draining_writers_reap_progress_total
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_me_writer_removed_total(&self) {
if self.telemetry_me_allows_debug() {
self.me_writer_removed_total.fetch_add(1, Ordering::Relaxed);
@@ -1259,6 +1280,17 @@ impl Stats {
pub fn get_pool_stale_pick_total(&self) -> u64 {
self.pool_stale_pick_total.load(Ordering::Relaxed)
}
pub fn get_me_writer_close_signal_drop_total(&self) -> u64 {
self.me_writer_close_signal_drop_total.load(Ordering::Relaxed)
}
pub fn get_me_writer_close_signal_channel_full_total(&self) -> u64 {
self.me_writer_close_signal_channel_full_total
.load(Ordering::Relaxed)
}
pub fn get_me_draining_writers_reap_progress_total(&self) -> u64 {
self.me_draining_writers_reap_progress_total
.load(Ordering::Relaxed)
}
pub fn get_me_writer_removed_total(&self) -> u64 {
self.me_writer_removed_total.load(Ordering::Relaxed)
}

View File

@@ -314,6 +314,8 @@ pub(super) async fn reap_draining_writers(
}
pool.stats.increment_pool_force_close_total();
pool.remove_writer_and_close_clients(writer_id).await;
pool.stats
.increment_me_draining_writers_reap_progress_total();
closed_total = closed_total.saturating_add(1);
}
for writer_id in empty_writer_ids {
@@ -324,6 +326,8 @@ pub(super) async fn reap_draining_writers(
continue;
}
pool.remove_writer_and_close_clients(writer_id).await;
pool.stats
.increment_me_draining_writers_reap_progress_total();
closed_total = closed_total.saturating_add(1);
}
@@ -1574,6 +1578,8 @@ mod tests {
general.me_warn_rate_limit_ms,
MeRouteNoWriterMode::default(),
general.me_route_no_writer_wait_ms,
general.me_route_hybrid_max_wait_ms,
general.me_route_blocking_send_timeout_ms,
general.me_route_inline_recovery_attempts,
general.me_route_inline_recovery_wait_ms,
)

View File

@@ -111,6 +111,8 @@ async fn make_pool(
general.me_warn_rate_limit_ms,
MeRouteNoWriterMode::default(),
general.me_route_no_writer_wait_ms,
general.me_route_hybrid_max_wait_ms,
general.me_route_blocking_send_timeout_ms,
general.me_route_inline_recovery_attempts,
general.me_route_inline_recovery_wait_ms,
);

View File

@@ -110,6 +110,8 @@ async fn make_pool(
general.me_warn_rate_limit_ms,
MeRouteNoWriterMode::default(),
general.me_route_no_writer_wait_ms,
general.me_route_hybrid_max_wait_ms,
general.me_route_blocking_send_timeout_ms,
general.me_route_inline_recovery_attempts,
general.me_route_inline_recovery_wait_ms,
);

View File

@@ -4,6 +4,7 @@ use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use bytes::Bytes;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
@@ -103,6 +104,8 @@ async fn make_pool(me_pool_drain_threshold: u64) -> Arc<MePool> {
general.me_warn_rate_limit_ms,
MeRouteNoWriterMode::default(),
general.me_route_no_writer_wait_ms,
general.me_route_hybrid_max_wait_ms,
general.me_route_blocking_send_timeout_ms,
general.me_route_inline_recovery_attempts,
general.me_route_inline_recovery_wait_ms,
)
@@ -207,6 +210,89 @@ async fn reap_draining_writers_removes_empty_draining_writers() {
assert_eq!(current_writer_ids(&pool).await, vec![3]);
}
#[tokio::test]
async fn reap_draining_writers_does_not_block_on_stuck_writer_close_signal() {
let pool = make_pool(128).await;
let now_epoch_secs = MePool::now_epoch_secs();
let (blocked_tx, blocked_rx) = mpsc::channel::<WriterCommand>(1);
assert!(
blocked_tx
.try_send(WriterCommand::Data(Bytes::from_static(b"stuck")))
.is_ok()
);
let blocked_rx_guard = tokio::spawn(async move {
let _hold_rx = blocked_rx;
tokio::time::sleep(Duration::from_secs(30)).await;
});
let blocked_writer_id = 90u64;
let blocked_writer = MeWriter {
id: blocked_writer_id,
addr: SocketAddr::new(
IpAddr::V4(Ipv4Addr::LOCALHOST),
4500 + blocked_writer_id as u16,
),
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
writer_dc: 2,
generation: 1,
contour: Arc::new(AtomicU8::new(WriterContour::Draining.as_u8())),
created_at: Instant::now() - Duration::from_secs(blocked_writer_id),
tx: blocked_tx.clone(),
cancel: CancellationToken::new(),
degraded: Arc::new(AtomicBool::new(false)),
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
draining: Arc::new(AtomicBool::new(true)),
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(
now_epoch_secs.saturating_sub(120),
)),
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)),
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
};
pool.writers.write().await.push(blocked_writer);
pool.registry
.register_writer(blocked_writer_id, blocked_tx)
.await;
pool.conn_count.fetch_add(1, Ordering::Relaxed);
insert_draining_writer(&pool, 91, now_epoch_secs.saturating_sub(110), 0, 0).await;
let mut warn_next_allowed = HashMap::new();
let mut soft_evict_next_allowed = HashMap::new();
let reap_res = tokio::time::timeout(
Duration::from_millis(500),
reap_draining_writers(&pool, &mut warn_next_allowed, &mut soft_evict_next_allowed),
)
.await;
blocked_rx_guard.abort();
assert!(reap_res.is_ok(), "reap should not block on close signal");
assert!(current_writer_ids(&pool).await.is_empty());
assert_eq!(pool.stats.get_me_writer_close_signal_drop_total(), 2);
assert_eq!(pool.stats.get_me_writer_close_signal_channel_full_total(), 1);
assert_eq!(pool.stats.get_me_draining_writers_reap_progress_total(), 2);
let activity = pool.registry.writer_activity_snapshot().await;
assert!(!activity.bound_clients_by_writer.contains_key(&blocked_writer_id));
assert!(!activity.bound_clients_by_writer.contains_key(&91));
let (probe_conn_id, _rx) = pool.registry.register().await;
assert!(
!pool.registry
.bind_writer(
probe_conn_id,
blocked_writer_id,
ConnMeta {
target_dc: 2,
client_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 6400),
our_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443),
proto_flags: 0,
},
)
.await
);
let _ = pool.registry.unregister(probe_conn_id).await;
}
#[tokio::test]
async fn reap_draining_writers_overflow_closes_oldest_non_empty_writers() {
let pool = make_pool(2).await;

View File

@@ -193,6 +193,8 @@ pub struct MePool {
pub(super) me_reader_route_data_wait_ms: Arc<AtomicU64>,
pub(super) me_route_no_writer_mode: AtomicU8,
pub(super) me_route_no_writer_wait: Duration,
pub(super) me_route_hybrid_max_wait: Duration,
pub(super) me_route_blocking_send_timeout: Duration,
pub(super) me_route_inline_recovery_attempts: u32,
pub(super) me_route_inline_recovery_wait: Duration,
pub(super) me_health_interval_ms_unhealthy: AtomicU64,
@@ -307,6 +309,8 @@ impl MePool {
me_warn_rate_limit_ms: u64,
me_route_no_writer_mode: MeRouteNoWriterMode,
me_route_no_writer_wait_ms: u64,
me_route_hybrid_max_wait_ms: u64,
me_route_blocking_send_timeout_ms: u64,
me_route_inline_recovery_attempts: u32,
me_route_inline_recovery_wait_ms: u64,
) -> Arc<Self> {
@@ -490,6 +494,10 @@ impl MePool {
me_reader_route_data_wait_ms: Arc::new(AtomicU64::new(me_reader_route_data_wait_ms)),
me_route_no_writer_mode: AtomicU8::new(me_route_no_writer_mode.as_u8()),
me_route_no_writer_wait: Duration::from_millis(me_route_no_writer_wait_ms),
me_route_hybrid_max_wait: Duration::from_millis(me_route_hybrid_max_wait_ms),
me_route_blocking_send_timeout: Duration::from_millis(
me_route_blocking_send_timeout_ms,
),
me_route_inline_recovery_attempts,
me_route_inline_recovery_wait: Duration::from_millis(me_route_inline_recovery_wait_ms),
me_health_interval_ms_unhealthy: AtomicU64::new(me_health_interval_ms_unhealthy.max(1)),

View File

@@ -8,6 +8,7 @@ use bytes::Bytes;
use bytes::BytesMut;
use rand::Rng;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
@@ -312,41 +313,28 @@ impl MePool {
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_PING_U32.to_le_bytes());
p.extend_from_slice(&sent_id.to_le_bytes());
{
let mut tracker = ping_tracker_ping.lock().await;
let now_epoch_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let mut run_cleanup = false;
if let Some(pool) = pool_ping.upgrade() {
let last_cleanup_ms = pool
let now_epoch_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let mut run_cleanup = false;
if let Some(pool) = pool_ping.upgrade() {
let last_cleanup_ms = pool
.ping_tracker_last_cleanup_epoch_ms
.load(Ordering::Relaxed);
if now_epoch_ms.saturating_sub(last_cleanup_ms) >= 30_000
&& pool
.ping_tracker_last_cleanup_epoch_ms
.load(Ordering::Relaxed);
if now_epoch_ms.saturating_sub(last_cleanup_ms) >= 30_000
&& pool
.ping_tracker_last_cleanup_epoch_ms
.compare_exchange(
last_cleanup_ms,
now_epoch_ms,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
run_cleanup = true;
}
.compare_exchange(
last_cleanup_ms,
now_epoch_ms,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
run_cleanup = true;
}
if run_cleanup {
let before = tracker.len();
tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120));
let expired = before.saturating_sub(tracker.len());
if expired > 0 {
stats_ping.increment_me_keepalive_timeout_by(expired as u64);
}
}
tracker.insert(sent_id, (std::time::Instant::now(), writer_id));
}
ping_id = ping_id.wrapping_add(1);
stats_ping.increment_me_keepalive_sent();
@@ -367,6 +355,16 @@ impl MePool {
}
break;
}
let mut tracker = ping_tracker_ping.lock().await;
if run_cleanup {
let before = tracker.len();
tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120));
let expired = before.saturating_sub(tracker.len());
if expired > 0 {
stats_ping.increment_me_keepalive_timeout_by(expired as u64);
}
}
tracker.insert(sent_id, (std::time::Instant::now(), writer_id));
}
});
@@ -494,11 +492,9 @@ impl MePool {
}
pub(crate) async fn remove_writer_and_close_clients(self: &Arc<Self>, writer_id: u64) {
let conns = self.remove_writer_only(writer_id).await;
for bound in conns {
let _ = self.registry.route(bound.conn_id, super::MeResponse::Close).await;
let _ = self.registry.unregister(bound.conn_id).await;
}
// Full client cleanup now happens inside `registry.writer_lost` to keep
// writer reap/remove paths strictly non-blocking per connection.
let _ = self.remove_writer_only(writer_id).await;
}
async fn remove_writer_only(self: &Arc<Self>, writer_id: u64) -> Vec<BoundConn> {
@@ -528,6 +524,11 @@ impl MePool {
self.conn_count.fetch_sub(1, Ordering::Relaxed);
}
}
// State invariant:
// - writer is removed from `self.writers` (pool visibility),
// - writer is removed from registry routing/binding maps via `writer_lost`.
// The close command below is only a best-effort accelerator for task shutdown.
// Cleanup progress must never depend on command-channel availability.
let conns = self.registry.writer_lost(writer_id).await;
{
let mut tracker = self.ping_tracker.lock().await;
@@ -535,7 +536,25 @@ impl MePool {
}
self.rtt_stats.lock().await.remove(&writer_id);
if let Some(tx) = close_tx {
let _ = tx.send(WriterCommand::Close).await;
match tx.try_send(WriterCommand::Close) {
Ok(()) => {}
Err(TrySendError::Full(_)) => {
self.stats.increment_me_writer_close_signal_drop_total();
self.stats
.increment_me_writer_close_signal_channel_full_total();
debug!(
writer_id,
"Skipping close signal for removed writer: command channel is full"
);
}
Err(TrySendError::Closed(_)) => {
self.stats.increment_me_writer_close_signal_drop_total();
debug!(
writer_id,
"Skipping close signal for removed writer: command channel is closed"
);
}
}
}
if trigger_refill
&& let Some(addr) = removed_addr

View File

@@ -8,6 +8,7 @@ use bytes::{Bytes, BytesMut};
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
use tokio::sync::{Mutex, mpsc};
use tokio::sync::mpsc::error::TrySendError;
use tokio_util::sync::CancellationToken;
use tracing::{debug, trace, warn};
@@ -173,12 +174,12 @@ pub(crate) async fn reader_loop(
} else if pt == RPC_CLOSE_EXT_U32 && body.len() >= 8 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
debug!(cid, "RPC_CLOSE_EXT from ME");
reg.route(cid, MeResponse::Close).await;
let _ = reg.route_nowait(cid, MeResponse::Close).await;
reg.unregister(cid).await;
} else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
debug!(cid, "RPC_CLOSE_CONN from ME");
reg.route(cid, MeResponse::Close).await;
let _ = reg.route_nowait(cid, MeResponse::Close).await;
reg.unregister(cid).await;
} else if pt == RPC_PING_U32 && body.len() >= 8 {
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
@@ -186,13 +187,15 @@ pub(crate) async fn reader_loop(
let mut pong = Vec::with_capacity(12);
pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes());
pong.extend_from_slice(&ping_id.to_le_bytes());
if tx
.send(WriterCommand::DataAndFlush(Bytes::from(pong)))
.await
.is_err()
{
warn!("PONG send failed");
break;
match tx.try_send(WriterCommand::DataAndFlush(Bytes::from(pong))) {
Ok(()) => {}
Err(TrySendError::Full(_)) => {
debug!(ping_id, "PONG dropped: writer command channel is full");
}
Err(TrySendError::Closed(_)) => {
warn!("PONG send failed: writer channel closed");
break;
}
}
} else if pt == RPC_PONG_U32 && body.len() >= 8 {
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
@@ -232,6 +235,13 @@ async fn send_close_conn(tx: &mpsc::Sender<WriterCommand>, conn_id: u64) {
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes());
let _ = tx.send(WriterCommand::DataAndFlush(Bytes::from(p))).await;
match tx.try_send(WriterCommand::DataAndFlush(Bytes::from(p))) {
Ok(()) => {}
Err(TrySendError::Full(_)) => {
debug!(conn_id, "ME close_conn signal skipped: writer command channel is full");
}
Err(TrySendError::Closed(_)) => {
debug!(conn_id, "ME close_conn signal skipped: writer command channel is closed");
}
}
}

View File

@@ -169,6 +169,7 @@ impl ConnRegistry {
None
}
#[allow(dead_code)]
pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult {
let tx = {
let inner = self.inner.read().await;
@@ -445,30 +446,38 @@ impl ConnRegistry {
}
pub async fn writer_lost(&self, writer_id: u64) -> Vec<BoundConn> {
let mut inner = self.inner.write().await;
inner.writers.remove(&writer_id);
inner.last_meta_for_writer.remove(&writer_id);
inner.writer_idle_since_epoch_secs.remove(&writer_id);
let conns = inner
.conns_for_writer
.remove(&writer_id)
.unwrap_or_default()
.into_iter()
.collect::<Vec<_>>();
let mut close_txs = Vec::<mpsc::Sender<MeResponse>>::new();
let mut out = Vec::new();
for conn_id in conns {
if inner.writer_for_conn.get(&conn_id).copied() != Some(writer_id) {
continue;
}
inner.writer_for_conn.remove(&conn_id);
if let Some(m) = inner.meta.get(&conn_id) {
out.push(BoundConn {
conn_id,
meta: m.clone(),
});
{
let mut inner = self.inner.write().await;
inner.writers.remove(&writer_id);
inner.last_meta_for_writer.remove(&writer_id);
inner.writer_idle_since_epoch_secs.remove(&writer_id);
let conns = inner
.conns_for_writer
.remove(&writer_id)
.unwrap_or_default()
.into_iter()
.collect::<Vec<_>>();
for conn_id in conns {
if inner.writer_for_conn.get(&conn_id).copied() != Some(writer_id) {
continue;
}
inner.writer_for_conn.remove(&conn_id);
if let Some(client_tx) = inner.map.remove(&conn_id) {
close_txs.push(client_tx);
}
if let Some(meta) = inner.meta.remove(&conn_id) {
out.push(BoundConn { conn_id, meta });
}
}
}
for client_tx in close_txs {
let _ = client_tx.try_send(MeResponse::Close);
}
out
}
@@ -491,6 +500,7 @@ impl ConnRegistry {
#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
use super::ConnMeta;
use super::ConnRegistry;
@@ -663,6 +673,39 @@ mod tests {
assert!(registry.is_writer_empty(20).await);
}
#[tokio::test]
async fn writer_lost_removes_bound_conn_from_registry_and_signals_close() {
let registry = ConnRegistry::new();
let (conn_id, mut rx) = registry.register().await;
let (writer_tx, _writer_rx) = tokio::sync::mpsc::channel(8);
registry.register_writer(10, writer_tx).await;
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443);
assert!(
registry
.bind_writer(
conn_id,
10,
ConnMeta {
target_dc: 2,
client_addr: addr,
our_addr: addr,
proto_flags: 0,
},
)
.await
);
let lost = registry.writer_lost(10).await;
assert_eq!(lost.len(), 1);
assert_eq!(lost[0].conn_id, conn_id);
assert!(registry.get_writer(conn_id).await.is_none());
assert!(registry.get_meta(conn_id).await.is_none());
assert_eq!(registry.unregister(conn_id).await, None);
let close = tokio::time::timeout(Duration::from_millis(50), rx.recv()).await;
assert!(matches!(close, Ok(Some(MeResponse::Close))));
}
#[tokio::test]
async fn bind_writer_rejects_unregistered_writer() {
let registry = ConnRegistry::new();

View File

@@ -6,6 +6,7 @@ use std::sync::atomic::Ordering;
use std::time::{Duration, Instant};
use bytes::Bytes;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use tracing::{debug, warn};
@@ -29,6 +30,29 @@ const PICK_PENALTY_DRAINING: u64 = 600;
const PICK_PENALTY_STALE: u64 = 300;
const PICK_PENALTY_DEGRADED: u64 = 250;
enum TimedSendError<T> {
Closed(T),
Timeout(T),
}
async fn send_writer_command_with_timeout(
tx: &mpsc::Sender<WriterCommand>,
cmd: WriterCommand,
timeout: Duration,
) -> std::result::Result<(), TimedSendError<WriterCommand>> {
if timeout.is_zero() {
return tx.send(cmd).await.map_err(|err| TimedSendError::Closed(err.0));
}
match tokio::time::timeout(timeout, tx.reserve()).await {
Ok(Ok(permit)) => {
permit.send(cmd);
Ok(())
}
Ok(Err(_)) => Err(TimedSendError::Closed(cmd)),
Err(_) => Err(TimedSendError::Timeout(cmd)),
}
}
impl MePool {
/// Send RPC_PROXY_REQ. `tag_override`: per-user ad_tag (from access.user_ad_tags); if None, uses pool default.
pub async fn send_proxy_req(
@@ -78,8 +102,18 @@ impl MePool {
let mut hybrid_last_recovery_at: Option<Instant> = None;
let hybrid_wait_step = self.me_route_no_writer_wait.max(Duration::from_millis(50));
let mut hybrid_wait_current = hybrid_wait_step;
let hybrid_deadline = Instant::now() + self.me_route_hybrid_max_wait;
loop {
if matches!(no_writer_mode, MeRouteNoWriterMode::HybridAsyncPersistent)
&& Instant::now() >= hybrid_deadline
{
self.stats.increment_me_no_writer_failfast_total();
return Err(ProxyError::Proxy(
"No ME writer available in hybrid wait window".into(),
));
}
let mut skip_writer_id: Option<u64> = None;
let current_meta = self
.registry
.get_meta(conn_id)
@@ -90,12 +124,30 @@ impl MePool {
match current.tx.try_send(WriterCommand::Data(current_payload.clone())) {
Ok(()) => return Ok(()),
Err(TrySendError::Full(cmd)) => {
if current.tx.send(cmd).await.is_ok() {
return Ok(());
match send_writer_command_with_timeout(
&current.tx,
cmd,
self.me_route_blocking_send_timeout,
)
.await
{
Ok(()) => return Ok(()),
Err(TimedSendError::Closed(_)) => {
warn!(writer_id = current.writer_id, "ME writer channel closed");
self.remove_writer_and_close_clients(current.writer_id).await;
continue;
}
Err(TimedSendError::Timeout(_)) => {
debug!(
conn_id,
writer_id = current.writer_id,
timeout_ms = self.me_route_blocking_send_timeout.as_millis()
as u64,
"ME writer send timed out for bound writer, trying reroute"
);
skip_writer_id = Some(current.writer_id);
}
}
warn!(writer_id = current.writer_id, "ME writer channel closed");
self.remove_writer_and_close_clients(current.writer_id).await;
continue;
}
Err(TrySendError::Closed(_)) => {
warn!(writer_id = current.writer_id, "ME writer channel closed");
@@ -200,6 +252,9 @@ impl MePool {
.candidate_indices_for_dc(&writers_snapshot, routed_dc, true)
.await;
}
if let Some(skip_writer_id) = skip_writer_id {
candidate_indices.retain(|idx| writers_snapshot[*idx].id != skip_writer_id);
}
if candidate_indices.is_empty() {
let pick_mode = self.writer_pick_mode();
match no_writer_mode {
@@ -422,7 +477,13 @@ impl MePool {
self.stats.increment_me_writer_pick_blocking_fallback_total();
let effective_our_addr = SocketAddr::new(w.source_ip, our_addr.port());
let (payload, meta) = build_routed_payload(effective_our_addr);
match w.tx.send(WriterCommand::Data(payload.clone())).await {
match send_writer_command_with_timeout(
&w.tx,
WriterCommand::Data(payload.clone()),
self.me_route_blocking_send_timeout,
)
.await
{
Ok(()) => {
self.stats
.increment_me_writer_pick_success_fallback_total(pick_mode);
@@ -439,11 +500,20 @@ impl MePool {
}
return Ok(());
}
Err(_) => {
Err(TimedSendError::Closed(_)) => {
self.stats.increment_me_writer_pick_closed_total(pick_mode);
warn!(writer_id = w.id, "ME writer channel closed (blocking)");
self.remove_writer_and_close_clients(w.id).await;
}
Err(TimedSendError::Timeout(_)) => {
self.stats.increment_me_writer_pick_full_total(pick_mode);
debug!(
conn_id,
writer_id = w.id,
timeout_ms = self.me_route_blocking_send_timeout.as_millis() as u64,
"ME writer blocking fallback send timed out"
);
}
}
}
}
@@ -573,13 +643,19 @@ impl MePool {
let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes());
if w.tx
.send(WriterCommand::DataAndFlush(Bytes::from(p)))
.await
.is_err()
{
debug!("ME close write failed");
self.remove_writer_and_close_clients(w.writer_id).await;
match w.tx.try_send(WriterCommand::DataAndFlush(Bytes::from(p))) {
Ok(()) => {}
Err(TrySendError::Full(_)) => {
debug!(
conn_id,
writer_id = w.writer_id,
"ME close skipped: writer command channel is full"
);
}
Err(TrySendError::Closed(_)) => {
debug!("ME close write failed");
self.remove_writer_and_close_clients(w.writer_id).await;
}
}
} else {
debug!(conn_id, "ME close skipped (writer missing)");
@@ -596,8 +672,12 @@ impl MePool {
p.extend_from_slice(&conn_id.to_le_bytes());
match w.tx.try_send(WriterCommand::DataAndFlush(Bytes::from(p))) {
Ok(()) => {}
Err(TrySendError::Full(cmd)) => {
let _ = tokio::time::timeout(Duration::from_millis(50), w.tx.send(cmd)).await;
Err(TrySendError::Full(_)) => {
debug!(
conn_id,
writer_id = w.writer_id,
"ME close_conn skipped: writer command channel is full"
);
}
Err(TrySendError::Closed(_)) => {
debug!(conn_id, "ME close_conn skipped: writer channel closed");