diff --git a/src/config/types.rs b/src/config/types.rs index 4c19f8a..88bf8d3 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -187,9 +187,10 @@ impl MeFloorMode { #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] #[serde(rename_all = "snake_case")] pub enum MeRouteNoWriterMode { - #[default] AsyncRecoveryFailfast, InlineRecoveryLegacy, + #[default] + HybridAsyncPersistent, } impl MeRouteNoWriterMode { @@ -197,13 +198,16 @@ impl MeRouteNoWriterMode { match self { MeRouteNoWriterMode::AsyncRecoveryFailfast => 0, MeRouteNoWriterMode::InlineRecoveryLegacy => 1, + MeRouteNoWriterMode::HybridAsyncPersistent => 2, } } pub fn from_u8(raw: u8) -> Self { match raw { + 0 => MeRouteNoWriterMode::AsyncRecoveryFailfast, 1 => MeRouteNoWriterMode::InlineRecoveryLegacy, - _ => MeRouteNoWriterMode::AsyncRecoveryFailfast, + 2 => MeRouteNoWriterMode::HybridAsyncPersistent, + _ => MeRouteNoWriterMode::HybridAsyncPersistent, } } } diff --git a/src/main.rs b/src/main.rs index 798790a..9f81edf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,7 @@ use std::time::{Duration, Instant}; use rand::Rng; use tokio::net::TcpListener; use tokio::signal; -use tokio::sync::{Semaphore, mpsc}; +use tokio::sync::{Semaphore, mpsc, watch}; use tracing::{debug, error, info, warn}; use tracing_subscriber::{EnvFilter, fmt, prelude::*, reload}; #[cfg(unix)] @@ -241,6 +241,17 @@ fn format_uptime(total_secs: u64) -> String { format!("{} / {} seconds", parts.join(", "), total_secs) } +async fn wait_until_admission_open(admission_rx: &mut watch::Receiver) -> bool { + loop { + if *admission_rx.borrow() { + return true; + } + if admission_rx.changed().await.is_err() { + return *admission_rx.borrow(); + } + } +} + async fn load_startup_proxy_config_snapshot( url: &str, cache_path: Option<&str>, @@ -1325,6 +1336,60 @@ async fn main() -> std::result::Result<(), Box> { print_proxy_links(&host, port, &config); } + let (admission_tx, admission_rx) = watch::channel(true); + if config.general.use_middle_proxy { + if let Some(pool) = me_pool.as_ref() { + let initial_open = pool.admission_ready_full_floor().await; + admission_tx.send_replace(initial_open); + if initial_open { + info!("Conditional-admission gate: open (ME pool ready)"); + } else { + warn!("Conditional-admission gate: closed (ME pool is not ready)"); + } + + let pool_for_gate = pool.clone(); + let admission_tx_gate = admission_tx.clone(); + tokio::spawn(async move { + let mut gate_open = initial_open; + let mut open_streak = if initial_open { 1u32 } else { 0u32 }; + let mut close_streak = if initial_open { 0u32 } else { 1u32 }; + loop { + let ready = pool_for_gate.admission_ready_full_floor().await; + if ready { + open_streak = open_streak.saturating_add(1); + close_streak = 0; + if !gate_open && open_streak >= 2 { + gate_open = true; + admission_tx_gate.send_replace(true); + info!( + open_streak, + "Conditional-admission gate opened (ME pool recovered)" + ); + } + } else { + close_streak = close_streak.saturating_add(1); + open_streak = 0; + if gate_open && close_streak >= 2 { + gate_open = false; + admission_tx_gate.send_replace(false); + warn!( + close_streak, + "Conditional-admission gate closed (ME pool below required floor)" + ); + } + } + tokio::time::sleep(Duration::from_millis(250)).await; + } + }); + } else { + admission_tx.send_replace(false); + warn!("Conditional-admission gate: closed (ME pool is unavailable)"); + } + } else { + admission_tx.send_replace(true); + } + let _admission_tx_hold = admission_tx; + // Unix socket setup (before listeners check so unix-only config works) let mut has_unix_listener = false; #[cfg(unix)] @@ -1358,6 +1423,7 @@ async fn main() -> std::result::Result<(), Box> { has_unix_listener = true; let mut config_rx_unix: tokio::sync::watch::Receiver> = config_rx.clone(); + let mut admission_rx_unix = admission_rx.clone(); let stats = stats.clone(); let upstream_manager = upstream_manager.clone(); let replay_checker = replay_checker.clone(); @@ -1373,6 +1439,10 @@ async fn main() -> std::result::Result<(), Box> { let unix_conn_counter = std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)); loop { + if !wait_until_admission_open(&mut admission_rx_unix).await { + warn!("Conditional-admission gate channel closed for unix listener"); + break; + } match unix_listener.accept().await { Ok((stream, _)) => { let permit = match max_connections_unix.clone().acquire_owned().await { @@ -1507,6 +1577,7 @@ async fn main() -> std::result::Result<(), Box> { for (listener, listener_proxy_protocol) in listeners { let mut config_rx: tokio::sync::watch::Receiver> = config_rx.clone(); + let mut admission_rx_tcp = admission_rx.clone(); let stats = stats.clone(); let upstream_manager = upstream_manager.clone(); let replay_checker = replay_checker.clone(); @@ -1520,6 +1591,10 @@ async fn main() -> std::result::Result<(), Box> { tokio::spawn(async move { loop { + if !wait_until_admission_open(&mut admission_rx_tcp).await { + warn!("Conditional-admission gate channel closed for tcp listener"); + break; + } match listener.accept().await { Ok((stream, peer_addr)) => { let permit = match max_connections_tcp.clone().acquire_owned().await { diff --git a/src/transport/middle_proxy/pool_status.rs b/src/transport/middle_proxy/pool_status.rs index c01f74b..9dd3d07 100644 --- a/src/transport/middle_proxy/pool_status.rs +++ b/src/transport/middle_proxy/pool_status.rs @@ -100,6 +100,72 @@ pub(crate) struct MeApiRuntimeSnapshot { } impl MePool { + pub(crate) async fn admission_ready_full_floor(&self) -> bool { + let mut endpoints_by_dc = BTreeMap::>::new(); + if self.decision.ipv4_me { + let map = self.proxy_map_v4.read().await.clone(); + for (dc, addrs) in map { + let abs_dc = dc.abs(); + if abs_dc == 0 { + continue; + } + let Ok(dc_idx) = i16::try_from(abs_dc) else { + continue; + }; + let entry = endpoints_by_dc.entry(dc_idx).or_default(); + for (ip, port) in addrs { + entry.insert(SocketAddr::new(ip, port)); + } + } + } + if self.decision.ipv6_me { + let map = self.proxy_map_v6.read().await.clone(); + for (dc, addrs) in map { + let abs_dc = dc.abs(); + if abs_dc == 0 { + continue; + } + let Ok(dc_idx) = i16::try_from(abs_dc) else { + continue; + }; + let entry = endpoints_by_dc.entry(dc_idx).or_default(); + for (ip, port) in addrs { + entry.insert(SocketAddr::new(ip, port)); + } + } + } + + if endpoints_by_dc.is_empty() { + return false; + } + + let writers = self.writers.read().await.clone(); + let mut live_writers_by_endpoint = HashMap::::new(); + for writer in writers { + if writer.draining.load(Ordering::Relaxed) { + continue; + } + *live_writers_by_endpoint.entry(writer.addr).or_insert(0) += 1; + } + + for endpoints in endpoints_by_dc.values() { + let endpoint_count = endpoints.len(); + if endpoint_count == 0 { + return false; + } + let required = self.required_writers_for_dc_with_floor_mode(endpoint_count, false); + let alive: usize = endpoints + .iter() + .map(|endpoint| live_writers_by_endpoint.get(endpoint).copied().unwrap_or(0)) + .sum(); + if alive < required { + return false; + } + } + + true + } + pub(crate) async fn api_status_snapshot(&self) -> MeApiStatusSnapshot { let now_epoch_secs = Self::now_epoch_secs(); diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index c6db028..b442a8a 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -22,6 +22,7 @@ use super::registry::ConnMeta; const IDLE_WRITER_PENALTY_MID_SECS: u64 = 45; const IDLE_WRITER_PENALTY_HIGH_SECS: u64 = 55; +const HYBRID_GLOBAL_BURST_PERIOD_ROUNDS: u32 = 4; impl MePool { /// Send RPC_PROXY_REQ. `tag_override`: per-user ad_tag (from access.user_ad_tags); if None, uses pool default. @@ -55,6 +56,9 @@ impl MePool { let mut no_writer_deadline: Option = None; let mut emergency_attempts = 0u32; let mut async_recovery_triggered = false; + let mut hybrid_recovery_round = 0u32; + let mut hybrid_last_recovery_at: Option = None; + let hybrid_wait_step = self.me_route_no_writer_wait.max(Duration::from_millis(50)); loop { if let Some(current) = self.registry.get_writer(conn_id).await { @@ -138,6 +142,18 @@ impl MePool { } continue; } + MeRouteNoWriterMode::HybridAsyncPersistent => { + self.maybe_trigger_hybrid_recovery( + target_dc, + &mut hybrid_recovery_round, + &mut hybrid_last_recovery_at, + hybrid_wait_step, + ) + .await; + let deadline = Instant::now() + hybrid_wait_step; + let _ = self.wait_for_writer_until(deadline).await; + continue; + } } } ws.clone() @@ -215,6 +231,18 @@ impl MePool { return Err(ProxyError::Proxy("No ME writers available for target DC".into())); } } + MeRouteNoWriterMode::HybridAsyncPersistent => { + self.maybe_trigger_hybrid_recovery( + target_dc, + &mut hybrid_recovery_round, + &mut hybrid_last_recovery_at, + hybrid_wait_step, + ) + .await; + let deadline = Instant::now() + hybrid_wait_step; + let _ = self.wait_for_candidate_until(target_dc, deadline).await; + continue; + } } } let writer_idle_since = self.registry.writer_idle_since_snapshot().await; @@ -459,6 +487,28 @@ impl MePool { preferred } + async fn maybe_trigger_hybrid_recovery( + self: &Arc, + target_dc: i16, + hybrid_recovery_round: &mut u32, + hybrid_last_recovery_at: &mut Option, + hybrid_wait_step: Duration, + ) { + if let Some(last) = *hybrid_last_recovery_at + && last.elapsed() < hybrid_wait_step + { + return; + } + + let round = *hybrid_recovery_round; + let target_triggered = self.trigger_async_recovery_for_target_dc(target_dc).await; + if !target_triggered || round % HYBRID_GLOBAL_BURST_PERIOD_ROUNDS == 0 { + self.trigger_async_recovery_global().await; + } + *hybrid_recovery_round = round.saturating_add(1); + *hybrid_last_recovery_at = Some(Instant::now()); + } + pub async fn send_close(self: &Arc, conn_id: u64) -> Result<()> { if let Some(w) = self.registry.get_writer(conn_id).await { let mut p = Vec::with_capacity(12);