ME/DC Reroute + ME Upper-limit tuning

This commit is contained in:
Alexey 2026-03-09 00:53:47 +03:00
parent fc52cad109
commit ef2ed3daa0
No known key found for this signature in database
7 changed files with 392 additions and 73 deletions

View File

@ -37,6 +37,7 @@ use crate::crypto::SecureRandom;
use crate::ip_tracker::UserIpTracker; use crate::ip_tracker::UserIpTracker;
use crate::network::probe::{decide_network_capabilities, log_probe_result, run_probe}; use crate::network::probe::{decide_network_capabilities, log_probe_result, run_probe};
use crate::proxy::ClientHandler; use crate::proxy::ClientHandler;
use crate::proxy::route_mode::{ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteRuntimeController};
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::stats::telemetry::TelemetryPolicy; use crate::stats::telemetry::TelemetryPolicy;
use crate::stats::{ReplayChecker, Stats}; use crate::stats::{ReplayChecker, Stats};
@ -261,6 +262,10 @@ async fn wait_until_admission_open(admission_rx: &mut watch::Receiver<bool>) ->
} }
} }
fn is_expected_handshake_eof(err: &crate::error::ProxyError) -> bool {
err.to_string().contains("expected 64 bytes, got 0")
}
async fn load_startup_proxy_config_snapshot( async fn load_startup_proxy_config_snapshot(
url: &str, url: &str,
cache_path: Option<&str>, cache_path: Option<&str>,
@ -519,6 +524,12 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let (api_config_tx, api_config_rx) = watch::channel(Arc::new(config.clone())); let (api_config_tx, api_config_rx) = watch::channel(Arc::new(config.clone()));
let initial_admission_open = !config.general.use_middle_proxy; let initial_admission_open = !config.general.use_middle_proxy;
let (admission_tx, admission_rx) = watch::channel(initial_admission_open); let (admission_tx, admission_rx) = watch::channel(initial_admission_open);
let initial_route_mode = if config.general.use_middle_proxy {
RelayRouteMode::Middle
} else {
RelayRouteMode::Direct
};
let route_runtime = Arc::new(RouteRuntimeController::new(initial_route_mode));
let api_me_pool = Arc::new(RwLock::new(None::<Arc<MePool>>)); let api_me_pool = Arc::new(RwLock::new(None::<Arc<MePool>>));
startup_tracker startup_tracker
.start_component(COMPONENT_API_BOOTSTRAP, Some("spawn API listener task".to_string())) .start_component(COMPONENT_API_BOOTSTRAP, Some("spawn API listener task".to_string()))
@ -1783,9 +1794,11 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
if config.general.use_middle_proxy { if config.general.use_middle_proxy {
if let Some(pool) = me_pool.as_ref() { if let Some(pool) = me_pool.as_ref() {
let initial_open = pool.admission_ready_conditional_cast().await; let fallback_after = Duration::from_secs(6);
admission_tx.send_replace(initial_open); let initial_ready = pool.admission_ready_conditional_cast().await;
if initial_open { admission_tx.send_replace(initial_ready);
let _ = route_runtime.set_mode(RelayRouteMode::Middle);
if initial_ready {
info!("Conditional-admission gate: open (ME pool ready)"); info!("Conditional-admission gate: open (ME pool ready)");
} else { } else {
warn!("Conditional-admission gate: closed (ME pool is not ready)"); warn!("Conditional-admission gate: closed (ME pool is not ready)");
@ -1793,12 +1806,18 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let pool_for_gate = pool.clone(); let pool_for_gate = pool.clone();
let admission_tx_gate = admission_tx.clone(); let admission_tx_gate = admission_tx.clone();
let route_runtime_gate = route_runtime.clone();
let mut config_rx_gate = config_rx.clone(); let mut config_rx_gate = config_rx.clone();
let mut admission_poll_ms = config.general.me_admission_poll_ms.max(1); let mut admission_poll_ms = config.general.me_admission_poll_ms.max(1);
let mut fallback_enabled = config.general.me2dc_fallback;
tokio::spawn(async move { tokio::spawn(async move {
let mut gate_open = initial_open; let mut gate_open = initial_ready;
let mut open_streak = if initial_open { 1u32 } else { 0u32 }; let mut route_mode = RelayRouteMode::Middle;
let mut close_streak = if initial_open { 0u32 } else { 1u32 }; let mut not_ready_since = if initial_ready {
None
} else {
Some(Instant::now())
};
loop { loop {
tokio::select! { tokio::select! {
changed = config_rx_gate.changed() => { changed = config_rx_gate.changed() => {
@ -1807,42 +1826,70 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
} }
let cfg = config_rx_gate.borrow_and_update().clone(); let cfg = config_rx_gate.borrow_and_update().clone();
admission_poll_ms = cfg.general.me_admission_poll_ms.max(1); admission_poll_ms = cfg.general.me_admission_poll_ms.max(1);
fallback_enabled = cfg.general.me2dc_fallback;
continue; continue;
} }
_ = tokio::time::sleep(Duration::from_millis(admission_poll_ms)) => {} _ = tokio::time::sleep(Duration::from_millis(admission_poll_ms)) => {}
} }
let ready = pool_for_gate.admission_ready_conditional_cast().await; let ready = pool_for_gate.admission_ready_conditional_cast().await;
if ready { let now = Instant::now();
open_streak = open_streak.saturating_add(1); let (next_gate_open, next_route_mode, next_fallback_active) = if ready {
close_streak = 0; not_ready_since = None;
if !gate_open && open_streak >= 2 { (true, RelayRouteMode::Middle, false)
gate_open = true;
admission_tx_gate.send_replace(true);
info!(
open_streak,
"Conditional-admission gate opened (ME pool recovered)"
);
}
} else { } else {
close_streak = close_streak.saturating_add(1); let not_ready_started_at = *not_ready_since.get_or_insert(now);
open_streak = 0; let not_ready_for = now.saturating_duration_since(not_ready_started_at);
if gate_open && close_streak >= 2 { if fallback_enabled && not_ready_for > fallback_after {
gate_open = false; (true, RelayRouteMode::Direct, true)
admission_tx_gate.send_replace(false); } else {
warn!( (false, RelayRouteMode::Middle, false)
close_streak, }
"Conditional-admission gate closed (ME pool has uncovered DC groups)" };
);
if next_route_mode != route_mode {
route_mode = next_route_mode;
if let Some(snapshot) = route_runtime_gate.set_mode(route_mode) {
if matches!(route_mode, RelayRouteMode::Middle) {
info!(
target_mode = route_mode.as_str(),
cutover_generation = snapshot.generation,
"Middle-End routing restored for new sessions"
);
} else {
warn!(
target_mode = route_mode.as_str(),
cutover_generation = snapshot.generation,
grace_secs = fallback_after.as_secs(),
"ME pool stayed not-ready beyond grace; routing new sessions via Direct-DC"
);
}
} }
} }
if next_gate_open != gate_open {
gate_open = next_gate_open;
admission_tx_gate.send_replace(gate_open);
if gate_open {
if next_fallback_active {
warn!("Conditional-admission gate opened in ME fallback mode");
} else {
info!("Conditional-admission gate opened (ME pool ready)");
}
} else {
warn!("Conditional-admission gate closed (ME pool is not ready)");
}
}
} }
}); });
} else { } else {
admission_tx.send_replace(false); admission_tx.send_replace(false);
let _ = route_runtime.set_mode(RelayRouteMode::Direct);
warn!("Conditional-admission gate: closed (ME pool is unavailable)"); warn!("Conditional-admission gate: closed (ME pool is unavailable)");
} }
} else { } else {
admission_tx.send_replace(true); admission_tx.send_replace(true);
let _ = route_runtime.set_mode(RelayRouteMode::Direct);
} }
let _admission_tx_hold = admission_tx; let _admission_tx_hold = admission_tx;
@ -1886,6 +1933,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let buffer_pool = buffer_pool.clone(); let buffer_pool = buffer_pool.clone();
let rng = rng.clone(); let rng = rng.clone();
let me_pool = me_pool.clone(); let me_pool = me_pool.clone();
let route_runtime = route_runtime.clone();
let tls_cache = tls_cache.clone(); let tls_cache = tls_cache.clone();
let ip_tracker = ip_tracker.clone(); let ip_tracker = ip_tracker.clone();
let beobachten = beobachten.clone(); let beobachten = beobachten.clone();
@ -1918,6 +1966,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let buffer_pool = buffer_pool.clone(); let buffer_pool = buffer_pool.clone();
let rng = rng.clone(); let rng = rng.clone();
let me_pool = me_pool.clone(); let me_pool = me_pool.clone();
let route_runtime = route_runtime.clone();
let tls_cache = tls_cache.clone(); let tls_cache = tls_cache.clone();
let ip_tracker = ip_tracker.clone(); let ip_tracker = ip_tracker.clone();
let beobachten = beobachten.clone(); let beobachten = beobachten.clone();
@ -1928,7 +1977,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
if let Err(e) = crate::proxy::client::handle_client_stream( if let Err(e) = crate::proxy::client::handle_client_stream(
stream, fake_peer, config, stats, stream, fake_peer, config, stats,
upstream_manager, replay_checker, buffer_pool, rng, upstream_manager, replay_checker, buffer_pool, rng,
me_pool, tls_cache, ip_tracker, beobachten, proxy_protocol_enabled, me_pool, route_runtime, tls_cache, ip_tracker, beobachten, proxy_protocol_enabled,
).await { ).await {
debug!(error = %e, "Unix socket connection error"); debug!(error = %e, "Unix socket connection error");
} }
@ -2039,6 +2088,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let buffer_pool = buffer_pool.clone(); let buffer_pool = buffer_pool.clone();
let rng = rng.clone(); let rng = rng.clone();
let me_pool = me_pool.clone(); let me_pool = me_pool.clone();
let route_runtime = route_runtime.clone();
let tls_cache = tls_cache.clone(); let tls_cache = tls_cache.clone();
let ip_tracker = ip_tracker.clone(); let ip_tracker = ip_tracker.clone();
let beobachten = beobachten.clone(); let beobachten = beobachten.clone();
@ -2066,6 +2116,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let buffer_pool = buffer_pool.clone(); let buffer_pool = buffer_pool.clone();
let rng = rng.clone(); let rng = rng.clone();
let me_pool = me_pool.clone(); let me_pool = me_pool.clone();
let route_runtime = route_runtime.clone();
let tls_cache = tls_cache.clone(); let tls_cache = tls_cache.clone();
let ip_tracker = ip_tracker.clone(); let ip_tracker = ip_tracker.clone();
let beobachten = beobachten.clone(); let beobachten = beobachten.clone();
@ -2083,6 +2134,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
buffer_pool, buffer_pool,
rng, rng,
me_pool, me_pool,
route_runtime,
tls_cache, tls_cache,
ip_tracker, ip_tracker,
beobachten, beobachten,
@ -2119,10 +2171,20 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
&e, &e,
crate::error::ProxyError::Proxy(msg) if msg == "ME connection lost" crate::error::ProxyError::Proxy(msg) if msg == "ME connection lost"
); );
let route_switched = matches!(
&e,
crate::error::ProxyError::Proxy(msg) if msg == ROUTE_SWITCH_ERROR_MSG
);
match (peer_closed, me_closed) { match (peer_closed, me_closed) {
(true, _) => debug!(peer = %peer_addr, error = %e, "Connection closed by client"), (true, _) => debug!(peer = %peer_addr, error = %e, "Connection closed by client"),
(_, true) => warn!(peer = %peer_addr, error = %e, "Connection closed: Middle-End dropped session"), (_, true) => warn!(peer = %peer_addr, error = %e, "Connection closed: Middle-End dropped session"),
_ if route_switched => {
info!(peer = %peer_addr, error = %e, "Connection closed by controlled route cutover")
}
_ if is_expected_handshake_eof(&e) => {
info!(peer = %peer_addr, error = %e, "Connection closed during initial handshake")
}
_ => warn!(peer = %peer_addr, error = %e, "Connection closed with error"), _ => warn!(peer = %peer_addr, error = %e, "Connection closed with error"),
} }
} }

