diff --git a/src/main.rs b/src/main.rs index 3b6a543..a706373 100644 --- a/src/main.rs +++ b/src/main.rs @@ -37,6 +37,7 @@ use crate::crypto::SecureRandom; use crate::ip_tracker::UserIpTracker; use crate::network::probe::{decide_network_capabilities, log_probe_result, run_probe}; use crate::proxy::ClientHandler; +use crate::proxy::route_mode::{ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteRuntimeController}; use crate::stats::beobachten::BeobachtenStore; use crate::stats::telemetry::TelemetryPolicy; use crate::stats::{ReplayChecker, Stats}; @@ -261,6 +262,10 @@ async fn wait_until_admission_open(admission_rx: &mut watch::Receiver) -> } } +fn is_expected_handshake_eof(err: &crate::error::ProxyError) -> bool { + err.to_string().contains("expected 64 bytes, got 0") +} + async fn load_startup_proxy_config_snapshot( url: &str, cache_path: Option<&str>, @@ -519,6 +524,12 @@ async fn main() -> std::result::Result<(), Box> { let (api_config_tx, api_config_rx) = watch::channel(Arc::new(config.clone())); let initial_admission_open = !config.general.use_middle_proxy; let (admission_tx, admission_rx) = watch::channel(initial_admission_open); + let initial_route_mode = if config.general.use_middle_proxy { + RelayRouteMode::Middle + } else { + RelayRouteMode::Direct + }; + let route_runtime = Arc::new(RouteRuntimeController::new(initial_route_mode)); let api_me_pool = Arc::new(RwLock::new(None::>)); startup_tracker .start_component(COMPONENT_API_BOOTSTRAP, Some("spawn API listener task".to_string())) @@ -1783,9 +1794,11 @@ async fn main() -> std::result::Result<(), Box> { if config.general.use_middle_proxy { if let Some(pool) = me_pool.as_ref() { - let initial_open = pool.admission_ready_conditional_cast().await; - admission_tx.send_replace(initial_open); - if initial_open { + let fallback_after = Duration::from_secs(6); + let initial_ready = pool.admission_ready_conditional_cast().await; + admission_tx.send_replace(initial_ready); + let _ = route_runtime.set_mode(RelayRouteMode::Middle); + if initial_ready { info!("Conditional-admission gate: open (ME pool ready)"); } else { warn!("Conditional-admission gate: closed (ME pool is not ready)"); @@ -1793,12 +1806,18 @@ async fn main() -> std::result::Result<(), Box> { let pool_for_gate = pool.clone(); let admission_tx_gate = admission_tx.clone(); + let route_runtime_gate = route_runtime.clone(); let mut config_rx_gate = config_rx.clone(); let mut admission_poll_ms = config.general.me_admission_poll_ms.max(1); + let mut fallback_enabled = config.general.me2dc_fallback; tokio::spawn(async move { - let mut gate_open = initial_open; - let mut open_streak = if initial_open { 1u32 } else { 0u32 }; - let mut close_streak = if initial_open { 0u32 } else { 1u32 }; + let mut gate_open = initial_ready; + let mut route_mode = RelayRouteMode::Middle; + let mut not_ready_since = if initial_ready { + None + } else { + Some(Instant::now()) + }; loop { tokio::select! { changed = config_rx_gate.changed() => { @@ -1807,42 +1826,70 @@ async fn main() -> std::result::Result<(), Box> { } let cfg = config_rx_gate.borrow_and_update().clone(); admission_poll_ms = cfg.general.me_admission_poll_ms.max(1); + fallback_enabled = cfg.general.me2dc_fallback; continue; } _ = tokio::time::sleep(Duration::from_millis(admission_poll_ms)) => {} } let ready = pool_for_gate.admission_ready_conditional_cast().await; - if ready { - open_streak = open_streak.saturating_add(1); - close_streak = 0; - if !gate_open && open_streak >= 2 { - gate_open = true; - admission_tx_gate.send_replace(true); - info!( - open_streak, - "Conditional-admission gate opened (ME pool recovered)" - ); - } + let now = Instant::now(); + let (next_gate_open, next_route_mode, next_fallback_active) = if ready { + not_ready_since = None; + (true, RelayRouteMode::Middle, false) } else { - close_streak = close_streak.saturating_add(1); - open_streak = 0; - if gate_open && close_streak >= 2 { - gate_open = false; - admission_tx_gate.send_replace(false); - warn!( - close_streak, - "Conditional-admission gate closed (ME pool has uncovered DC groups)" - ); + let not_ready_started_at = *not_ready_since.get_or_insert(now); + let not_ready_for = now.saturating_duration_since(not_ready_started_at); + if fallback_enabled && not_ready_for > fallback_after { + (true, RelayRouteMode::Direct, true) + } else { + (false, RelayRouteMode::Middle, false) + } + }; + + if next_route_mode != route_mode { + route_mode = next_route_mode; + if let Some(snapshot) = route_runtime_gate.set_mode(route_mode) { + if matches!(route_mode, RelayRouteMode::Middle) { + info!( + target_mode = route_mode.as_str(), + cutover_generation = snapshot.generation, + "Middle-End routing restored for new sessions" + ); + } else { + warn!( + target_mode = route_mode.as_str(), + cutover_generation = snapshot.generation, + grace_secs = fallback_after.as_secs(), + "ME pool stayed not-ready beyond grace; routing new sessions via Direct-DC" + ); + } } } + + if next_gate_open != gate_open { + gate_open = next_gate_open; + admission_tx_gate.send_replace(gate_open); + if gate_open { + if next_fallback_active { + warn!("Conditional-admission gate opened in ME fallback mode"); + } else { + info!("Conditional-admission gate opened (ME pool ready)"); + } + } else { + warn!("Conditional-admission gate closed (ME pool is not ready)"); + } + } + } }); } else { admission_tx.send_replace(false); + let _ = route_runtime.set_mode(RelayRouteMode::Direct); warn!("Conditional-admission gate: closed (ME pool is unavailable)"); } } else { admission_tx.send_replace(true); + let _ = route_runtime.set_mode(RelayRouteMode::Direct); } let _admission_tx_hold = admission_tx; @@ -1886,6 +1933,7 @@ async fn main() -> std::result::Result<(), Box> { let buffer_pool = buffer_pool.clone(); let rng = rng.clone(); let me_pool = me_pool.clone(); + let route_runtime = route_runtime.clone(); let tls_cache = tls_cache.clone(); let ip_tracker = ip_tracker.clone(); let beobachten = beobachten.clone(); @@ -1918,6 +1966,7 @@ async fn main() -> std::result::Result<(), Box> { let buffer_pool = buffer_pool.clone(); let rng = rng.clone(); let me_pool = me_pool.clone(); + let route_runtime = route_runtime.clone(); let tls_cache = tls_cache.clone(); let ip_tracker = ip_tracker.clone(); let beobachten = beobachten.clone(); @@ -1928,7 +1977,7 @@ async fn main() -> std::result::Result<(), Box> { if let Err(e) = crate::proxy::client::handle_client_stream( stream, fake_peer, config, stats, upstream_manager, replay_checker, buffer_pool, rng, - me_pool, tls_cache, ip_tracker, beobachten, proxy_protocol_enabled, + me_pool, route_runtime, tls_cache, ip_tracker, beobachten, proxy_protocol_enabled, ).await { debug!(error = %e, "Unix socket connection error"); } @@ -2039,6 +2088,7 @@ async fn main() -> std::result::Result<(), Box> { let buffer_pool = buffer_pool.clone(); let rng = rng.clone(); let me_pool = me_pool.clone(); + let route_runtime = route_runtime.clone(); let tls_cache = tls_cache.clone(); let ip_tracker = ip_tracker.clone(); let beobachten = beobachten.clone(); @@ -2066,6 +2116,7 @@ async fn main() -> std::result::Result<(), Box> { let buffer_pool = buffer_pool.clone(); let rng = rng.clone(); let me_pool = me_pool.clone(); + let route_runtime = route_runtime.clone(); let tls_cache = tls_cache.clone(); let ip_tracker = ip_tracker.clone(); let beobachten = beobachten.clone(); @@ -2083,6 +2134,7 @@ async fn main() -> std::result::Result<(), Box> { buffer_pool, rng, me_pool, + route_runtime, tls_cache, ip_tracker, beobachten, @@ -2119,10 +2171,20 @@ async fn main() -> std::result::Result<(), Box> { &e, crate::error::ProxyError::Proxy(msg) if msg == "ME connection lost" ); + let route_switched = matches!( + &e, + crate::error::ProxyError::Proxy(msg) if msg == ROUTE_SWITCH_ERROR_MSG + ); match (peer_closed, me_closed) { (true, _) => debug!(peer = %peer_addr, error = %e, "Connection closed by client"), (_, true) => warn!(peer = %peer_addr, error = %e, "Connection closed: Middle-End dropped session"), + _ if route_switched => { + info!(peer = %peer_addr, error = %e, "Connection closed by controlled route cutover") + } + _ if is_expected_handshake_eof(&e) => { + info!(peer = %peer_addr, error = %e, "Connection closed during initial handshake") + } _ => warn!(peer = %peer_addr, error = %e, "Connection closed with error"), } } diff --git a/src/proxy/client.rs b/src/proxy/client.rs index ebfabcb..cbe59ce 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -39,6 +39,7 @@ use crate::proxy::direct_relay::handle_via_direct; use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake}; use crate::proxy::masking::handle_bad_client; use crate::proxy::middle_relay::handle_via_middle_proxy; +use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; fn beobachten_ttl(config: &ProxyConfig) -> Duration { Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60)) @@ -80,6 +81,7 @@ pub async fn handle_client_stream( buffer_pool: Arc, rng: Arc, me_pool: Option>, + route_runtime: Arc, tls_cache: Option>, ip_tracker: Arc, beobachten: Arc, @@ -214,6 +216,7 @@ where RunningClientHandler::handle_authenticated_static( crypto_reader, crypto_writer, success, upstream_manager, stats, config, buffer_pool, rng, me_pool, + route_runtime.clone(), local_addr, real_peer, ip_tracker.clone(), ), ))) @@ -274,6 +277,7 @@ where buffer_pool, rng, me_pool, + route_runtime.clone(), local_addr, real_peer, ip_tracker.clone(), @@ -324,6 +328,7 @@ pub struct RunningClientHandler { buffer_pool: Arc, rng: Arc, me_pool: Option>, + route_runtime: Arc, tls_cache: Option>, ip_tracker: Arc, beobachten: Arc, @@ -341,6 +346,7 @@ impl ClientHandler { buffer_pool: Arc, rng: Arc, me_pool: Option>, + route_runtime: Arc, tls_cache: Option>, ip_tracker: Arc, beobachten: Arc, @@ -356,6 +362,7 @@ impl ClientHandler { buffer_pool, rng, me_pool, + route_runtime, tls_cache, ip_tracker, beobachten, @@ -597,6 +604,7 @@ impl RunningClientHandler { buffer_pool, self.rng, self.me_pool, + self.route_runtime.clone(), local_addr, peer, self.ip_tracker, @@ -677,6 +685,7 @@ impl RunningClientHandler { buffer_pool, self.rng, self.me_pool, + self.route_runtime.clone(), local_addr, peer, self.ip_tracker, @@ -698,6 +707,7 @@ impl RunningClientHandler { buffer_pool: Arc, rng: Arc, me_pool: Option>, + route_runtime: Arc, local_addr: SocketAddr, peer_addr: SocketAddr, ip_tracker: Arc, @@ -713,7 +723,11 @@ impl RunningClientHandler { return Err(e); } - let relay_result = if config.general.use_middle_proxy { + let route_snapshot = route_runtime.snapshot(); + let session_id = rng.u64(); + let relay_result = if config.general.use_middle_proxy + && matches!(route_snapshot.mode, RelayRouteMode::Middle) + { if let Some(ref pool) = me_pool { handle_via_middle_proxy( client_reader, @@ -725,6 +739,9 @@ impl RunningClientHandler { buffer_pool, local_addr, rng, + route_runtime.subscribe(), + route_snapshot, + session_id, ) .await } else { @@ -738,6 +755,9 @@ impl RunningClientHandler { config, buffer_pool, rng, + route_runtime.subscribe(), + route_snapshot, + session_id, ) .await } @@ -752,6 +772,9 @@ impl RunningClientHandler { config, buffer_pool, rng, + route_runtime.subscribe(), + route_snapshot, + session_id, ) .await }; diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index d4b0f2e..7a7810a 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -5,14 +5,19 @@ use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; +use tokio::sync::watch; use tracing::{debug, info, warn}; use crate::config::ProxyConfig; use crate::crypto::SecureRandom; -use crate::error::Result; +use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; use crate::proxy::handshake::{HandshakeSuccess, encrypt_tg_nonce_with_ciphers, generate_tg_nonce}; use crate::proxy::relay::relay_bidirectional; +use crate::proxy::route_mode::{ + RelayRouteMode, RouteCutoverState, ROUTE_SWITCH_ERROR_MSG, affected_cutover_state, + cutover_stagger_delay, +}; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::UpstreamManager; @@ -26,6 +31,9 @@ pub(crate) async fn handle_via_direct( config: Arc, buffer_pool: Arc, rng: Arc, + mut route_rx: watch::Receiver, + route_snapshot: RouteCutoverState, + session_id: u64, ) -> Result<()> where R: AsyncRead + Unpin + Send + 'static, @@ -69,8 +77,36 @@ where user, Arc::clone(&stats), buffer_pool, - ) - .await; + ); + tokio::pin!(relay_result); + let relay_result = loop { + if let Some(cutover) = affected_cutover_state( + &route_rx, + RelayRouteMode::Direct, + route_snapshot.generation, + ) { + let delay = cutover_stagger_delay(session_id, cutover.generation); + warn!( + user = %user, + target_mode = cutover.mode.as_str(), + cutover_generation = cutover.generation, + delay_ms = delay.as_millis() as u64, + "Cutover affected direct session, closing client connection" + ); + tokio::time::sleep(delay).await; + break Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); + } + tokio::select! { + result = &mut relay_result => { + break result; + } + changed = route_rx.changed() => { + if changed.is_err() { + break relay_result.await; + } + } + } + }; stats.decrement_current_connections_direct(); stats.decrement_user_curr_connects(user); diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 0006914..efaa8ba 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -8,7 +8,7 @@ use std::time::{Duration, Instant}; use bytes::Bytes; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::{mpsc, oneshot, watch}; use tracing::{debug, trace, warn}; use crate::config::ProxyConfig; @@ -16,6 +16,10 @@ use crate::crypto::SecureRandom; use crate::error::{ProxyError, Result}; use crate::protocol::constants::{*, secure_padding_len}; use crate::proxy::handshake::HandshakeSuccess; +use crate::proxy::route_mode::{ + RelayRouteMode, RouteCutoverState, ROUTE_SWITCH_ERROR_MSG, affected_cutover_state, + cutover_stagger_delay, +}; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; @@ -228,6 +232,9 @@ pub(crate) async fn handle_via_middle_proxy( _buffer_pool: Arc, local_addr: SocketAddr, rng: Arc, + mut route_rx: watch::Receiver, + route_snapshot: RouteCutoverState, + session_id: u64, ) -> Result<()> where R: AsyncRead + Unpin + Send + 'static, @@ -267,6 +274,27 @@ where stats.increment_user_curr_connects(&user); stats.increment_current_connections_me(); + if let Some(cutover) = affected_cutover_state( + &route_rx, + RelayRouteMode::Middle, + route_snapshot.generation, + ) { + let delay = cutover_stagger_delay(session_id, cutover.generation); + warn!( + conn_id, + target_mode = cutover.mode.as_str(), + cutover_generation = cutover.generation, + delay_ms = delay.as_millis() as u64, + "Cutover affected middle session before relay start, closing client connection" + ); + tokio::time::sleep(delay).await; + let _ = me_pool.send_close(conn_id).await; + me_pool.registry().unregister(conn_id).await; + stats.decrement_current_connections_me(); + stats.decrement_user_curr_connects(&user); + return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); + } + // Per-user ad_tag from access.user_ad_tags; fallback to general.ad_tag (hot-reloadable) let user_tag: Option> = config .access @@ -498,46 +526,75 @@ where let mut main_result: Result<()> = Ok(()); let mut client_closed = false; let mut frame_counter: u64 = 0; + let mut route_watch_open = true; loop { - match read_client_payload( - &mut crypto_reader, - proto_tag, - frame_limit, - &forensics, - &mut frame_counter, - &stats, - ).await { - Ok(Some((payload, quickack))) => { - trace!(conn_id, bytes = payload.len(), "C->ME frame"); - forensics.bytes_c2me = forensics - .bytes_c2me - .saturating_add(payload.len() as u64); - stats.add_user_octets_from(&user, payload.len() as u64); - let mut flags = proto_flags; - if quickack { - flags |= RPC_FLAG_QUICKACK; - } - if payload.len() >= 8 && payload[..8].iter().all(|b| *b == 0) { - flags |= RPC_FLAG_NOT_ENCRYPTED; - } - // Keep client read loop lightweight: route heavy ME send path via a dedicated task. - if enqueue_c2me_command(&c2me_tx, C2MeCommand::Data { payload, flags }) - .await - .is_err() - { - main_result = Err(ProxyError::Proxy("ME sender channel closed".into())); - break; + if let Some(cutover) = affected_cutover_state( + &route_rx, + RelayRouteMode::Middle, + route_snapshot.generation, + ) { + let delay = cutover_stagger_delay(session_id, cutover.generation); + warn!( + conn_id, + target_mode = cutover.mode.as_str(), + cutover_generation = cutover.generation, + delay_ms = delay.as_millis() as u64, + "Cutover affected middle session, closing client connection" + ); + tokio::time::sleep(delay).await; + let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await; + main_result = Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); + break; + } + + tokio::select! { + changed = route_rx.changed(), if route_watch_open => { + if changed.is_err() { + route_watch_open = false; } } - Ok(None) => { - debug!(conn_id, "Client EOF"); - client_closed = true; - let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await; - break; - } - Err(e) => { - main_result = Err(e); - break; + payload_result = read_client_payload( + &mut crypto_reader, + proto_tag, + frame_limit, + &forensics, + &mut frame_counter, + &stats, + ) => { + match payload_result { + Ok(Some((payload, quickack))) => { + trace!(conn_id, bytes = payload.len(), "C->ME frame"); + forensics.bytes_c2me = forensics + .bytes_c2me + .saturating_add(payload.len() as u64); + stats.add_user_octets_from(&user, payload.len() as u64); + let mut flags = proto_flags; + if quickack { + flags |= RPC_FLAG_QUICKACK; + } + if payload.len() >= 8 && payload[..8].iter().all(|b| *b == 0) { + flags |= RPC_FLAG_NOT_ENCRYPTED; + } + // Keep client read loop lightweight: route heavy ME send path via a dedicated task. + if enqueue_c2me_command(&c2me_tx, C2MeCommand::Data { payload, flags }) + .await + .is_err() + { + main_result = Err(ProxyError::Proxy("ME sender channel closed".into())); + break; + } + } + Ok(None) => { + debug!(conn_id, "Client EOF"); + client_closed = true; + let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await; + break; + } + Err(e) => { + main_result = Err(e); + break; + } + } } } } diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index bedae1a..1eed469 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -5,6 +5,7 @@ pub mod direct_relay; pub mod handshake; pub mod masking; pub mod middle_relay; +pub mod route_mode; pub mod relay; pub use client::ClientHandler; diff --git a/src/proxy/route_mode.rs b/src/proxy/route_mode.rs new file mode 100644 index 0000000..57830ca --- /dev/null +++ b/src/proxy/route_mode.rs @@ -0,0 +1,117 @@ +use std::sync::Arc; +use std::sync::atomic::{AtomicU8, AtomicU64, Ordering}; +use std::time::Duration; + +use tokio::sync::watch; + +pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Route mode switched by cutover"; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[repr(u8)] +pub(crate) enum RelayRouteMode { + Direct = 0, + Middle = 1, +} + +impl RelayRouteMode { + pub(crate) fn as_u8(self) -> u8 { + self as u8 + } + + pub(crate) fn from_u8(value: u8) -> Self { + match value { + 1 => Self::Middle, + _ => Self::Direct, + } + } + + pub(crate) fn as_str(self) -> &'static str { + match self { + Self::Direct => "direct", + Self::Middle => "middle", + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) struct RouteCutoverState { + pub mode: RelayRouteMode, + pub generation: u64, +} + +#[derive(Clone)] +pub(crate) struct RouteRuntimeController { + mode: Arc, + generation: Arc, + tx: watch::Sender, +} + +impl RouteRuntimeController { + pub(crate) fn new(initial_mode: RelayRouteMode) -> Self { + let initial = RouteCutoverState { + mode: initial_mode, + generation: 0, + }; + let (tx, _rx) = watch::channel(initial); + Self { + mode: Arc::new(AtomicU8::new(initial_mode.as_u8())), + generation: Arc::new(AtomicU64::new(0)), + tx, + } + } + + pub(crate) fn snapshot(&self) -> RouteCutoverState { + RouteCutoverState { + mode: RelayRouteMode::from_u8(self.mode.load(Ordering::Relaxed)), + generation: self.generation.load(Ordering::Relaxed), + } + } + + pub(crate) fn subscribe(&self) -> watch::Receiver { + self.tx.subscribe() + } + + pub(crate) fn set_mode(&self, mode: RelayRouteMode) -> Option { + let previous = self.mode.swap(mode.as_u8(), Ordering::Relaxed); + if previous == mode.as_u8() { + return None; + } + let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1; + let next = RouteCutoverState { mode, generation }; + self.tx.send_replace(next); + Some(next) + } +} + +pub(crate) fn is_session_affected_by_cutover( + current: RouteCutoverState, + _session_mode: RelayRouteMode, + session_generation: u64, +) -> bool { + current.generation > session_generation +} + +pub(crate) fn affected_cutover_state( + rx: &watch::Receiver, + session_mode: RelayRouteMode, + session_generation: u64, +) -> Option { + let current = *rx.borrow(); + if is_session_affected_by_cutover(current, session_mode, session_generation) { + return Some(current); + } + None +} + +pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duration { + let mut value = session_id + ^ generation.rotate_left(17) + ^ 0x9e37_79b9_7f4a_7c15; + value ^= value >> 30; + value = value.wrapping_mul(0xbf58_476d_1ce4_e5b9); + value ^= value >> 27; + value = value.wrapping_mul(0x94d0_49bb_1331_11eb); + value ^= value >> 31; + let ms = 1000 + (value % 1000); + Duration::from_millis(ms) +} diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 8d5b110..2d81e63 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -828,10 +828,29 @@ impl MePool { effective } + // Keeps per-contour (active/warm) writer budget bounded by CPU count. + // Baseline is 86 writers on the first core and +48 for each extra core. + fn adaptive_floor_cpu_budget_per_contour_cap(&self, cores: usize) -> usize { + const FIRST_CORE_WRITER_BUDGET: usize = 86; + const EXTRA_CORE_WRITER_BUDGET: usize = 48; + if cores == 0 { + return FIRST_CORE_WRITER_BUDGET; + } + FIRST_CORE_WRITER_BUDGET.saturating_add( + cores + .saturating_sub(1) + .saturating_mul(EXTRA_CORE_WRITER_BUDGET), + ) + } + pub(super) fn adaptive_floor_active_cap_configured_total(&self) -> usize { let cores = self.adaptive_floor_effective_cpu_cores(); - let per_core_cap = cores.saturating_mul(self.adaptive_floor_max_active_writers_per_core()); - let configured = per_core_cap.min(self.adaptive_floor_max_active_writers_global()); + let per_contour_budget = self.adaptive_floor_cpu_budget_per_contour_cap(cores); + let configured = cores + .saturating_mul(self.adaptive_floor_max_active_writers_per_core()) + .min(self.adaptive_floor_max_active_writers_global()) + .min(per_contour_budget) + .max(1); self.me_adaptive_floor_active_cap_configured .store(configured as u64, Ordering::Relaxed); self.stats @@ -841,8 +860,12 @@ impl MePool { pub(super) fn adaptive_floor_warm_cap_configured_total(&self) -> usize { let cores = self.adaptive_floor_effective_cpu_cores(); - let per_core_cap = cores.saturating_mul(self.adaptive_floor_max_warm_writers_per_core()); - let configured = per_core_cap.min(self.adaptive_floor_max_warm_writers_global()); + let per_contour_budget = self.adaptive_floor_cpu_budget_per_contour_cap(cores); + let configured = cores + .saturating_mul(self.adaptive_floor_max_warm_writers_per_core()) + .min(self.adaptive_floor_max_warm_writers_global()) + .min(per_contour_budget) + .max(1); self.me_adaptive_floor_warm_cap_configured .store(configured as u64, Ordering::Relaxed); self.stats