From e589891706005daf8bbb8d2bb067432ca944ef64 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Mon, 2 Mar 2026 20:41:51 +0300 Subject: [PATCH] ME Dual-Trio Pool Drafts --- src/metrics.rs | 94 ++++++++++++++++++++ src/stats/mod.rs | 79 +++++++++++++++++ src/transport/middle_proxy/handshake.rs | 49 +++++++++++ src/transport/middle_proxy/pool.rs | 48 +++++++++- src/transport/middle_proxy/pool_refill.rs | 98 ++++++++++++++++++++- src/transport/middle_proxy/pool_reinit.rs | 102 ++++++++++++++++++++-- src/transport/middle_proxy/pool_writer.rs | 25 +++++- src/transport/middle_proxy/reader.rs | 1 + src/transport/middle_proxy/send.rs | 58 ++++++++++-- 9 files changed, 534 insertions(+), 20 deletions(-) diff --git a/src/metrics.rs b/src/metrics.rs index 35f29ca..fcbd03c 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -274,6 +274,43 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); + let _ = writeln!(out, "# HELP telemt_me_handshake_reject_total ME handshake rejects from upstream"); + let _ = writeln!(out, "# TYPE telemt_me_handshake_reject_total counter"); + let _ = writeln!( + out, + "telemt_me_handshake_reject_total {}", + if me_allows_normal { + stats.get_me_handshake_reject_total() + } else { + 0 + } + ); + + let _ = writeln!(out, "# HELP telemt_me_handshake_error_code_total ME handshake reject errors by code"); + let _ = writeln!(out, "# TYPE telemt_me_handshake_error_code_total counter"); + if me_allows_normal { + for (error_code, count) in stats.get_me_handshake_error_code_counts() { + let _ = writeln!( + out, + "telemt_me_handshake_error_code_total{{error_code=\"{}\"}} {}", + error_code, + count + ); + } + } + + let _ = writeln!(out, "# HELP telemt_me_reader_eof_total ME reader EOF terminations"); + let _ = writeln!(out, "# TYPE telemt_me_reader_eof_total counter"); + let _ = writeln!( + out, + "telemt_me_reader_eof_total {}", + if me_allows_normal { + stats.get_me_reader_eof_total() + } else { + 0 + } + ); + let _ = writeln!(out, "# HELP telemt_me_crc_mismatch_total ME CRC mismatches"); let _ = writeln!(out, "# TYPE telemt_me_crc_mismatch_total counter"); let _ = writeln!( @@ -385,6 +422,63 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp } ); + let _ = writeln!( + out, + "# HELP telemt_me_endpoint_quarantine_total ME endpoint quarantines due to rapid flaps" + ); + let _ = writeln!(out, "# TYPE telemt_me_endpoint_quarantine_total counter"); + let _ = writeln!( + out, + "telemt_me_endpoint_quarantine_total {}", + if me_allows_normal { + stats.get_me_endpoint_quarantine_total() + } else { + 0 + } + ); + + let _ = writeln!(out, "# HELP telemt_me_kdf_drift_total ME KDF input drift detections"); + let _ = writeln!(out, "# TYPE telemt_me_kdf_drift_total counter"); + let _ = writeln!( + out, + "telemt_me_kdf_drift_total {}", + if me_allows_normal { + stats.get_me_kdf_drift_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_hardswap_pending_reuse_total Hardswap cycles that reused an existing pending generation" + ); + let _ = writeln!(out, "# TYPE telemt_me_hardswap_pending_reuse_total counter"); + let _ = writeln!( + out, + "telemt_me_hardswap_pending_reuse_total {}", + if me_allows_debug { + stats.get_me_hardswap_pending_reuse_total() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_hardswap_pending_ttl_expired_total Pending hardswap generations reset by TTL expiration" + ); + let _ = writeln!(out, "# TYPE telemt_me_hardswap_pending_ttl_expired_total counter"); + let _ = writeln!( + out, + "telemt_me_hardswap_pending_ttl_expired_total {}", + if me_allows_normal { + stats.get_me_hardswap_pending_ttl_expired_total() + } else { + 0 + } + ); + let _ = writeln!(out, "# HELP telemt_secure_padding_invalid_total Invalid secure frame lengths"); let _ = writeln!(out, "# TYPE telemt_secure_padding_invalid_total counter"); let _ = writeln!( diff --git a/src/stats/mod.rs b/src/stats/mod.rs index f5aa2b7..453a73a 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -32,8 +32,15 @@ pub struct Stats { me_keepalive_timeout: AtomicU64, me_reconnect_attempts: AtomicU64, me_reconnect_success: AtomicU64, + me_handshake_reject_total: AtomicU64, + me_reader_eof_total: AtomicU64, me_crc_mismatch: AtomicU64, me_seq_mismatch: AtomicU64, + me_endpoint_quarantine_total: AtomicU64, + me_kdf_drift_total: AtomicU64, + me_hardswap_pending_reuse_total: AtomicU64, + me_hardswap_pending_ttl_expired_total: AtomicU64, + me_handshake_error_codes: DashMap, me_route_drop_no_conn: AtomicU64, me_route_drop_channel_closed: AtomicU64, me_route_drop_queue_full: AtomicU64, @@ -172,6 +179,26 @@ impl Stats { self.me_reconnect_success.fetch_add(1, Ordering::Relaxed); } } + pub fn increment_me_handshake_reject_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_handshake_reject_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_handshake_error_code(&self, code: i32) { + if !self.telemetry_me_allows_normal() { + return; + } + let entry = self + .me_handshake_error_codes + .entry(code) + .or_insert_with(|| AtomicU64::new(0)); + entry.fetch_add(1, Ordering::Relaxed); + } + pub fn increment_me_reader_eof_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_reader_eof_total.fetch_add(1, Ordering::Relaxed); + } + } pub fn increment_me_crc_mismatch(&self) { if self.telemetry_me_allows_normal() { self.me_crc_mismatch.fetch_add(1, Ordering::Relaxed); @@ -333,6 +360,29 @@ impl Stats { .fetch_add(1, Ordering::Relaxed); } } + pub fn increment_me_endpoint_quarantine_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_endpoint_quarantine_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_kdf_drift_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_kdf_drift_total.fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_hardswap_pending_reuse_total(&self) { + if self.telemetry_me_allows_debug() { + self.me_hardswap_pending_reuse_total + .fetch_add(1, Ordering::Relaxed); + } + } + pub fn increment_me_hardswap_pending_ttl_expired_total(&self) { + if self.telemetry_me_allows_normal() { + self.me_hardswap_pending_ttl_expired_total + .fetch_add(1, Ordering::Relaxed); + } + } pub fn get_connects_all(&self) -> u64 { self.connects_all.load(Ordering::Relaxed) } pub fn get_connects_bad(&self) -> u64 { self.connects_bad.load(Ordering::Relaxed) } pub fn get_me_keepalive_sent(&self) -> u64 { self.me_keepalive_sent.load(Ordering::Relaxed) } @@ -341,8 +391,37 @@ impl Stats { pub fn get_me_keepalive_timeout(&self) -> u64 { self.me_keepalive_timeout.load(Ordering::Relaxed) } pub fn get_me_reconnect_attempts(&self) -> u64 { self.me_reconnect_attempts.load(Ordering::Relaxed) } pub fn get_me_reconnect_success(&self) -> u64 { self.me_reconnect_success.load(Ordering::Relaxed) } + pub fn get_me_handshake_reject_total(&self) -> u64 { + self.me_handshake_reject_total.load(Ordering::Relaxed) + } + pub fn get_me_reader_eof_total(&self) -> u64 { + self.me_reader_eof_total.load(Ordering::Relaxed) + } pub fn get_me_crc_mismatch(&self) -> u64 { self.me_crc_mismatch.load(Ordering::Relaxed) } pub fn get_me_seq_mismatch(&self) -> u64 { self.me_seq_mismatch.load(Ordering::Relaxed) } + pub fn get_me_endpoint_quarantine_total(&self) -> u64 { + self.me_endpoint_quarantine_total.load(Ordering::Relaxed) + } + pub fn get_me_kdf_drift_total(&self) -> u64 { + self.me_kdf_drift_total.load(Ordering::Relaxed) + } + pub fn get_me_hardswap_pending_reuse_total(&self) -> u64 { + self.me_hardswap_pending_reuse_total + .load(Ordering::Relaxed) + } + pub fn get_me_hardswap_pending_ttl_expired_total(&self) -> u64 { + self.me_hardswap_pending_ttl_expired_total + .load(Ordering::Relaxed) + } + pub fn get_me_handshake_error_code_counts(&self) -> Vec<(i32, u64)> { + let mut out: Vec<(i32, u64)> = self + .me_handshake_error_codes + .iter() + .map(|entry| (*entry.key(), entry.value().load(Ordering::Relaxed))) + .collect(); + out.sort_by_key(|(code, _)| *code); + out + } pub fn get_me_route_drop_no_conn(&self) -> u64 { self.me_route_drop_no_conn.load(Ordering::Relaxed) } pub fn get_me_route_drop_channel_closed(&self) -> u64 { self.me_route_drop_channel_closed.load(Ordering::Relaxed) diff --git a/src/transport/middle_proxy/handshake.rs b/src/transport/middle_proxy/handshake.rs index 5daa460..251c911 100644 --- a/src/transport/middle_proxy/handshake.rs +++ b/src/transport/middle_proxy/handshake.rs @@ -1,6 +1,8 @@ use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::Ordering; use std::time::{Duration, Instant}; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; use socket2::{SockRef, TcpKeepalive}; #[cfg(target_os = "linux")] use libc; @@ -34,6 +36,8 @@ use super::codec::{ use super::wire::{extract_ip_material, IpMaterial}; use super::MePool; +const ME_KDF_DRIFT_STRICT: bool = false; + /// Result of a successful ME handshake with timings. pub(crate) struct HandshakeOutput { pub rd: ReadHalf, @@ -47,6 +51,22 @@ pub(crate) struct HandshakeOutput { } impl MePool { + fn kdf_material_fingerprint( + local_addr_nat: SocketAddr, + peer_addr_nat: SocketAddr, + client_port_for_kdf: u16, + reflected: Option, + socks_bound_addr: Option, + ) -> u64 { + let mut hasher = DefaultHasher::new(); + local_addr_nat.hash(&mut hasher); + peer_addr_nat.hash(&mut hasher); + client_port_for_kdf.hash(&mut hasher); + reflected.hash(&mut hasher); + socks_bound_addr.hash(&mut hasher); + hasher.finish() + } + async fn resolve_dc_idx_for_endpoint(&self, addr: SocketAddr) -> Option { if addr.is_ipv4() { let map = self.proxy_map_v4.read().await; @@ -343,6 +363,33 @@ impl MePool { .map(|bound| bound.port()) .filter(|port| *port != 0) .unwrap_or(local_addr_nat.port()); + let kdf_fingerprint = Self::kdf_material_fingerprint( + local_addr_nat, + peer_addr_nat, + client_port_for_kdf, + reflected, + socks_bound_addr, + ); + let mut kdf_fingerprint_guard = self.kdf_material_fingerprint.lock().await; + if let Some(prev_fingerprint) = kdf_fingerprint_guard.get(&peer_addr_nat).copied() + && prev_fingerprint != kdf_fingerprint + { + self.stats.increment_me_kdf_drift_total(); + warn!( + %peer_addr_nat, + %local_addr_nat, + client_port_for_kdf, + "ME KDF input drift detected for endpoint" + ); + if ME_KDF_DRIFT_STRICT { + return Err(ProxyError::InvalidHandshake( + "ME KDF input drift detected (strict mode)".to_string(), + )); + } + } + kdf_fingerprint_guard.insert(peer_addr_nat, kdf_fingerprint); + drop(kdf_fingerprint_guard); + let client_port_bytes = client_port_for_kdf.to_le_bytes(); let server_ip = extract_ip_material(peer_addr_nat); @@ -540,6 +587,8 @@ impl MePool { } else { -1 }; + self.stats.increment_me_handshake_reject_total(); + self.stats.increment_me_handshake_error_code(err_code); return Err(ProxyError::InvalidHandshake(format!( "ME rejected handshake (error={err_code})" ))); diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 92e83bc..4a5598a 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -16,11 +16,18 @@ use crate::transport::UpstreamManager; use super::ConnRegistry; use super::codec::WriterCommand; +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub(super) struct RefillDcKey { + pub dc: i32, + pub family: IpFamily, +} + #[derive(Clone)] pub struct MeWriter { pub id: u64, pub addr: SocketAddr, pub generation: u64, + pub contour: Arc, pub created_at: Instant, pub tx: mpsc::Sender, pub cancel: CancellationToken, @@ -30,6 +37,29 @@ pub struct MeWriter { pub allow_drain_fallback: Arc, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub(super) enum WriterContour { + Warm = 0, + Active = 1, + Draining = 2, +} + +impl WriterContour { + pub(super) fn as_u8(self) -> u8 { + self as u8 + } + + pub(super) fn from_u8(value: u8) -> Self { + match value { + 0 => Self::Warm, + 1 => Self::Active, + 2 => Self::Draining, + _ => Self::Draining, + } + } +} + #[derive(Debug, Clone)] pub struct SecretSnapshot { pub epoch: u64, @@ -80,12 +110,18 @@ pub struct MePool { pub(super) nat_reflection_cache: Arc>, pub(super) writer_available: Arc, pub(super) refill_inflight: Arc>>, + pub(super) refill_inflight_dc: Arc>>, pub(super) conn_count: AtomicUsize, pub(super) stats: Arc, pub(super) generation: AtomicU64, + pub(super) active_generation: AtomicU64, + pub(super) warm_generation: AtomicU64, pub(super) pending_hardswap_generation: AtomicU64, + pub(super) pending_hardswap_started_at_epoch_secs: AtomicU64, + pub(super) pending_hardswap_map_hash: AtomicU64, pub(super) hardswap: AtomicBool, pub(super) endpoint_quarantine: Arc>>, + pub(super) kdf_material_fingerprint: Arc>>, pub(super) me_pool_drain_ttl_secs: AtomicU64, pub(super) me_pool_force_close_secs: AtomicU64, pub(super) me_pool_min_fresh_ratio_permille: AtomicU32, @@ -233,11 +269,17 @@ impl MePool { nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())), writer_available: Arc::new(Notify::new()), refill_inflight: Arc::new(Mutex::new(HashSet::new())), + refill_inflight_dc: Arc::new(Mutex::new(HashSet::new())), conn_count: AtomicUsize::new(0), generation: AtomicU64::new(1), + active_generation: AtomicU64::new(1), + warm_generation: AtomicU64::new(0), pending_hardswap_generation: AtomicU64::new(0), + pending_hardswap_started_at_epoch_secs: AtomicU64::new(0), + pending_hardswap_map_hash: AtomicU64::new(0), hardswap: AtomicBool::new(hardswap), endpoint_quarantine: Arc::new(Mutex::new(HashMap::new())), + kdf_material_fingerprint: Arc::new(Mutex::new(HashMap::new())), me_pool_drain_ttl_secs: AtomicU64::new(me_pool_drain_ttl_secs), me_pool_force_close_secs: AtomicU64::new(me_pool_force_close_secs), me_pool_min_fresh_ratio_permille: AtomicU32::new(Self::ratio_to_permille( @@ -258,7 +300,11 @@ impl MePool { } pub fn current_generation(&self) -> u64 { - self.generation.load(Ordering::Relaxed) + self.active_generation.load(Ordering::Relaxed) + } + + pub(super) fn warm_generation(&self) -> u64 { + self.warm_generation.load(Ordering::Relaxed) } pub fn update_runtime_reinit_policy( diff --git a/src/transport/middle_proxy/pool_refill.rs b/src/transport/middle_proxy/pool_refill.rs index a286e65..92071bd 100644 --- a/src/transport/middle_proxy/pool_refill.rs +++ b/src/transport/middle_proxy/pool_refill.rs @@ -7,8 +7,9 @@ use std::time::{Duration, Instant}; use tracing::{debug, info, warn}; use crate::crypto::SecureRandom; +use crate::network::IpFamily; -use super::pool::MePool; +use super::pool::{MePool, RefillDcKey, WriterContour}; const ME_FLAP_UPTIME_THRESHOLD_SECS: u64 = 20; const ME_FLAP_QUARANTINE_SECS: u64 = 25; @@ -27,6 +28,7 @@ impl MePool { let mut guard = self.endpoint_quarantine.lock().await; guard.retain(|_, expiry| *expiry > Instant::now()); guard.insert(addr, until); + self.stats.increment_me_endpoint_quarantine_total(); warn!( %addr, uptime_ms = uptime.as_millis(), @@ -84,14 +86,76 @@ impl MePool { if endpoints.is_empty() { return false; } - let guard = self.refill_inflight.lock().await; - endpoints.iter().any(|addr| guard.contains(addr)) + + { + let guard = self.refill_inflight.lock().await; + if endpoints.iter().any(|addr| guard.contains(addr)) { + return true; + } + } + + let dc_keys = self.resolve_refill_dc_keys_for_endpoints(endpoints).await; + if dc_keys.is_empty() { + return false; + } + let guard = self.refill_inflight_dc.lock().await; + dc_keys.iter().any(|key| guard.contains(key)) + } + + async fn resolve_refill_dc_key_for_addr(&self, addr: SocketAddr) -> Option { + let family = if addr.is_ipv4() { + IpFamily::V4 + } else { + IpFamily::V6 + }; + let map = self.proxy_map_for_family(family).await; + for (dc, endpoints) in map { + if endpoints + .into_iter() + .any(|(ip, port)| SocketAddr::new(ip, port) == addr) + { + return Some(RefillDcKey { + dc: dc.abs(), + family, + }); + } + } + None + } + + async fn resolve_refill_dc_keys_for_endpoints( + &self, + endpoints: &[SocketAddr], + ) -> HashSet { + let mut out = HashSet::::new(); + for addr in endpoints { + if let Some(key) = self.resolve_refill_dc_key_for_addr(*addr).await { + out.insert(key); + } + } + out } pub(super) async fn connect_endpoints_round_robin( self: &Arc, endpoints: &[SocketAddr], rng: &SecureRandom, + ) -> bool { + self.connect_endpoints_round_robin_with_generation_contour( + endpoints, + rng, + self.current_generation(), + WriterContour::Active, + ) + .await + } + + pub(super) async fn connect_endpoints_round_robin_with_generation_contour( + self: &Arc, + endpoints: &[SocketAddr], + rng: &SecureRandom, + generation: u64, + contour: WriterContour, ) -> bool { let candidates = self.connectable_endpoints(endpoints).await; if candidates.is_empty() { @@ -101,7 +165,10 @@ impl MePool { for offset in 0..candidates.len() { let idx = (start + offset) % candidates.len(); let addr = candidates[idx]; - match self.connect_one(addr, rng).await { + match self + .connect_one_with_generation_contour(addr, rng, generation, contour) + .await + { Ok(()) => return true, Err(e) => debug!(%addr, error = %e, "ME connect failed during round-robin warmup"), } @@ -225,6 +292,9 @@ impl MePool { pub(crate) fn trigger_immediate_refill(self: &Arc, addr: SocketAddr) { let pool = Arc::clone(self); tokio::spawn(async move { + let dc_endpoints = pool.endpoints_for_same_dc(addr).await; + let dc_keys = pool.resolve_refill_dc_keys_for_endpoints(&dc_endpoints).await; + { let mut guard = pool.refill_inflight.lock().await; if !guard.insert(addr) { @@ -232,6 +302,19 @@ impl MePool { return; } } + + if !dc_keys.is_empty() { + let mut dc_guard = pool.refill_inflight_dc.lock().await; + if dc_keys.iter().any(|key| dc_guard.contains(key)) { + pool.stats.increment_me_refill_skipped_inflight_total(); + drop(dc_guard); + let mut guard = pool.refill_inflight.lock().await; + guard.remove(&addr); + return; + } + dc_guard.extend(dc_keys.iter().copied()); + } + pool.stats.increment_me_refill_triggered_total(); let restored = pool.refill_writer_after_loss(addr).await; @@ -241,6 +324,13 @@ impl MePool { let mut guard = pool.refill_inflight.lock().await; guard.remove(&addr); + drop(guard); + if !dc_keys.is_empty() { + let mut dc_guard = pool.refill_inflight_dc.lock().await; + for key in &dc_keys { + dc_guard.remove(key); + } + } }); } } diff --git a/src/transport/middle_proxy/pool_reinit.rs b/src/transport/middle_proxy/pool_reinit.rs index 5552fb6..33b8cc4 100644 --- a/src/transport/middle_proxy/pool_reinit.rs +++ b/src/transport/middle_proxy/pool_reinit.rs @@ -1,4 +1,5 @@ use std::collections::{HashMap, HashSet}; +use std::hash::{Hash, Hasher}; use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::Ordering; @@ -7,12 +8,58 @@ use std::time::Duration; use rand::Rng; use rand::seq::SliceRandom; use tracing::{debug, info, warn}; +use std::collections::hash_map::DefaultHasher; use crate::crypto::SecureRandom; -use super::pool::MePool; +use super::pool::{MePool, WriterContour}; + +const ME_HARDSWAP_PENDING_TTL_SECS: u64 = 1800; impl MePool { + fn desired_map_hash(desired_by_dc: &HashMap>) -> u64 { + let mut hasher = DefaultHasher::new(); + let mut dcs: Vec = desired_by_dc.keys().copied().collect(); + dcs.sort_unstable(); + for dc in dcs { + dc.hash(&mut hasher); + let mut endpoints: Vec = desired_by_dc + .get(&dc) + .map(|set| set.iter().copied().collect()) + .unwrap_or_default(); + endpoints.sort_unstable(); + for endpoint in endpoints { + endpoint.hash(&mut hasher); + } + } + hasher.finish() + } + + fn clear_pending_hardswap_state(&self) { + self.pending_hardswap_generation.store(0, Ordering::Relaxed); + self.pending_hardswap_started_at_epoch_secs + .store(0, Ordering::Relaxed); + self.pending_hardswap_map_hash.store(0, Ordering::Relaxed); + self.warm_generation.store(0, Ordering::Relaxed); + } + + async fn promote_warm_generation_to_active(&self, generation: u64) { + self.active_generation.store(generation, Ordering::Relaxed); + self.warm_generation.store(0, Ordering::Relaxed); + + let ws = self.writers.read().await; + for writer in ws.iter() { + if writer.draining.load(Ordering::Relaxed) { + continue; + } + if writer.generation == generation { + writer + .contour + .store(WriterContour::Active.as_u8(), Ordering::Relaxed); + } + } + } + fn coverage_ratio( desired_by_dc: &HashMap>, active_writer_addrs: &HashSet, @@ -202,7 +249,14 @@ impl MePool { let delay_ms = self.hardswap_warmup_connect_delay_ms(); tokio::time::sleep(Duration::from_millis(delay_ms)).await; - let connected = self.connect_endpoints_round_robin(&endpoint_list, rng).await; + let connected = self + .connect_endpoints_round_robin_with_generation_contour( + &endpoint_list, + rng, + generation, + WriterContour::Warm, + ) + .await; debug!( dc = *dc, pass = pass_idx + 1, @@ -265,29 +319,61 @@ impl MePool { return; } + let desired_map_hash = Self::desired_map_hash(&desired_by_dc); + let now_epoch_secs = Self::now_epoch_secs(); let previous_generation = self.current_generation(); let hardswap = self.hardswap.load(Ordering::Relaxed); let generation = if hardswap { let pending_generation = self.pending_hardswap_generation.load(Ordering::Relaxed); - if pending_generation != 0 && pending_generation >= previous_generation { + let pending_started_at = self + .pending_hardswap_started_at_epoch_secs + .load(Ordering::Relaxed); + let pending_map_hash = self.pending_hardswap_map_hash.load(Ordering::Relaxed); + let pending_age_secs = now_epoch_secs.saturating_sub(pending_started_at); + let pending_ttl_expired = pending_started_at > 0 && pending_age_secs > ME_HARDSWAP_PENDING_TTL_SECS; + let pending_matches_map = pending_map_hash != 0 && pending_map_hash == desired_map_hash; + + if pending_generation != 0 + && pending_generation >= previous_generation + && pending_matches_map + && !pending_ttl_expired + { + self.stats.increment_me_hardswap_pending_reuse_total(); debug!( previous_generation, generation = pending_generation, + pending_age_secs, "ME hardswap continues with pending generation" ); pending_generation } else { + if pending_generation != 0 && pending_ttl_expired { + self.stats.increment_me_hardswap_pending_ttl_expired_total(); + warn!( + previous_generation, + generation = pending_generation, + pending_age_secs, + pending_ttl_secs = ME_HARDSWAP_PENDING_TTL_SECS, + "ME hardswap pending generation expired by TTL; starting fresh generation" + ); + } let next_generation = self.generation.fetch_add(1, Ordering::Relaxed) + 1; self.pending_hardswap_generation .store(next_generation, Ordering::Relaxed); + self.pending_hardswap_started_at_epoch_secs + .store(now_epoch_secs, Ordering::Relaxed); + self.pending_hardswap_map_hash + .store(desired_map_hash, Ordering::Relaxed); + self.warm_generation.store(next_generation, Ordering::Relaxed); next_generation } } else { - self.pending_hardswap_generation.store(0, Ordering::Relaxed); + self.clear_pending_hardswap_state(); self.generation.fetch_add(1, Ordering::Relaxed) + 1 }; if hardswap { + self.warm_generation.store(generation, Ordering::Relaxed); self.warmup_generation_for_all_dcs(rng, generation, &desired_by_dc) .await; } else { @@ -352,6 +438,10 @@ impl MePool { return; } + if hardswap { + self.promote_warm_generation_to_active(generation).await; + } + let desired_addrs: HashSet = desired_by_dc .values() .flat_map(|set| set.iter().copied()) @@ -373,7 +463,7 @@ impl MePool { if stale_writer_ids.is_empty() { if hardswap { - self.pending_hardswap_generation.store(0, Ordering::Relaxed); + self.clear_pending_hardswap_state(); } debug!("ME reinit cycle completed with no stale writers"); return; @@ -397,7 +487,7 @@ impl MePool { .await; } if hardswap { - self.pending_hardswap_generation.store(0, Ordering::Relaxed); + self.clear_pending_hardswap_state(); } } diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 77ab891..455757e 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -1,6 +1,6 @@ use std::net::SocketAddr; use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering}; use std::time::{Duration, Instant}; use bytes::BytesMut; @@ -15,7 +15,7 @@ use crate::error::{ProxyError, Result}; use crate::protocol::constants::RPC_PING_U32; use super::codec::{RpcWriter, WriterCommand}; -use super::pool::{MePool, MeWriter}; +use super::pool::{MePool, MeWriter, WriterContour}; use super::reader::reader_loop; use super::registry::BoundConn; @@ -43,6 +43,22 @@ impl MePool { } pub(crate) async fn connect_one(self: &Arc, addr: SocketAddr, rng: &SecureRandom) -> Result<()> { + self.connect_one_with_generation_contour( + addr, + rng, + self.current_generation(), + WriterContour::Active, + ) + .await + } + + pub(super) async fn connect_one_with_generation_contour( + self: &Arc, + addr: SocketAddr, + rng: &SecureRandom, + generation: u64, + contour: WriterContour, + ) -> Result<()> { let secret_len = self.proxy_secret.read().await.secret.len(); if secret_len < 32 { return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into())); @@ -52,7 +68,7 @@ impl MePool { let hs = self.handshake_only(stream, addr, upstream_egress, rng).await?; let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed); - let generation = self.current_generation(); + let contour = Arc::new(AtomicU8::new(contour.as_u8())); let cancel = CancellationToken::new(); let degraded = Arc::new(AtomicBool::new(false)); let draining = Arc::new(AtomicBool::new(false)); @@ -89,6 +105,7 @@ impl MePool { id: writer_id, addr, generation, + contour: contour.clone(), created_at: Instant::now(), tx: tx.clone(), cancel: cancel.clone(), @@ -305,6 +322,8 @@ impl MePool { if !already_draining { self.stats.increment_pool_drain_active(); } + w.contour + .store(WriterContour::Draining.as_u8(), Ordering::Relaxed); w.draining.store(true, Ordering::Relaxed); true } else { diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index 632e34a..e907d25 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -46,6 +46,7 @@ pub(crate) async fn reader_loop( _ = cancel.cancelled() => return Ok(()), }; if n == 0 { + stats.increment_me_reader_eof_total(); return Err(ProxyError::Io(std::io::Error::new( ErrorKind::UnexpectedEof, "ME socket closed by peer", diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 25b8852..3b57c4c 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -13,6 +13,7 @@ use crate::protocol::constants::RPC_CLOSE_EXT_U32; use super::MePool; use super::codec::WriterCommand; +use super::pool::WriterContour; use super::wire::build_proxy_req_payload; use rand::seq::SliceRandom; use super::registry::ConnMeta; @@ -101,7 +102,14 @@ impl MePool { ws.clone() }; - let mut candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await; + let mut candidate_indices = self + .candidate_indices_for_dc(&writers_snapshot, target_dc, false) + .await; + if candidate_indices.is_empty() { + candidate_indices = self + .candidate_indices_for_dc(&writers_snapshot, target_dc, true) + .await; + } if candidate_indices.is_empty() { // Emergency connect-on-demand if emergency_attempts >= 3 { @@ -127,7 +135,14 @@ impl MePool { let ws2 = self.writers.read().await; writers_snapshot = ws2.clone(); drop(ws2); - candidate_indices = self.candidate_indices_for_dc(&writers_snapshot, target_dc).await; + candidate_indices = self + .candidate_indices_for_dc(&writers_snapshot, target_dc, false) + .await; + if candidate_indices.is_empty() { + candidate_indices = self + .candidate_indices_for_dc(&writers_snapshot, target_dc, true) + .await; + } if !candidate_indices.is_empty() { break; } @@ -143,6 +158,7 @@ impl MePool { let left = &writers_snapshot[*lhs]; let right = &writers_snapshot[*rhs]; let left_key = ( + self.writer_contour_rank_for_selection(left), (left.generation < self.current_generation()) as usize, left.degraded.load(Ordering::Relaxed) as usize, Reverse(left.tx.capacity()), @@ -150,6 +166,7 @@ impl MePool { left.id, ); let right_key = ( + self.writer_contour_rank_for_selection(right), (right.generation < self.current_generation()) as usize, right.degraded.load(Ordering::Relaxed) as usize, Reverse(right.tx.capacity()), @@ -163,7 +180,12 @@ impl MePool { let w = &writers_snapshot[*idx]; let degraded = w.degraded.load(Ordering::Relaxed); let stale = (w.generation < self.current_generation()) as usize; - (stale, degraded as usize, Reverse(w.tx.capacity())) + ( + self.writer_contour_rank_for_selection(w), + stale, + degraded as usize, + Reverse(w.tx.capacity()), + ) }); } @@ -257,6 +279,7 @@ impl MePool { &self, writers: &[super::pool::MeWriter], target_dc: i16, + include_warm: bool, ) -> Vec { let key = target_dc as i32; let mut preferred = Vec::::new(); @@ -300,13 +323,13 @@ impl MePool { if preferred.is_empty() { return (0..writers.len()) - .filter(|i| self.writer_accepts_new_binding(&writers[*i])) + .filter(|i| self.writer_eligible_for_selection(&writers[*i], include_warm)) .collect(); } let mut out = Vec::new(); for (idx, w) in writers.iter().enumerate() { - if !self.writer_accepts_new_binding(w) { + if !self.writer_eligible_for_selection(w, include_warm) { continue; } if preferred.contains(&w.addr) { @@ -315,10 +338,33 @@ impl MePool { } if out.is_empty() { return (0..writers.len()) - .filter(|i| self.writer_accepts_new_binding(&writers[*i])) + .filter(|i| self.writer_eligible_for_selection(&writers[*i], include_warm)) .collect(); } out } + fn writer_eligible_for_selection( + &self, + writer: &super::pool::MeWriter, + include_warm: bool, + ) -> bool { + if !self.writer_accepts_new_binding(writer) { + return false; + } + + match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) { + WriterContour::Active => true, + WriterContour::Warm => include_warm, + WriterContour::Draining => true, + } + } + + fn writer_contour_rank_for_selection(&self, writer: &super::pool::MeWriter) -> usize { + match WriterContour::from_u8(writer.contour.load(Ordering::Relaxed)) { + WriterContour::Active => 0, + WriterContour::Warm => 1, + WriterContour::Draining => 2, + } + } }