View File

@ -39,6 +39,7 @@ use crate::proxy::direct_relay::handle_via_direct;
use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake}; use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake};
use crate::proxy::masking::handle_bad_client; use crate::proxy::masking::handle_bad_client;
use crate::proxy::middle_relay::handle_via_middle_proxy; use crate::proxy::middle_relay::handle_via_middle_proxy;
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
fn beobachten_ttl(config: &ProxyConfig) -> Duration { fn beobachten_ttl(config: &ProxyConfig) -> Duration {
Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60)) Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60))
@ -80,6 +81,7 @@ pub async fn handle_client_stream<S>(
buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>, me_pool: Option<Arc<MePool>>,
route_runtime: Arc<RouteRuntimeController>,
tls_cache: Option<Arc<TlsFrontCache>>, tls_cache: Option<Arc<TlsFrontCache>>,
ip_tracker: Arc<UserIpTracker>, ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>, beobachten: Arc<BeobachtenStore>,
@ -214,6 +216,7 @@ where
RunningClientHandler::handle_authenticated_static( RunningClientHandler::handle_authenticated_static(
crypto_reader, crypto_writer, success, crypto_reader, crypto_writer, success,
upstream_manager, stats, config, buffer_pool, rng, me_pool, upstream_manager, stats, config, buffer_pool, rng, me_pool,
route_runtime.clone(),
local_addr, real_peer, ip_tracker.clone(), local_addr, real_peer, ip_tracker.clone(),
), ),
))) )))
@ -274,6 +277,7 @@ where
buffer_pool, buffer_pool,
rng, rng,
me_pool, me_pool,
route_runtime.clone(),
local_addr, local_addr,
real_peer, real_peer,
ip_tracker.clone(), ip_tracker.clone(),
@ -324,6 +328,7 @@ pub struct RunningClientHandler {
buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>, me_pool: Option<Arc<MePool>>,
route_runtime: Arc<RouteRuntimeController>,
tls_cache: Option<Arc<TlsFrontCache>>, tls_cache: Option<Arc<TlsFrontCache>>,
ip_tracker: Arc<UserIpTracker>, ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>, beobachten: Arc<BeobachtenStore>,
@ -341,6 +346,7 @@ impl ClientHandler {
buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>, me_pool: Option<Arc<MePool>>,
route_runtime: Arc<RouteRuntimeController>,
tls_cache: Option<Arc<TlsFrontCache>>, tls_cache: Option<Arc<TlsFrontCache>>,
ip_tracker: Arc<UserIpTracker>, ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>, beobachten: Arc<BeobachtenStore>,
@ -356,6 +362,7 @@ impl ClientHandler {
buffer_pool, buffer_pool,
rng, rng,
me_pool, me_pool,
route_runtime,
tls_cache, tls_cache,
ip_tracker, ip_tracker,
beobachten, beobachten,
@ -597,6 +604,7 @@ impl RunningClientHandler {
buffer_pool, buffer_pool,
self.rng, self.rng,
self.me_pool, self.me_pool,
self.route_runtime.clone(),
local_addr, local_addr,
peer, peer,
self.ip_tracker, self.ip_tracker,
@ -677,6 +685,7 @@ impl RunningClientHandler {
buffer_pool, buffer_pool,
self.rng, self.rng,
self.me_pool, self.me_pool,
self.route_runtime.clone(),
local_addr, local_addr,
peer, peer,
self.ip_tracker, self.ip_tracker,
@ -698,6 +707,7 @@ impl RunningClientHandler {
buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
me_pool: Option<Arc<MePool>>, me_pool: Option<Arc<MePool>>,
route_runtime: Arc<RouteRuntimeController>,
local_addr: SocketAddr, local_addr: SocketAddr,
peer_addr: SocketAddr, peer_addr: SocketAddr,
ip_tracker: Arc<UserIpTracker>, ip_tracker: Arc<UserIpTracker>,
@ -713,7 +723,11 @@ impl RunningClientHandler {
return Err(e); return Err(e);
} }
let relay_result = if config.general.use_middle_proxy { let route_snapshot = route_runtime.snapshot();
let session_id = rng.u64();
let relay_result = if config.general.use_middle_proxy
&& matches!(route_snapshot.mode, RelayRouteMode::Middle)
{
if let Some(ref pool) = me_pool { if let Some(ref pool) = me_pool {
handle_via_middle_proxy( handle_via_middle_proxy(
client_reader, client_reader,
@ -725,6 +739,9 @@ impl RunningClientHandler {
buffer_pool, buffer_pool,
local_addr, local_addr,
rng, rng,
route_runtime.subscribe(),
route_snapshot,
session_id,
) )
.await .await
} else { } else {
@ -738,6 +755,9 @@ impl RunningClientHandler {
config, config,
buffer_pool, buffer_pool,
rng, rng,
route_runtime.subscribe(),
route_snapshot,
session_id,
) )
.await .await
} }
@ -752,6 +772,9 @@ impl RunningClientHandler {
config, config,
buffer_pool, buffer_pool,
rng, rng,
route_runtime.subscribe(),
route_snapshot,
session_id,
) )
.await .await
}; };

View File

@ -5,14 +5,19 @@ use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::watch;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::error::Result; use crate::error::{ProxyError, Result};
use crate::protocol::constants::*; use crate::protocol::constants::*;
use crate::proxy::handshake::{HandshakeSuccess, encrypt_tg_nonce_with_ciphers, generate_tg_nonce}; use crate::proxy::handshake::{HandshakeSuccess, encrypt_tg_nonce_with_ciphers, generate_tg_nonce};
use crate::proxy::relay::relay_bidirectional; use crate::proxy::relay::relay_bidirectional;
use crate::proxy::route_mode::{
RelayRouteMode, RouteCutoverState, ROUTE_SWITCH_ERROR_MSG, affected_cutover_state,
cutover_stagger_delay,
};
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::UpstreamManager; use crate::transport::UpstreamManager;
@ -26,6 +31,9 @@ pub(crate) async fn handle_via_direct<R, W>(
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
mut route_rx: watch::Receiver<RouteCutoverState>,
route_snapshot: RouteCutoverState,
session_id: u64,
) -> Result<()> ) -> Result<()>
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
@ -69,8 +77,36 @@ where
user, user,
Arc::clone(&stats), Arc::clone(&stats),
buffer_pool, buffer_pool,
) );
.await; tokio::pin!(relay_result);
let relay_result = loop {
if let Some(cutover) = affected_cutover_state(
&route_rx,
RelayRouteMode::Direct,
route_snapshot.generation,
) {
let delay = cutover_stagger_delay(session_id, cutover.generation);
warn!(
user = %user,
target_mode = cutover.mode.as_str(),
cutover_generation = cutover.generation,
delay_ms = delay.as_millis() as u64,
"Cutover affected direct session, closing client connection"
);
tokio::time::sleep(delay).await;
break Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
}
tokio::select! {
result = &mut relay_result => {
break result;
}
changed = route_rx.changed() => {
if changed.is_err() {
break relay_result.await;
}
}
}
};
stats.decrement_current_connections_direct(); stats.decrement_current_connections_direct();
stats.decrement_user_curr_connects(user); stats.decrement_user_curr_connects(user);

View File

@ -8,7 +8,7 @@ use std::time::{Duration, Instant};
use bytes::Bytes; use bytes::Bytes;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot, watch};
use tracing::{debug, trace, warn}; use tracing::{debug, trace, warn};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
@ -16,6 +16,10 @@ use crate::crypto::SecureRandom;
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::protocol::constants::{*, secure_padding_len}; use crate::protocol::constants::{*, secure_padding_len};
use crate::proxy::handshake::HandshakeSuccess; use crate::proxy::handshake::HandshakeSuccess;
use crate::proxy::route_mode::{
RelayRouteMode, RouteCutoverState, ROUTE_SWITCH_ERROR_MSG, affected_cutover_state,
cutover_stagger_delay,
};
use crate::stats::Stats; use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
@ -228,6 +232,9 @@ pub(crate) async fn handle_via_middle_proxy<R, W>(
_buffer_pool: Arc<BufferPool>, _buffer_pool: Arc<BufferPool>,
local_addr: SocketAddr, local_addr: SocketAddr,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
mut route_rx: watch::Receiver<RouteCutoverState>,
route_snapshot: RouteCutoverState,
session_id: u64,
) -> Result<()> ) -> Result<()>
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
@ -267,6 +274,27 @@ where
stats.increment_user_curr_connects(&user); stats.increment_user_curr_connects(&user);
stats.increment_current_connections_me(); stats.increment_current_connections_me();
if let Some(cutover) = affected_cutover_state(
&route_rx,
RelayRouteMode::Middle,
route_snapshot.generation,
) {
let delay = cutover_stagger_delay(session_id, cutover.generation);
warn!(
conn_id,
target_mode = cutover.mode.as_str(),
cutover_generation = cutover.generation,
delay_ms = delay.as_millis() as u64,
"Cutover affected middle session before relay start, closing client connection"
);
tokio::time::sleep(delay).await;
let _ = me_pool.send_close(conn_id).await;
me_pool.registry().unregister(conn_id).await;
stats.decrement_current_connections_me();
stats.decrement_user_curr_connects(&user);
return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
}
// Per-user ad_tag from access.user_ad_tags; fallback to general.ad_tag (hot-reloadable) // Per-user ad_tag from access.user_ad_tags; fallback to general.ad_tag (hot-reloadable)
let user_tag: Option<Vec<u8>> = config let user_tag: Option<Vec<u8>> = config
.access .access
@ -498,46 +526,75 @@ where
let mut main_result: Result<()> = Ok(()); let mut main_result: Result<()> = Ok(());
let mut client_closed = false; let mut client_closed = false;
let mut frame_counter: u64 = 0; let mut frame_counter: u64 = 0;
let mut route_watch_open = true;
loop { loop {
match read_client_payload( if let Some(cutover) = affected_cutover_state(
&mut crypto_reader, &route_rx,
proto_tag, RelayRouteMode::Middle,
frame_limit, route_snapshot.generation,
&forensics, ) {
&mut frame_counter, let delay = cutover_stagger_delay(session_id, cutover.generation);
&stats, warn!(
).await { conn_id,
Ok(Some((payload, quickack))) => { target_mode = cutover.mode.as_str(),
trace!(conn_id, bytes = payload.len(), "C->ME frame"); cutover_generation = cutover.generation,
forensics.bytes_c2me = forensics delay_ms = delay.as_millis() as u64,
.bytes_c2me "Cutover affected middle session, closing client connection"
.saturating_add(payload.len() as u64); );
stats.add_user_octets_from(&user, payload.len() as u64); tokio::time::sleep(delay).await;
let mut flags = proto_flags; let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await;
if quickack { main_result = Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
flags |= RPC_FLAG_QUICKACK; break;
} }
if payload.len() >= 8 && payload[..8].iter().all(|b| *b == 0) {
flags |= RPC_FLAG_NOT_ENCRYPTED; tokio::select! {
} changed = route_rx.changed(), if route_watch_open => {
// Keep client read loop lightweight: route heavy ME send path via a dedicated task. if changed.is_err() {
if enqueue_c2me_command(&c2me_tx, C2MeCommand::Data { payload, flags }) route_watch_open = false;
.await
.is_err()
{
main_result = Err(ProxyError::Proxy("ME sender channel closed".into()));
break;
} }
} }
Ok(None) => { payload_result = read_client_payload(
debug!(conn_id, "Client EOF"); &mut crypto_reader,
client_closed = true; proto_tag,
let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await; frame_limit,
break; &forensics,
} &mut frame_counter,
Err(e) => { &stats,
main_result = Err(e); ) => {
break; match payload_result {
Ok(Some((payload, quickack))) => {
trace!(conn_id, bytes = payload.len(), "C->ME frame");
forensics.bytes_c2me = forensics
.bytes_c2me
.saturating_add(payload.len() as u64);
stats.add_user_octets_from(&user, payload.len() as u64);
let mut flags = proto_flags;
if quickack {
flags |= RPC_FLAG_QUICKACK;
}
if payload.len() >= 8 && payload[..8].iter().all(|b| *b == 0) {
flags |= RPC_FLAG_NOT_ENCRYPTED;
}
// Keep client read loop lightweight: route heavy ME send path via a dedicated task.
if enqueue_c2me_command(&c2me_tx, C2MeCommand::Data { payload, flags })
.await
.is_err()
{
main_result = Err(ProxyError::Proxy("ME sender channel closed".into()));
break;
}
}
Ok(None) => {
debug!(conn_id, "Client EOF");
client_closed = true;
let _ = enqueue_c2me_command(&c2me_tx, C2MeCommand::Close).await;
break;
}
Err(e) => {
main_result = Err(e);
break;
}
}
} }
} }
} }

