ME Pool Hardswap

This commit is contained in:
Alexey
2026-02-24 00:04:12 +03:00
parent 1f486e0df2
commit 0e2d42624f
13 changed files with 491 additions and 64 deletions

View File

@@ -131,6 +131,13 @@ pub async fn fetch_proxy_config(url: &str) -> Result<ProxyConfigData> {
}
async fn run_update_cycle(pool: &Arc<MePool>, rng: &Arc<SecureRandom>, cfg: &ProxyConfig) {
pool.update_runtime_reinit_policy(
cfg.general.hardswap,
cfg.general.me_pool_drain_ttl_secs,
cfg.general.effective_me_pool_force_close_secs(),
cfg.general.me_pool_min_fresh_ratio,
);
let mut maps_changed = false;
// Update proxy config v4
@@ -162,12 +169,7 @@ async fn run_update_cycle(pool: &Arc<MePool>, rng: &Arc<SecureRandom>, cfg: &Pro
}
if maps_changed {
let drain_timeout = if cfg.general.me_reinit_drain_timeout_secs == 0 {
None
} else {
Some(Duration::from_secs(cfg.general.me_reinit_drain_timeout_secs))
};
pool.zero_downtime_reinit_after_map_change(rng.as_ref(), drain_timeout)
pool.zero_downtime_reinit_after_map_change(rng.as_ref())
.await;
}
@@ -224,6 +226,12 @@ pub async fn me_config_updater(
break;
}
let cfg = config_rx.borrow().clone();
pool.update_runtime_reinit_policy(
cfg.general.hardswap,
cfg.general.me_pool_drain_ttl_secs,
cfg.general.effective_me_pool_force_close_secs(),
cfg.general.me_pool_min_fresh_ratio,
);
let new_secs = cfg.general.effective_update_every_secs().max(1);
if new_secs == update_every_secs {
continue;

View File

@@ -68,6 +68,7 @@ async fn check_family(
.read()
.await
.iter()
.filter(|w| !w.draining.load(std::sync::atomic::Ordering::Relaxed))
.map(|w| w.addr)
.collect();

View File

@@ -1,14 +1,14 @@
use std::collections::{HashMap, HashSet};
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, AtomicUsize, Ordering};
use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU32, AtomicU64, AtomicUsize, Ordering};
use bytes::BytesMut;
use rand::Rng;
use rand::seq::SliceRandom;
use tokio::sync::{Mutex, RwLock, mpsc, Notify};
use tokio_util::sync::CancellationToken;
use tracing::{debug, info, warn};
use std::time::{Duration, Instant};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use crate::crypto::SecureRandom;
use crate::error::{ProxyError, Result};
@@ -27,10 +27,13 @@ const ME_ACTIVE_PING_JITTER_SECS: i64 = 5;
pub struct MeWriter {
pub id: u64,
pub addr: SocketAddr,
pub generation: u64,
pub tx: mpsc::Sender<WriterCommand>,
pub cancel: CancellationToken,
pub degraded: Arc<AtomicBool>,
pub draining: Arc<AtomicBool>,
pub draining_started_at_epoch_secs: Arc<AtomicU64>,
pub allow_drain_fallback: Arc<AtomicBool>,
}
pub struct MePool {
@@ -73,6 +76,11 @@ pub struct MePool {
pub(super) writer_available: Arc<Notify>,
pub(super) conn_count: AtomicUsize,
pub(super) stats: Arc<crate::stats::Stats>,
pub(super) generation: AtomicU64,
pub(super) hardswap: AtomicBool,
pub(super) me_pool_drain_ttl_secs: AtomicU64,
pub(super) me_pool_force_close_secs: AtomicU64,
pub(super) me_pool_min_fresh_ratio_permille: AtomicU32,
pool_size: usize,
}
@@ -83,6 +91,22 @@ pub struct NatReflectionCache {
}
impl MePool {
fn ratio_to_permille(ratio: f32) -> u32 {
let clamped = ratio.clamp(0.0, 1.0);
(clamped * 1000.0).round() as u32
}
fn permille_to_ratio(permille: u32) -> f32 {
(permille.min(1000) as f32) / 1000.0
}
fn now_epoch_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
pub fn new(
proxy_tag: Option<Vec<u8>>,
proxy_secret: Vec<u8>,
@@ -110,6 +134,10 @@ impl MePool {
me_reconnect_backoff_base_ms: u64,
me_reconnect_backoff_cap_ms: u64,
me_reconnect_fast_retry_count: u32,
hardswap: bool,
me_pool_drain_ttl_secs: u64,
me_pool_force_close_secs: u64,
me_pool_min_fresh_ratio: f32,
) -> Arc<Self> {
Arc::new(Self {
registry: Arc::new(ConnRegistry::new()),
@@ -152,6 +180,11 @@ impl MePool {
nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())),
writer_available: Arc::new(Notify::new()),
conn_count: AtomicUsize::new(0),
generation: AtomicU64::new(1),
hardswap: AtomicBool::new(hardswap),
me_pool_drain_ttl_secs: AtomicU64::new(me_pool_drain_ttl_secs),
me_pool_force_close_secs: AtomicU64::new(me_pool_force_close_secs),
me_pool_min_fresh_ratio_permille: AtomicU32::new(Self::ratio_to_permille(me_pool_min_fresh_ratio)),
})
}
@@ -159,6 +192,25 @@ impl MePool {
self.proxy_tag.is_some()
}
pub fn current_generation(&self) -> u64 {
self.generation.load(Ordering::Relaxed)
}
pub fn update_runtime_reinit_policy(
&self,
hardswap: bool,
drain_ttl_secs: u64,
force_close_secs: u64,
min_fresh_ratio: f32,
) {
self.hardswap.store(hardswap, Ordering::Relaxed);
self.me_pool_drain_ttl_secs.store(drain_ttl_secs, Ordering::Relaxed);
self.me_pool_force_close_secs
.store(force_close_secs, Ordering::Relaxed);
self.me_pool_min_fresh_ratio_permille
.store(Self::ratio_to_permille(min_fresh_ratio), Ordering::Relaxed);
}
pub fn reset_stun_state(&self) {
self.nat_probe_attempts.store(0, Ordering::Relaxed);
self.nat_probe_disabled.store(false, Ordering::Relaxed);
@@ -177,6 +229,42 @@ impl MePool {
self.writers.clone()
}
fn force_close_timeout(&self) -> Option<Duration> {
let secs = self.me_pool_force_close_secs.load(Ordering::Relaxed);
if secs == 0 {
None
} else {
Some(Duration::from_secs(secs))
}
}
fn coverage_ratio(
desired_by_dc: &HashMap<i32, HashSet<SocketAddr>>,
active_writer_addrs: &HashSet<SocketAddr>,
) -> (f32, Vec<i32>) {
if desired_by_dc.is_empty() {
return (1.0, Vec::new());
}
let mut missing_dc = Vec::<i32>::new();
let mut covered = 0usize;
for (dc, endpoints) in desired_by_dc {
if endpoints.is_empty() {
continue;
}
if endpoints.iter().any(|addr| active_writer_addrs.contains(addr)) {
covered += 1;
} else {
missing_dc.push(*dc);
}
}
missing_dc.sort_unstable();
let total = desired_by_dc.len().max(1);
let ratio = (covered as f32) / (total as f32);
(ratio, missing_dc)
}
pub async fn reconcile_connections(self: &Arc<Self>, rng: &SecureRandom) {
let writers = self.writers.read().await;
let current: HashSet<SocketAddr> = writers
@@ -235,39 +323,104 @@ impl MePool {
out
}
async fn warmup_generation_for_all_dcs(
self: &Arc<Self>,
rng: &SecureRandom,
generation: u64,
desired_by_dc: &HashMap<i32, HashSet<SocketAddr>>,
) {
for endpoints in desired_by_dc.values() {
if endpoints.is_empty() {
continue;
}
let has_fresh = {
let ws = self.writers.read().await;
ws.iter().any(|w| {
!w.draining.load(Ordering::Relaxed)
&& w.generation == generation
&& endpoints.contains(&w.addr)
})
};
if has_fresh {
continue;
}
let mut shuffled: Vec<SocketAddr> = endpoints.iter().copied().collect();
shuffled.shuffle(&mut rand::rng());
for addr in shuffled {
if self.connect_one(addr, rng).await.is_ok() {
break;
}
}
}
}
pub async fn zero_downtime_reinit_after_map_change(
self: &Arc<Self>,
rng: &SecureRandom,
drain_timeout: Option<Duration>,
) {
// Stage 1: prewarm writers for new endpoint maps before draining old ones.
self.reconcile_connections(rng).await;
let desired_by_dc = self.desired_dc_endpoints().await;
if desired_by_dc.is_empty() {
warn!("ME endpoint map is empty after update; skipping stale writer drain");
return;
}
let previous_generation = self.current_generation();
let generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1;
let hardswap = self.hardswap.load(Ordering::Relaxed);
if hardswap {
self.warmup_generation_for_all_dcs(rng, generation, &desired_by_dc)
.await;
} else {
self.reconcile_connections(rng).await;
}
let writers = self.writers.read().await;
let active_writer_addrs: HashSet<SocketAddr> = writers
.iter()
.filter(|w| !w.draining.load(Ordering::Relaxed))
.map(|w| w.addr)
.collect();
let mut missing_dc = Vec::<i32>::new();
for (dc, endpoints) in &desired_by_dc {
if endpoints.is_empty() {
continue;
}
if !endpoints.iter().any(|addr| active_writer_addrs.contains(addr)) {
missing_dc.push(*dc);
}
let min_ratio = Self::permille_to_ratio(
self.me_pool_min_fresh_ratio_permille
.load(Ordering::Relaxed),
);
let (coverage_ratio, missing_dc) = Self::coverage_ratio(&desired_by_dc, &active_writer_addrs);
if !hardswap && coverage_ratio < min_ratio {
warn!(
previous_generation,
generation,
coverage_ratio = format_args!("{coverage_ratio:.3}"),
min_ratio = format_args!("{min_ratio:.3}"),
missing_dc = ?missing_dc,
"ME reinit coverage below threshold; keeping stale writers"
);
return;
}
if !missing_dc.is_empty() {
missing_dc.sort_unstable();
if hardswap {
let fresh_writer_addrs: HashSet<SocketAddr> = writers
.iter()
.filter(|w| !w.draining.load(Ordering::Relaxed))
.filter(|w| w.generation == generation)
.map(|w| w.addr)
.collect();
let (fresh_ratio, fresh_missing_dc) =
Self::coverage_ratio(&desired_by_dc, &fresh_writer_addrs);
if !fresh_missing_dc.is_empty() {
warn!(
previous_generation,
generation,
fresh_ratio = format_args!("{fresh_ratio:.3}"),
missing_dc = ?fresh_missing_dc,
"ME hardswap pending: fresh generation coverage incomplete"
);
return;
}
} else if !missing_dc.is_empty() {
warn!(
missing_dc = ?missing_dc,
// Keep stale writers alive when fresh coverage is incomplete.
@@ -284,7 +437,13 @@ impl MePool {
let stale_writer_ids: Vec<u64> = writers
.iter()
.filter(|w| !w.draining.load(Ordering::Relaxed))
.filter(|w| !desired_addrs.contains(&w.addr))
.filter(|w| {
if hardswap {
w.generation < generation
} else {
!desired_addrs.contains(&w.addr)
}
})
.map(|w| w.id)
.collect();
drop(writers);
@@ -294,14 +453,21 @@ impl MePool {
return;
}
let drain_timeout = self.force_close_timeout();
let drain_timeout_secs = drain_timeout.map(|d| d.as_secs()).unwrap_or(0);
info!(
stale_writers = stale_writer_ids.len(),
previous_generation,
generation,
hardswap,
coverage_ratio = format_args!("{coverage_ratio:.3}"),
min_ratio = format_args!("{min_ratio:.3}"),
drain_timeout_secs,
"ME map update covered; draining stale writers"
);
self.stats.increment_pool_swap_total();
for writer_id in stale_writer_ids {
self.mark_writer_draining_with_timeout(writer_id, drain_timeout)
self.mark_writer_draining_with_timeout(writer_id, drain_timeout, !hardswap)
.await;
}
}
@@ -507,9 +673,12 @@ impl MePool {
let hs = self.handshake_only(stream, addr, rng).await?;
let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed);
let generation = self.current_generation();
let cancel = CancellationToken::new();
let degraded = Arc::new(AtomicBool::new(false));
let draining = Arc::new(AtomicBool::new(false));
let draining_started_at_epoch_secs = Arc::new(AtomicU64::new(0));
let allow_drain_fallback = Arc::new(AtomicBool::new(false));
let (tx, mut rx) = mpsc::channel::<WriterCommand>(4096);
let mut rpc_writer = RpcWriter {
writer: hs.wr,
@@ -540,10 +709,13 @@ impl MePool {
let writer = MeWriter {
id: writer_id,
addr,
generation,
tx: tx.clone(),
cancel: cancel.clone(),
degraded: degraded.clone(),
draining: draining.clone(),
draining_started_at_epoch_secs: draining_started_at_epoch_secs.clone(),
allow_drain_fallback: allow_drain_fallback.clone(),
};
self.writers.write().await.push(writer.clone());
self.conn_count.fetch_add(1, Ordering::Relaxed);
@@ -715,6 +887,9 @@ impl MePool {
let mut ws = self.writers.write().await;
if let Some(pos) = ws.iter().position(|w| w.id == writer_id) {
let w = ws.remove(pos);
if w.draining.load(Ordering::Relaxed) {
self.stats.decrement_pool_drain_active();
}
w.cancel.cancel();
close_tx = Some(w.tx.clone());
self.conn_count.fetch_sub(1, Ordering::Relaxed);
@@ -731,11 +906,20 @@ impl MePool {
self: &Arc<Self>,
writer_id: u64,
timeout: Option<Duration>,
allow_drain_fallback: bool,
) {
let timeout = timeout.filter(|d| !d.is_zero());
let found = {
let mut ws = self.writers.write().await;
if let Some(w) = ws.iter_mut().find(|w| w.id == writer_id) {
let already_draining = w.draining.swap(true, Ordering::Relaxed);
w.allow_drain_fallback
.store(allow_drain_fallback, Ordering::Relaxed);
w.draining_started_at_epoch_secs
.store(Self::now_epoch_secs(), Ordering::Relaxed);
if !already_draining {
self.stats.increment_pool_drain_active();
}
w.draining.store(true, Ordering::Relaxed);
true
} else {
@@ -748,7 +932,12 @@ impl MePool {
}
let timeout_secs = timeout.map(|d| d.as_secs()).unwrap_or(0);
debug!(writer_id, timeout_secs, "ME writer marked draining");
debug!(
writer_id,
timeout_secs,
allow_drain_fallback,
"ME writer marked draining"
);
let pool = Arc::downgrade(self);
tokio::spawn(async move {
@@ -758,6 +947,7 @@ impl MePool {
if let Some(deadline_at) = deadline {
if Instant::now() >= deadline_at {
warn!(writer_id, "Drain timeout, force-closing");
p.stats.increment_pool_force_close_total();
let _ = p.remove_writer_and_close_clients(writer_id).await;
break;
}
@@ -775,10 +965,31 @@ impl MePool {
}
pub(crate) async fn mark_writer_draining(self: &Arc<Self>, writer_id: u64) {
self.mark_writer_draining_with_timeout(writer_id, Some(Duration::from_secs(300)))
self.mark_writer_draining_with_timeout(writer_id, Some(Duration::from_secs(300)), false)
.await;
}
pub(super) fn writer_accepts_new_binding(&self, writer: &MeWriter) -> bool {
if !writer.draining.load(Ordering::Relaxed) {
return true;
}
if !writer.allow_drain_fallback.load(Ordering::Relaxed) {
return false;
}
let ttl_secs = self.me_pool_drain_ttl_secs.load(Ordering::Relaxed);
if ttl_secs == 0 {
return true;
}
let started = writer.draining_started_at_epoch_secs.load(Ordering::Relaxed);
if started == 0 {
return false;
}
Self::now_epoch_secs().saturating_sub(started) <= ttl_secs
}
}
fn hex_dump(data: &[u8]) -> String {

View File

@@ -134,8 +134,8 @@ impl MePool {
candidate_indices.sort_by_key(|idx| {
let w = &writers_snapshot[*idx];
let degraded = w.degraded.load(Ordering::Relaxed);
let draining = w.draining.load(Ordering::Relaxed);
(draining as usize, degraded as usize)
let stale = (w.generation < self.current_generation()) as usize;
(stale, degraded as usize)
});
let start = self.rr.fetch_add(1, Ordering::Relaxed) as usize % candidate_indices.len();
@@ -143,13 +143,23 @@ impl MePool {
for offset in 0..candidate_indices.len() {
let idx = candidate_indices[(start + offset) % candidate_indices.len()];
let w = &writers_snapshot[idx];
if w.draining.load(Ordering::Relaxed) {
if !self.writer_accepts_new_binding(w) {
continue;
}
if w.tx.send(WriterCommand::Data(payload.clone())).await.is_ok() {
self.registry
.bind_writer(conn_id, w.id, w.tx.clone(), meta.clone())
.await;
if w.generation < self.current_generation() {
self.stats.increment_pool_stale_pick_total();
debug!(
conn_id,
writer_id = w.id,
writer_generation = w.generation,
current_generation = self.current_generation(),
"Selected stale ME writer for fallback bind"
);
}
return Ok(());
} else {
warn!(writer_id = w.id, "ME writer channel closed");
@@ -159,7 +169,7 @@ impl MePool {
}
let w = writers_snapshot[candidate_indices[start]].clone();
if w.draining.load(Ordering::Relaxed) {
if !self.writer_accepts_new_binding(&w) {
continue;
}
match w.tx.send(WriterCommand::Data(payload.clone())).await {
@@ -167,6 +177,9 @@ impl MePool {
self.registry
.bind_writer(conn_id, w.id, w.tx.clone(), meta.clone())
.await;
if w.generation < self.current_generation() {
self.stats.increment_pool_stale_pick_total();
}
return Ok(());
}
Err(_) => {
@@ -245,13 +258,13 @@ impl MePool {
if preferred.is_empty() {
return (0..writers.len())
.filter(|i| !writers[*i].draining.load(Ordering::Relaxed))
.filter(|i| self.writer_accepts_new_binding(&writers[*i]))
.collect();
}
let mut out = Vec::new();
for (idx, w) in writers.iter().enumerate() {
if w.draining.load(Ordering::Relaxed) {
if !self.writer_accepts_new_binding(w) {
continue;
}
if preferred.iter().any(|p| *p == w.addr) {
@@ -260,7 +273,7 @@ impl MePool {
}
if out.is_empty() {
return (0..writers.len())
.filter(|i| !writers[*i].draining.load(Ordering::Relaxed))
.filter(|i| self.writer_accepts_new_binding(&writers[*i]))
.collect();
}
out