View File

@ -5,6 +5,7 @@ pub mod direct_relay;
pub mod handshake; pub mod handshake;
pub mod masking; pub mod masking;
pub mod middle_relay; pub mod middle_relay;
pub mod route_mode;
pub mod relay; pub mod relay;
pub use client::ClientHandler; pub use client::ClientHandler;

117
src/proxy/route_mode.rs Normal file
View File

@ -0,0 +1,117 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
use std::time::Duration;
use tokio::sync::watch;
pub(crate) const ROUTE_SWITCH_ERROR_MSG: &str = "Route mode switched by cutover";
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub(crate) enum RelayRouteMode {
Direct = 0,
Middle = 1,
}
impl RelayRouteMode {
pub(crate) fn as_u8(self) -> u8 {
self as u8
}
pub(crate) fn from_u8(value: u8) -> Self {
match value {
1 => Self::Middle,
_ => Self::Direct,
}
}
pub(crate) fn as_str(self) -> &'static str {
match self {
Self::Direct => "direct",
Self::Middle => "middle",
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) struct RouteCutoverState {
pub mode: RelayRouteMode,
pub generation: u64,
}
#[derive(Clone)]
pub(crate) struct RouteRuntimeController {
mode: Arc<AtomicU8>,
generation: Arc<AtomicU64>,
tx: watch::Sender<RouteCutoverState>,
}
impl RouteRuntimeController {
pub(crate) fn new(initial_mode: RelayRouteMode) -> Self {
let initial = RouteCutoverState {
mode: initial_mode,
generation: 0,
};
let (tx, _rx) = watch::channel(initial);
Self {
mode: Arc::new(AtomicU8::new(initial_mode.as_u8())),
generation: Arc::new(AtomicU64::new(0)),
tx,
}
}
pub(crate) fn snapshot(&self) -> RouteCutoverState {
RouteCutoverState {
mode: RelayRouteMode::from_u8(self.mode.load(Ordering::Relaxed)),
generation: self.generation.load(Ordering::Relaxed),
}
}
pub(crate) fn subscribe(&self) -> watch::Receiver<RouteCutoverState> {
self.tx.subscribe()
}
pub(crate) fn set_mode(&self, mode: RelayRouteMode) -> Option<RouteCutoverState> {
let previous = self.mode.swap(mode.as_u8(), Ordering::Relaxed);
if previous == mode.as_u8() {
return None;
}
let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1;
let next = RouteCutoverState { mode, generation };
self.tx.send_replace(next);
Some(next)
}
}
pub(crate) fn is_session_affected_by_cutover(
current: RouteCutoverState,
_session_mode: RelayRouteMode,
session_generation: u64,
) -> bool {
current.generation > session_generation
}
pub(crate) fn affected_cutover_state(
rx: &watch::Receiver<RouteCutoverState>,
session_mode: RelayRouteMode,
session_generation: u64,
) -> Option<RouteCutoverState> {
let current = *rx.borrow();
if is_session_affected_by_cutover(current, session_mode, session_generation) {
return Some(current);
}
None
}
pub(crate) fn cutover_stagger_delay(session_id: u64, generation: u64) -> Duration {
let mut value = session_id
^ generation.rotate_left(17)
^ 0x9e37_79b9_7f4a_7c15;
value ^= value >> 30;
value = value.wrapping_mul(0xbf58_476d_1ce4_e5b9);
value ^= value >> 27;
value = value.wrapping_mul(0x94d0_49bb_1331_11eb);
value ^= value >> 31;
let ms = 1000 + (value % 1000);
Duration::from_millis(ms)
}

View File

@ -828,10 +828,29 @@ impl MePool {
effective effective
} }
// Keeps per-contour (active/warm) writer budget bounded by CPU count.
// Baseline is 86 writers on the first core and +48 for each extra core.
fn adaptive_floor_cpu_budget_per_contour_cap(&self, cores: usize) -> usize {
const FIRST_CORE_WRITER_BUDGET: usize = 86;
const EXTRA_CORE_WRITER_BUDGET: usize = 48;
if cores == 0 {
return FIRST_CORE_WRITER_BUDGET;
}
FIRST_CORE_WRITER_BUDGET.saturating_add(
cores
.saturating_sub(1)
.saturating_mul(EXTRA_CORE_WRITER_BUDGET),
)
}
pub(super) fn adaptive_floor_active_cap_configured_total(&self) -> usize { pub(super) fn adaptive_floor_active_cap_configured_total(&self) -> usize {
let cores = self.adaptive_floor_effective_cpu_cores(); let cores = self.adaptive_floor_effective_cpu_cores();
let per_core_cap = cores.saturating_mul(self.adaptive_floor_max_active_writers_per_core()); let per_contour_budget = self.adaptive_floor_cpu_budget_per_contour_cap(cores);
let configured = per_core_cap.min(self.adaptive_floor_max_active_writers_global()); let configured = cores
.saturating_mul(self.adaptive_floor_max_active_writers_per_core())
.min(self.adaptive_floor_max_active_writers_global())
.min(per_contour_budget)
.max(1);
self.me_adaptive_floor_active_cap_configured self.me_adaptive_floor_active_cap_configured
.store(configured as u64, Ordering::Relaxed); .store(configured as u64, Ordering::Relaxed);
self.stats self.stats
@ -841,8 +860,12 @@ impl MePool {
pub(super) fn adaptive_floor_warm_cap_configured_total(&self) -> usize { pub(super) fn adaptive_floor_warm_cap_configured_total(&self) -> usize {
let cores = self.adaptive_floor_effective_cpu_cores(); let cores = self.adaptive_floor_effective_cpu_cores();
let per_core_cap = cores.saturating_mul(self.adaptive_floor_max_warm_writers_per_core()); let per_contour_budget = self.adaptive_floor_cpu_budget_per_contour_cap(cores);
let configured = per_core_cap.min(self.adaptive_floor_max_warm_writers_global()); let configured = cores
.saturating_mul(self.adaptive_floor_max_warm_writers_per_core())
.min(self.adaptive_floor_max_warm_writers_global())
.min(per_contour_budget)
.max(1);
self.me_adaptive_floor_warm_cap_configured self.me_adaptive_floor_warm_cap_configured
.store(configured as u64, Ordering::Relaxed); .store(configured as u64, Ordering::Relaxed);
self.stats self.stats