ME Strict Writers

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey 2026-03-07 13:32:02 +03:00
parent 26323dbebf
commit 27e6dec018
No known key found for this signature in database
16 changed files with 487 additions and 174 deletions

View File

@ -5,6 +5,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::net::IpAddr; use std::net::IpAddr;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio::sync::RwLock; use tokio::sync::RwLock;
@ -18,6 +19,7 @@ pub struct UserIpTracker {
max_ips: Arc<RwLock<HashMap<String, usize>>>, max_ips: Arc<RwLock<HashMap<String, usize>>>,
limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>, limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>,
limit_window: Arc<RwLock<Duration>>, limit_window: Arc<RwLock<Duration>>,
last_compact_epoch_secs: Arc<AtomicU64>,
} }
impl UserIpTracker { impl UserIpTracker {
@ -28,6 +30,54 @@ impl UserIpTracker {
max_ips: Arc::new(RwLock::new(HashMap::new())), max_ips: Arc::new(RwLock::new(HashMap::new())),
limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)), limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)),
limit_window: Arc::new(RwLock::new(Duration::from_secs(30))), limit_window: Arc::new(RwLock::new(Duration::from_secs(30))),
last_compact_epoch_secs: Arc::new(AtomicU64::new(0)),
}
}
fn now_epoch_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
async fn maybe_compact_empty_users(&self) {
const COMPACT_INTERVAL_SECS: u64 = 60;
let now_epoch_secs = Self::now_epoch_secs();
let last_compact_epoch_secs = self.last_compact_epoch_secs.load(Ordering::Relaxed);
if now_epoch_secs.saturating_sub(last_compact_epoch_secs) < COMPACT_INTERVAL_SECS {
return;
}
if self
.last_compact_epoch_secs
.compare_exchange(
last_compact_epoch_secs,
now_epoch_secs,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_err()
{
return;
}
let mut active_ips = self.active_ips.write().await;
let mut recent_ips = self.recent_ips.write().await;
let mut users = Vec::<String>::with_capacity(active_ips.len().saturating_add(recent_ips.len()));
users.extend(active_ips.keys().cloned());
for user in recent_ips.keys() {
if !active_ips.contains_key(user) {
users.push(user.clone());
}
}
for user in users {
let active_empty = active_ips.get(&user).map(|ips| ips.is_empty()).unwrap_or(true);
let recent_empty = recent_ips.get(&user).map(|ips| ips.is_empty()).unwrap_or(true);
if active_empty && recent_empty {
active_ips.remove(&user);
recent_ips.remove(&user);
}
} }
} }
@ -63,6 +113,7 @@ impl UserIpTracker {
} }
pub async fn check_and_add(&self, username: &str, ip: IpAddr) -> Result<(), String> { pub async fn check_and_add(&self, username: &str, ip: IpAddr) -> Result<(), String> {
self.maybe_compact_empty_users().await;
let limit = { let limit = {
let max_ips = self.max_ips.read().await; let max_ips = self.max_ips.read().await;
max_ips.get(username).copied() max_ips.get(username).copied()
@ -116,6 +167,7 @@ impl UserIpTracker {
} }
pub async fn remove_ip(&self, username: &str, ip: IpAddr) { pub async fn remove_ip(&self, username: &str, ip: IpAddr) {
self.maybe_compact_empty_users().await;
let mut active_ips = self.active_ips.write().await; let mut active_ips = self.active_ips.write().await;
if let Some(user_ips) = active_ips.get_mut(username) { if let Some(user_ips) = active_ips.get_mut(username) {
if let Some(count) = user_ips.get_mut(&ip) { if let Some(count) = user_ips.get_mut(&ip) {

View File

@ -6,6 +6,7 @@ use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock}; use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use bytes::Bytes;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tracing::{debug, trace, warn}; use tracing::{debug, trace, warn};
@ -20,7 +21,7 @@ use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
enum C2MeCommand { enum C2MeCommand {
Data { payload: Vec<u8>, flags: u32 }, Data { payload: Bytes, flags: u32 },
Close, Close,
} }
@ -283,7 +284,7 @@ where
success.dc_idx, success.dc_idx,
peer, peer,
translated_local_addr, translated_local_addr,
&payload, payload.as_ref(),
flags, flags,
effective_tag.as_deref(), effective_tag.as_deref(),
).await?; ).await?;
@ -479,7 +480,7 @@ async fn read_client_payload<R>(
forensics: &RelayForensicsState, forensics: &RelayForensicsState,
frame_counter: &mut u64, frame_counter: &mut u64,
stats: &Stats, stats: &Stats,
) -> Result<Option<(Vec<u8>, bool)>> ) -> Result<Option<(Bytes, bool)>>
where where
R: AsyncRead + Unpin + Send + 'static, R: AsyncRead + Unpin + Send + 'static,
{ {
@ -578,7 +579,7 @@ where
payload.truncate(secure_payload_len); payload.truncate(secure_payload_len);
} }
*frame_counter += 1; *frame_counter += 1;
return Ok(Some((payload, quickack))); return Ok(Some((Bytes::from(payload), quickack)));
} }
} }
@ -715,7 +716,7 @@ mod tests {
enqueue_c2me_command( enqueue_c2me_command(
&tx, &tx,
C2MeCommand::Data { C2MeCommand::Data {
payload: vec![1, 2, 3], payload: Bytes::from_static(&[1, 2, 3]),
flags: 0, flags: 0,
}, },
) )
@ -728,7 +729,7 @@ mod tests {
.unwrap(); .unwrap();
match recv { match recv {
C2MeCommand::Data { payload, flags } => { C2MeCommand::Data { payload, flags } => {
assert_eq!(payload, vec![1, 2, 3]); assert_eq!(payload.as_ref(), &[1, 2, 3]);
assert_eq!(flags, 0); assert_eq!(flags, 0);
} }
C2MeCommand::Close => panic!("unexpected close command"), C2MeCommand::Close => panic!("unexpected close command"),
@ -739,7 +740,7 @@ mod tests {
async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1); let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
tx.send(C2MeCommand::Data { tx.send(C2MeCommand::Data {
payload: vec![9], payload: Bytes::from_static(&[9]),
flags: 9, flags: 9,
}) })
.await .await
@ -750,7 +751,7 @@ mod tests {
enqueue_c2me_command( enqueue_c2me_command(
&tx2, &tx2,
C2MeCommand::Data { C2MeCommand::Data {
payload: vec![7, 7], payload: Bytes::from_static(&[7, 7]),
flags: 7, flags: 7,
}, },
) )
@ -769,7 +770,7 @@ mod tests {
.unwrap(); .unwrap();
match recv { match recv {
C2MeCommand::Data { payload, flags } => { C2MeCommand::Data { payload, flags } => {
assert_eq!(payload, vec![7, 7]); assert_eq!(payload.as_ref(), &[7, 7]);
assert_eq!(flags, 7); assert_eq!(flags, 7);
} }
C2MeCommand::Close => panic!("unexpected close command"), C2MeCommand::Close => panic!("unexpected close command"),

View File

@ -6,7 +6,7 @@ pub mod beobachten;
pub mod telemetry; pub mod telemetry;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering};
use std::time::{Instant, Duration}; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use dashmap::DashMap; use dashmap::DashMap;
use parking_lot::Mutex; use parking_lot::Mutex;
use lru::LruCache; use lru::LruCache;
@ -119,6 +119,7 @@ pub struct Stats {
telemetry_user_enabled: AtomicBool, telemetry_user_enabled: AtomicBool,
telemetry_me_level: AtomicU8, telemetry_me_level: AtomicU8,
user_stats: DashMap<String, UserStats>, user_stats: DashMap<String, UserStats>,
user_stats_last_cleanup_epoch_secs: AtomicU64,
start_time: parking_lot::RwLock<Option<Instant>>, start_time: parking_lot::RwLock<Option<Instant>>,
} }
@ -130,6 +131,7 @@ pub struct UserStats {
pub octets_to_client: AtomicU64, pub octets_to_client: AtomicU64,
pub msgs_from_client: AtomicU64, pub msgs_from_client: AtomicU64,
pub msgs_to_client: AtomicU64, pub msgs_to_client: AtomicU64,
pub last_seen_epoch_secs: AtomicU64,
} }
impl Stats { impl Stats {
@ -178,6 +180,54 @@ impl Stats {
} }
} }
fn now_epoch_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn touch_user_stats(stats: &UserStats) {
stats
.last_seen_epoch_secs
.store(Self::now_epoch_secs(), Ordering::Relaxed);
}
fn maybe_cleanup_user_stats(&self) {
const USER_STATS_CLEANUP_INTERVAL_SECS: u64 = 60;
const USER_STATS_IDLE_TTL_SECS: u64 = 24 * 60 * 60;
let now_epoch_secs = Self::now_epoch_secs();
let last_cleanup_epoch_secs = self
.user_stats_last_cleanup_epoch_secs
.load(Ordering::Relaxed);
if now_epoch_secs.saturating_sub(last_cleanup_epoch_secs)
< USER_STATS_CLEANUP_INTERVAL_SECS
{
return;
}
if self
.user_stats_last_cleanup_epoch_secs
.compare_exchange(
last_cleanup_epoch_secs,
now_epoch_secs,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_err()
{
return;
}
self.user_stats.retain(|_, stats| {
if stats.curr_connects.load(Ordering::Relaxed) > 0 {
return true;
}
let last_seen_epoch_secs = stats.last_seen_epoch_secs.load(Ordering::Relaxed);
now_epoch_secs.saturating_sub(last_seen_epoch_secs) <= USER_STATS_IDLE_TTL_SECS
});
}
pub fn apply_telemetry_policy(&self, policy: TelemetryPolicy) { pub fn apply_telemetry_policy(&self, policy: TelemetryPolicy) {
self.telemetry_core_enabled self.telemetry_core_enabled
.store(policy.core_enabled, Ordering::Relaxed); .store(policy.core_enabled, Ordering::Relaxed);
@ -970,34 +1020,36 @@ impl Stats {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) { if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.connects.fetch_add(1, Ordering::Relaxed); stats.connects.fetch_add(1, Ordering::Relaxed);
return; return;
} }
self.user_stats let stats = self.user_stats.entry(user.to_string()).or_default();
.entry(user.to_string()) Self::touch_user_stats(stats.value());
.or_default() stats.connects.fetch_add(1, Ordering::Relaxed);
.connects
.fetch_add(1, Ordering::Relaxed);
} }
pub fn increment_user_curr_connects(&self, user: &str) { pub fn increment_user_curr_connects(&self, user: &str) {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) { if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.curr_connects.fetch_add(1, Ordering::Relaxed); stats.curr_connects.fetch_add(1, Ordering::Relaxed);
return; return;
} }
self.user_stats let stats = self.user_stats.entry(user.to_string()).or_default();
.entry(user.to_string()) Self::touch_user_stats(stats.value());
.or_default() stats.curr_connects.fetch_add(1, Ordering::Relaxed);
.curr_connects
.fetch_add(1, Ordering::Relaxed);
} }
pub fn decrement_user_curr_connects(&self, user: &str) { pub fn decrement_user_curr_connects(&self, user: &str) {
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) { if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
let counter = &stats.curr_connects; let counter = &stats.curr_connects;
let mut current = counter.load(Ordering::Relaxed); let mut current = counter.load(Ordering::Relaxed);
loop { loop {
@ -1027,60 +1079,60 @@ impl Stats {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) { if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
return; return;
} }
self.user_stats let stats = self.user_stats.entry(user.to_string()).or_default();
.entry(user.to_string()) Self::touch_user_stats(stats.value());
.or_default() stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
.octets_from_client
.fetch_add(bytes, Ordering::Relaxed);
} }
pub fn add_user_octets_to(&self, user: &str, bytes: u64) { pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) { if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
return; return;
} }
self.user_stats let stats = self.user_stats.entry(user.to_string()).or_default();
.entry(user.to_string()) Self::touch_user_stats(stats.value());
.or_default() stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
.octets_to_client
.fetch_add(bytes, Ordering::Relaxed);
} }
pub fn increment_user_msgs_from(&self, user: &str) { pub fn increment_user_msgs_from(&self, user: &str) {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) { if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); stats.msgs_from_client.fetch_add(1, Ordering::Relaxed);
return; return;
} }
self.user_stats let stats = self.user_stats.entry(user.to_string()).or_default();
.entry(user.to_string()) Self::touch_user_stats(stats.value());
.or_default() stats.msgs_from_client.fetch_add(1, Ordering::Relaxed);
.msgs_from_client
.fetch_add(1, Ordering::Relaxed);
} }
pub fn increment_user_msgs_to(&self, user: &str) { pub fn increment_user_msgs_to(&self, user: &str) {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) { if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); stats.msgs_to_client.fetch_add(1, Ordering::Relaxed);
return; return;
} }
self.user_stats let stats = self.user_stats.entry(user.to_string()).or_default();
.entry(user.to_string()) Self::touch_user_stats(stats.value());
.or_default() stats.msgs_to_client.fetch_add(1, Ordering::Relaxed);
.msgs_to_client
.fetch_add(1, Ordering::Relaxed);
} }
pub fn get_user_total_octets(&self, user: &str) -> u64 { pub fn get_user_total_octets(&self, user: &str) -> u64 {

View File

@ -1,4 +1,5 @@
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use bytes::Bytes;
use crate::crypto::{AesCbc, crc32, crc32c}; use crate::crypto::{AesCbc, crc32, crc32c};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
@ -6,8 +7,8 @@ use crate::protocol::constants::*;
/// Commands sent to dedicated writer tasks to avoid mutex contention on TCP writes. /// Commands sent to dedicated writer tasks to avoid mutex contention on TCP writes.
pub(crate) enum WriterCommand { pub(crate) enum WriterCommand {
Data(Vec<u8>), Data(Bytes),
DataAndFlush(Vec<u8>), DataAndFlush(Bytes),
Close, Close,
} }

View File

@ -135,10 +135,15 @@ impl MePool {
pub(crate) async fn connect_tcp( pub(crate) async fn connect_tcp(
&self, &self,
addr: SocketAddr, addr: SocketAddr,
dc_idx_override: Option<i16>,
) -> Result<(TcpStream, f64, Option<UpstreamEgressInfo>)> { ) -> Result<(TcpStream, f64, Option<UpstreamEgressInfo>)> {
let start = Instant::now(); let start = Instant::now();
let (stream, upstream_egress) = if let Some(upstream) = &self.upstream { let (stream, upstream_egress) = if let Some(upstream) = &self.upstream {
let dc_idx = self.resolve_dc_idx_for_endpoint(addr).await; let dc_idx = if let Some(dc_idx) = dc_idx_override {
Some(dc_idx)
} else {
self.resolve_dc_idx_for_endpoint(addr).await
};
let (stream, egress) = upstream.connect_with_details(addr, dc_idx, None).await?; let (stream, egress) = upstream.connect_with_details(addr, dc_idx, None).await?;
(stream, Some(egress)) (stream, Some(egress))
} else { } else {

View File

@ -60,6 +60,7 @@ pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_c
loop { loop {
tokio::time::sleep(Duration::from_secs(HEALTH_INTERVAL_SECS)).await; tokio::time::sleep(Duration::from_secs(HEALTH_INTERVAL_SECS)).await;
pool.prune_closed_writers().await; pool.prune_closed_writers().await;
reap_draining_writers(&pool).await;
check_family( check_family(
IpFamily::V4, IpFamily::V4,
&pool, &pool,
@ -95,6 +96,28 @@ pub async fn me_health_monitor(pool: Arc<MePool>, rng: Arc<SecureRandom>, _min_c
} }
} }
async fn reap_draining_writers(pool: &Arc<MePool>) {
let now_epoch_secs = MePool::now_epoch_secs();
let writers = pool.writers.read().await.clone();
for writer in writers {
if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) {
continue;
}
if pool.registry.is_writer_empty(writer.id).await {
pool.remove_writer_and_close_clients(writer.id).await;
continue;
}
let deadline_epoch_secs = writer
.drain_deadline_epoch_secs
.load(std::sync::atomic::Ordering::Relaxed);
if deadline_epoch_secs != 0 && now_epoch_secs >= deadline_epoch_secs {
warn!(writer_id = writer.id, "Drain timeout, force-closing");
pool.stats.increment_pool_force_close_total();
pool.remove_writer_and_close_clients(writer.id).await;
}
}
}
async fn check_family( async fn check_family(
family: IpFamily, family: IpFamily,
pool: &Arc<MePool>, pool: &Arc<MePool>,
@ -153,12 +176,18 @@ async fn check_family(
.push(writer.id); .push(writer.id);
} }
let writer_idle_since = pool.registry.writer_idle_since_snapshot().await; let writer_idle_since = pool.registry.writer_idle_since_snapshot().await;
let bound_clients_by_writer = pool
.registry
.writer_activity_snapshot()
.await
.bound_clients_by_writer;
let floor_plan = build_family_floor_plan( let floor_plan = build_family_floor_plan(
pool, pool,
family, family,
&dc_endpoints, &dc_endpoints,
&live_addr_counts, &live_addr_counts,
&live_writer_ids_by_addr, &live_writer_ids_by_addr,
&bound_clients_by_writer,
adaptive_idle_since, adaptive_idle_since,
adaptive_recover_until, adaptive_recover_until,
) )
@ -241,6 +270,7 @@ async fn check_family(
required, required,
&live_writer_ids_by_addr, &live_writer_ids_by_addr,
&writer_idle_since, &writer_idle_since,
&bound_clients_by_writer,
idle_refresh_next_attempt, idle_refresh_next_attempt,
) )
.await; .await;
@ -254,6 +284,7 @@ async fn check_family(
alive, alive,
required, required,
&live_writer_ids_by_addr, &live_writer_ids_by_addr,
&bound_clients_by_writer,
shadow_rotate_deadline, shadow_rotate_deadline,
) )
.await; .await;
@ -320,6 +351,7 @@ async fn check_family(
&endpoints, &endpoints,
&live_writer_ids_by_addr, &live_writer_ids_by_addr,
&writer_idle_since, &writer_idle_since,
&bound_clients_by_writer,
) )
.await; .await;
if swapped { if swapped {
@ -470,6 +502,7 @@ async fn build_family_floor_plan(
dc_endpoints: &HashMap<i32, Vec<SocketAddr>>, dc_endpoints: &HashMap<i32, Vec<SocketAddr>>,
live_addr_counts: &HashMap<SocketAddr, usize>, live_addr_counts: &HashMap<SocketAddr, usize>,
live_writer_ids_by_addr: &HashMap<SocketAddr, Vec<u64>>, live_writer_ids_by_addr: &HashMap<SocketAddr, Vec<u64>>,
bound_clients_by_writer: &HashMap<u64, usize>,
adaptive_idle_since: &mut HashMap<(i32, IpFamily), Instant>, adaptive_idle_since: &mut HashMap<(i32, IpFamily), Instant>,
adaptive_recover_until: &mut HashMap<(i32, IpFamily), Instant>, adaptive_recover_until: &mut HashMap<(i32, IpFamily), Instant>,
) -> FamilyFloorPlan { ) -> FamilyFloorPlan {
@ -491,6 +524,7 @@ async fn build_family_floor_plan(
key, key,
endpoints, endpoints,
live_writer_ids_by_addr, live_writer_ids_by_addr,
bound_clients_by_writer,
adaptive_idle_since, adaptive_idle_since,
adaptive_recover_until, adaptive_recover_until,
) )
@ -521,7 +555,7 @@ async fn build_family_floor_plan(
.sum::<usize>(); .sum::<usize>();
family_active_total = family_active_total.saturating_add(alive); family_active_total = family_active_total.saturating_add(alive);
let writer_ids = list_writer_ids_for_endpoints(endpoints, live_writer_ids_by_addr); let writer_ids = list_writer_ids_for_endpoints(endpoints, live_writer_ids_by_addr);
let has_bound_clients = has_bound_clients_on_endpoint(pool, &writer_ids).await; let has_bound_clients = has_bound_clients_on_endpoint(&writer_ids, bound_clients_by_writer);
entries.push(DcFloorPlanEntry { entries.push(DcFloorPlanEntry {
dc: *dc, dc: *dc,
@ -622,6 +656,7 @@ async fn maybe_swap_idle_writer_for_cap(
endpoints: &[SocketAddr], endpoints: &[SocketAddr],
live_writer_ids_by_addr: &HashMap<SocketAddr, Vec<u64>>, live_writer_ids_by_addr: &HashMap<SocketAddr, Vec<u64>>,
writer_idle_since: &HashMap<u64, u64>, writer_idle_since: &HashMap<u64, u64>,
bound_clients_by_writer: &HashMap<u64, usize>,
) -> bool { ) -> bool {
let now_epoch_secs = MePool::now_epoch_secs(); let now_epoch_secs = MePool::now_epoch_secs();
let mut candidate: Option<(u64, SocketAddr, u64)> = None; let mut candidate: Option<(u64, SocketAddr, u64)> = None;
@ -630,7 +665,7 @@ async fn maybe_swap_idle_writer_for_cap(
continue; continue;
}; };
for writer_id in writer_ids { for writer_id in writer_ids {
if !pool.registry.is_writer_empty(*writer_id).await { if bound_clients_by_writer.get(writer_id).copied().unwrap_or(0) > 0 {
continue; continue;
} }
let Some(idle_since_epoch_secs) = writer_idle_since.get(writer_id).copied() else { let Some(idle_since_epoch_secs) = writer_idle_since.get(writer_id).copied() else {
@ -705,6 +740,7 @@ async fn maybe_refresh_idle_writer_for_dc(
required: usize, required: usize,
live_writer_ids_by_addr: &HashMap<SocketAddr, Vec<u64>>, live_writer_ids_by_addr: &HashMap<SocketAddr, Vec<u64>>,
writer_idle_since: &HashMap<u64, u64>, writer_idle_since: &HashMap<u64, u64>,
bound_clients_by_writer: &HashMap<u64, usize>,
idle_refresh_next_attempt: &mut HashMap<(i32, IpFamily), Instant>, idle_refresh_next_attempt: &mut HashMap<(i32, IpFamily), Instant>,
) { ) {
if alive < required { if alive < required {
@ -725,6 +761,9 @@ async fn maybe_refresh_idle_writer_for_dc(
continue; continue;
}; };
for writer_id in writer_ids { for writer_id in writer_ids {
if bound_clients_by_writer.get(writer_id).copied().unwrap_or(0) > 0 {
continue;
}
let Some(idle_since_epoch_secs) = writer_idle_since.get(writer_id).copied() else { let Some(idle_since_epoch_secs) = writer_idle_since.get(writer_id).copied() else {
continue; continue;
}; };
@ -806,6 +845,7 @@ async fn should_reduce_floor_for_idle(
key: (i32, IpFamily), key: (i32, IpFamily),
endpoints: &[SocketAddr], endpoints: &[SocketAddr],
live_writer_ids_by_addr: &HashMap<SocketAddr, Vec<u64>>, live_writer_ids_by_addr: &HashMap<SocketAddr, Vec<u64>>,
bound_clients_by_writer: &HashMap<u64, usize>,
adaptive_idle_since: &mut HashMap<(i32, IpFamily), Instant>, adaptive_idle_since: &mut HashMap<(i32, IpFamily), Instant>,
adaptive_recover_until: &mut HashMap<(i32, IpFamily), Instant>, adaptive_recover_until: &mut HashMap<(i32, IpFamily), Instant>,
) -> bool { ) -> bool {
@ -817,7 +857,7 @@ async fn should_reduce_floor_for_idle(
let now = Instant::now(); let now = Instant::now();
let writer_ids = list_writer_ids_for_endpoints(endpoints, live_writer_ids_by_addr); let writer_ids = list_writer_ids_for_endpoints(endpoints, live_writer_ids_by_addr);
let has_bound_clients = has_bound_clients_on_endpoint(pool, &writer_ids).await; let has_bound_clients = has_bound_clients_on_endpoint(&writer_ids, bound_clients_by_writer);
if has_bound_clients { if has_bound_clients {
adaptive_idle_since.remove(&key); adaptive_idle_since.remove(&key);
adaptive_recover_until.insert(key, now + pool.adaptive_floor_recover_grace_duration()); adaptive_recover_until.insert(key, now + pool.adaptive_floor_recover_grace_duration());
@ -836,13 +876,13 @@ async fn should_reduce_floor_for_idle(
now.saturating_duration_since(*idle_since) >= pool.adaptive_floor_idle_duration() now.saturating_duration_since(*idle_since) >= pool.adaptive_floor_idle_duration()
} }
async fn has_bound_clients_on_endpoint(pool: &Arc<MePool>, writer_ids: &[u64]) -> bool { fn has_bound_clients_on_endpoint(
for writer_id in writer_ids { writer_ids: &[u64],
if !pool.registry.is_writer_empty(*writer_id).await { bound_clients_by_writer: &HashMap<u64, usize>,
return true; ) -> bool {
} writer_ids
} .iter()
false .any(|writer_id| bound_clients_by_writer.get(writer_id).copied().unwrap_or(0) > 0)
} }
async fn recover_single_endpoint_outage( async fn recover_single_endpoint_outage(
@ -973,6 +1013,7 @@ async fn maybe_rotate_single_endpoint_shadow(
alive: usize, alive: usize,
required: usize, required: usize,
live_writer_ids_by_addr: &HashMap<SocketAddr, Vec<u64>>, live_writer_ids_by_addr: &HashMap<SocketAddr, Vec<u64>>,
bound_clients_by_writer: &HashMap<u64, usize>,
shadow_rotate_deadline: &mut HashMap<(i32, IpFamily), Instant>, shadow_rotate_deadline: &mut HashMap<(i32, IpFamily), Instant>,
) { ) {
if endpoints.len() != 1 || alive < required { if endpoints.len() != 1 || alive < required {
@ -1011,7 +1052,7 @@ async fn maybe_rotate_single_endpoint_shadow(
let mut candidate_writer_id = None; let mut candidate_writer_id = None;
for writer_id in writer_ids { for writer_id in writer_ids {
if pool.registry.is_writer_empty(*writer_id).await { if bound_clients_by_writer.get(writer_id).copied().unwrap_or(0) == 0 {
candidate_writer_id = Some(*writer_id); candidate_writer_id = Some(*writer_id);
break; break;
} }

View File

@ -331,7 +331,7 @@ pub async fn run_me_ping(pool: &Arc<MePool>, rng: &SecureRandom) -> Vec<MePingRe
let mut error = None; let mut error = None;
let mut route = None; let mut route = None;
match pool.connect_tcp(addr).await { match pool.connect_tcp(addr, None).await {
Ok((stream, conn_rtt, upstream_egress)) => { Ok((stream, conn_rtt, upstream_egress)) => {
connect_ms = Some(conn_rtt); connect_ms = Some(conn_rtt);
route = route_from_egress(upstream_egress); route = route_from_egress(upstream_egress);

View File

@ -22,10 +22,17 @@ pub(super) struct RefillDcKey {
pub family: IpFamily, pub family: IpFamily,
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(super) struct RefillEndpointKey {
pub dc: i32,
pub addr: SocketAddr,
}
#[derive(Clone)] #[derive(Clone)]
pub struct MeWriter { pub struct MeWriter {
pub id: u64, pub id: u64,
pub addr: SocketAddr, pub addr: SocketAddr,
pub writer_dc: i32,
pub generation: u64, pub generation: u64,
pub contour: Arc<AtomicU8>, pub contour: Arc<AtomicU8>,
pub created_at: Instant, pub created_at: Instant,
@ -34,6 +41,7 @@ pub struct MeWriter {
pub degraded: Arc<AtomicBool>, pub degraded: Arc<AtomicBool>,
pub draining: Arc<AtomicBool>, pub draining: Arc<AtomicBool>,
pub draining_started_at_epoch_secs: Arc<AtomicU64>, pub draining_started_at_epoch_secs: Arc<AtomicU64>,
pub drain_deadline_epoch_secs: Arc<AtomicU64>,
pub allow_drain_fallback: Arc<AtomicBool>, pub allow_drain_fallback: Arc<AtomicBool>,
} }
@ -128,12 +136,13 @@ pub struct MePool {
pub(super) default_dc: AtomicI32, pub(super) default_dc: AtomicI32,
pub(super) next_writer_id: AtomicU64, pub(super) next_writer_id: AtomicU64,
pub(super) ping_tracker: Arc<Mutex<HashMap<i64, (std::time::Instant, u64)>>>, pub(super) ping_tracker: Arc<Mutex<HashMap<i64, (std::time::Instant, u64)>>>,
pub(super) ping_tracker_last_cleanup_epoch_ms: AtomicU64,
pub(super) rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>, pub(super) rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
pub(super) nat_reflection_cache: Arc<Mutex<NatReflectionCache>>, pub(super) nat_reflection_cache: Arc<Mutex<NatReflectionCache>>,
pub(super) nat_reflection_singleflight_v4: Arc<Mutex<()>>, pub(super) nat_reflection_singleflight_v4: Arc<Mutex<()>>,
pub(super) nat_reflection_singleflight_v6: Arc<Mutex<()>>, pub(super) nat_reflection_singleflight_v6: Arc<Mutex<()>>,
pub(super) writer_available: Arc<Notify>, pub(super) writer_available: Arc<Notify>,
pub(super) refill_inflight: Arc<Mutex<HashSet<SocketAddr>>>, pub(super) refill_inflight: Arc<Mutex<HashSet<RefillEndpointKey>>>,
pub(super) refill_inflight_dc: Arc<Mutex<HashSet<RefillDcKey>>>, pub(super) refill_inflight_dc: Arc<Mutex<HashSet<RefillDcKey>>>,
pub(super) conn_count: AtomicUsize, pub(super) conn_count: AtomicUsize,
pub(super) stats: Arc<crate::stats::Stats>, pub(super) stats: Arc<crate::stats::Stats>,
@ -361,6 +370,7 @@ impl MePool {
default_dc: AtomicI32::new(default_dc.unwrap_or(2)), default_dc: AtomicI32::new(default_dc.unwrap_or(2)),
next_writer_id: AtomicU64::new(1), next_writer_id: AtomicU64::new(1),
ping_tracker: Arc::new(Mutex::new(HashMap::new())), ping_tracker: Arc::new(Mutex::new(HashMap::new())),
ping_tracker_last_cleanup_epoch_ms: AtomicU64::new(0),
rtt_stats: Arc::new(Mutex::new(HashMap::new())), rtt_stats: Arc::new(Mutex::new(HashMap::new())),
nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())), nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())),
nat_reflection_singleflight_v4: Arc::new(Mutex::new(())), nat_reflection_singleflight_v4: Arc::new(Mutex::new(())),
@ -779,6 +789,36 @@ impl MePool {
if dc == 0 { 2 } else { dc } if dc == 0 { 2 } else { dc }
} }
pub(super) async fn has_configured_endpoints_for_dc(&self, dc: i32) -> bool {
if self.decision.ipv4_me {
let map = self.proxy_map_v4.read().await;
if map.get(&dc).is_some_and(|endpoints| !endpoints.is_empty()) {
return true;
}
}
if self.decision.ipv6_me {
let map = self.proxy_map_v6.read().await;
if map.get(&dc).is_some_and(|endpoints| !endpoints.is_empty()) {
return true;
}
}
false
}
pub(super) async fn resolve_target_dc_for_routing(&self, target_dc: i32) -> (i32, bool) {
if target_dc == 0 {
return (self.default_dc_for_routing(), true);
}
if self.has_configured_endpoints_for_dc(target_dc).await {
return (target_dc, false);
}
(self.default_dc_for_routing(), true)
}
pub(super) fn dc_lookup_chain_for_target(&self, target_dc: i32) -> Vec<i32> { pub(super) fn dc_lookup_chain_for_target(&self, target_dc: i32) -> Vec<i32> {
let mut out = Vec::with_capacity(1); let mut out = Vec::with_capacity(1);
if target_dc != 0 { if target_dc != 0 {

View File

@ -55,7 +55,11 @@ impl MePool {
.iter() .iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port)) .map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect(); .collect();
if self.active_writer_count_for_endpoints(&endpoints).await >= target_writers { if self
.active_writer_count_for_dc_endpoints(dc, &endpoints)
.await
>= target_writers
{
continue; continue;
} }
let pool = Arc::clone(self); let pool = Arc::clone(self);
@ -79,7 +83,7 @@ impl MePool {
.iter() .iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port)) .map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect(); .collect();
if self.active_writer_count_for_endpoints(&endpoints).await == 0 { if self.active_writer_count_for_dc_endpoints(*dc, &endpoints).await == 0 {
missing_dcs.push(*dc); missing_dcs.push(*dc);
} }
} }
@ -156,7 +160,9 @@ impl MePool {
let endpoint_set: HashSet<SocketAddr> = endpoints.iter().copied().collect(); let endpoint_set: HashSet<SocketAddr> = endpoints.iter().copied().collect();
loop { loop {
let alive = self.active_writer_count_for_endpoints(&endpoint_set).await; let alive = self
.active_writer_count_for_dc_endpoints(dc, &endpoint_set)
.await;
if alive >= target_writers { if alive >= target_writers {
info!( info!(
dc = %dc, dc = %dc,
@ -175,7 +181,7 @@ impl MePool {
let rng_clone = Arc::clone(&rng); let rng_clone = Arc::clone(&rng);
let endpoints_clone = endpoints.clone(); let endpoints_clone = endpoints.clone();
join.spawn(async move { join.spawn(async move {
pool.connect_endpoints_round_robin(&endpoints_clone, rng_clone.as_ref()) pool.connect_endpoints_round_robin(dc, &endpoints_clone, rng_clone.as_ref())
.await .await
}); });
} }
@ -193,7 +199,9 @@ impl MePool {
} }
} }
let alive_after = self.active_writer_count_for_endpoints(&endpoint_set).await; let alive_after = self
.active_writer_count_for_dc_endpoints(dc, &endpoint_set)
.await;
if alive_after >= target_writers { if alive_after >= target_writers {
info!( info!(
dc = %dc, dc = %dc,

View File

@ -9,7 +9,7 @@ use tracing::{debug, info, warn};
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::network::IpFamily; use crate::network::IpFamily;
use super::pool::{MePool, RefillDcKey, WriterContour}; use super::pool::{MePool, RefillDcKey, RefillEndpointKey, WriterContour};
const ME_FLAP_UPTIME_THRESHOLD_SECS: u64 = 20; const ME_FLAP_UPTIME_THRESHOLD_SECS: u64 = 20;
const ME_FLAP_QUARANTINE_SECS: u64 = 25; const ME_FLAP_QUARANTINE_SECS: u64 = 25;
@ -82,57 +82,19 @@ impl MePool {
Vec::new() Vec::new()
} }
pub(super) async fn has_refill_inflight_for_endpoints(&self, endpoints: &[SocketAddr]) -> bool { pub(super) async fn has_refill_inflight_for_dc_key(&self, key: RefillDcKey) -> bool {
if endpoints.is_empty() {
return false;
}
{
let guard = self.refill_inflight.lock().await;
if endpoints.iter().any(|addr| guard.contains(addr)) {
return true;
}
}
let dc_keys = self.resolve_refill_dc_keys_for_endpoints(endpoints).await;
if dc_keys.is_empty() {
return false;
}
let guard = self.refill_inflight_dc.lock().await; let guard = self.refill_inflight_dc.lock().await;
dc_keys.iter().any(|key| guard.contains(key)) guard.contains(&key)
}
async fn resolve_refill_dc_key_for_addr(&self, addr: SocketAddr) -> Option<RefillDcKey> {
let family = if addr.is_ipv4() {
IpFamily::V4
} else {
IpFamily::V6
};
Some(RefillDcKey {
dc: self.resolve_dc_for_endpoint(addr).await,
family,
})
}
async fn resolve_refill_dc_keys_for_endpoints(
&self,
endpoints: &[SocketAddr],
) -> HashSet<RefillDcKey> {
let mut out = HashSet::<RefillDcKey>::new();
for addr in endpoints {
if let Some(key) = self.resolve_refill_dc_key_for_addr(*addr).await {
out.insert(key);
}
}
out
} }
pub(super) async fn connect_endpoints_round_robin( pub(super) async fn connect_endpoints_round_robin(
self: &Arc<Self>, self: &Arc<Self>,
dc: i32,
endpoints: &[SocketAddr], endpoints: &[SocketAddr],
rng: &SecureRandom, rng: &SecureRandom,
) -> bool { ) -> bool {
self.connect_endpoints_round_robin_with_generation_contour( self.connect_endpoints_round_robin_with_generation_contour(
dc,
endpoints, endpoints,
rng, rng,
self.current_generation(), self.current_generation(),
@ -143,6 +105,7 @@ impl MePool {
pub(super) async fn connect_endpoints_round_robin_with_generation_contour( pub(super) async fn connect_endpoints_round_robin_with_generation_contour(
self: &Arc<Self>, self: &Arc<Self>,
dc: i32,
endpoints: &[SocketAddr], endpoints: &[SocketAddr],
rng: &SecureRandom, rng: &SecureRandom,
generation: u64, generation: u64,
@ -157,7 +120,7 @@ impl MePool {
let idx = (start + offset) % candidates.len(); let idx = (start + offset) % candidates.len();
let addr = candidates[idx]; let addr = candidates[idx];
match self match self
.connect_one_with_generation_contour(addr, rng, generation, contour) .connect_one_with_generation_contour_for_dc(addr, rng, generation, contour, dc)
.await .await
{ {
Ok(()) => return true, Ok(()) => return true,
@ -167,9 +130,8 @@ impl MePool {
false false
} }
async fn endpoints_for_same_dc(&self, addr: SocketAddr) -> Vec<SocketAddr> { async fn endpoints_for_dc(&self, target_dc: i32) -> Vec<SocketAddr> {
let mut endpoints = HashSet::<SocketAddr>::new(); let mut endpoints = HashSet::<SocketAddr>::new();
let target_dc = self.resolve_dc_for_endpoint(addr).await;
if self.decision.ipv4_me { if self.decision.ipv4_me {
let map = self.proxy_map_v4.read().await; let map = self.proxy_map_v4.read().await;
@ -194,14 +156,14 @@ impl MePool {
sorted sorted
} }
async fn refill_writer_after_loss(self: &Arc<Self>, addr: SocketAddr) -> bool { async fn refill_writer_after_loss(self: &Arc<Self>, addr: SocketAddr, writer_dc: i32) -> bool {
let fast_retries = self.me_reconnect_fast_retry_count.max(1); let fast_retries = self.me_reconnect_fast_retry_count.max(1);
let same_endpoint_quarantined = self.is_endpoint_quarantined(addr).await; let same_endpoint_quarantined = self.is_endpoint_quarantined(addr).await;
if !same_endpoint_quarantined { if !same_endpoint_quarantined {
for attempt in 0..fast_retries { for attempt in 0..fast_retries {
self.stats.increment_me_reconnect_attempt(); self.stats.increment_me_reconnect_attempt();
match self.connect_one(addr, self.rng.as_ref()).await { match self.connect_one_for_dc(addr, writer_dc, self.rng.as_ref()).await {
Ok(()) => { Ok(()) => {
self.stats.increment_me_reconnect_success(); self.stats.increment_me_reconnect_success();
self.stats.increment_me_writer_restored_same_endpoint_total(); self.stats.increment_me_writer_restored_same_endpoint_total();
@ -229,7 +191,7 @@ impl MePool {
); );
} }
let dc_endpoints = self.endpoints_for_same_dc(addr).await; let dc_endpoints = self.endpoints_for_dc(writer_dc).await;
if dc_endpoints.is_empty() { if dc_endpoints.is_empty() {
self.stats.increment_me_refill_failed_total(); self.stats.increment_me_refill_failed_total();
return false; return false;
@ -238,7 +200,7 @@ impl MePool {
for attempt in 0..fast_retries { for attempt in 0..fast_retries {
self.stats.increment_me_reconnect_attempt(); self.stats.increment_me_reconnect_attempt();
if self if self
.connect_endpoints_round_robin(&dc_endpoints, self.rng.as_ref()) .connect_endpoints_round_robin(writer_dc, &dc_endpoints, self.rng.as_ref())
.await .await
{ {
self.stats.increment_me_reconnect_success(); self.stats.increment_me_reconnect_success();
@ -259,45 +221,69 @@ impl MePool {
pub(crate) fn trigger_immediate_refill(self: &Arc<Self>, addr: SocketAddr) { pub(crate) fn trigger_immediate_refill(self: &Arc<Self>, addr: SocketAddr) {
let pool = Arc::clone(self); let pool = Arc::clone(self);
tokio::spawn(async move { tokio::spawn(async move {
let dc_endpoints = pool.endpoints_for_same_dc(addr).await; let writer_dc = pool.resolve_dc_for_endpoint(addr).await;
let dc_keys = pool.resolve_refill_dc_keys_for_endpoints(&dc_endpoints).await; pool.trigger_immediate_refill_for_dc(addr, writer_dc);
});
}
{ pub(crate) fn trigger_immediate_refill_for_dc(self: &Arc<Self>, addr: SocketAddr, writer_dc: i32) {
let endpoint_key = RefillEndpointKey {
dc: writer_dc,
addr,
};
let pre_inserted = if let Ok(mut guard) = self.refill_inflight.try_lock() {
if !guard.insert(endpoint_key) {
self.stats.increment_me_refill_skipped_inflight_total();
return;
}
true
} else {
false
};
let pool = Arc::clone(self);
tokio::spawn(async move {
let dc_endpoints = pool.endpoints_for_dc(writer_dc).await;
let dc_key = RefillDcKey {
dc: writer_dc,
family: if addr.is_ipv4() {
IpFamily::V4
} else {
IpFamily::V6
},
};
if !pre_inserted {
let mut guard = pool.refill_inflight.lock().await; let mut guard = pool.refill_inflight.lock().await;
if !guard.insert(addr) { if !guard.insert(endpoint_key) {
pool.stats.increment_me_refill_skipped_inflight_total(); pool.stats.increment_me_refill_skipped_inflight_total();
return; return;
} }
} }
if !dc_keys.is_empty() { {
let mut dc_guard = pool.refill_inflight_dc.lock().await; let mut dc_guard = pool.refill_inflight_dc.lock().await;
if dc_keys.iter().any(|key| dc_guard.contains(key)) { if dc_guard.contains(&dc_key) {
pool.stats.increment_me_refill_skipped_inflight_total(); pool.stats.increment_me_refill_skipped_inflight_total();
drop(dc_guard); drop(dc_guard);
let mut guard = pool.refill_inflight.lock().await; let mut guard = pool.refill_inflight.lock().await;
guard.remove(&addr); guard.remove(&endpoint_key);
return; return;
} }
dc_guard.extend(dc_keys.iter().copied()); dc_guard.insert(dc_key);
} }
pool.stats.increment_me_refill_triggered_total(); pool.stats.increment_me_refill_triggered_total();
let restored = pool.refill_writer_after_loss(addr, writer_dc).await;
let restored = pool.refill_writer_after_loss(addr).await;
if !restored { if !restored {
warn!(%addr, "ME immediate refill failed"); warn!(%addr, dc = writer_dc, "ME immediate refill failed");
} }
let mut guard = pool.refill_inflight.lock().await; let mut guard = pool.refill_inflight.lock().await;
guard.remove(&addr); guard.remove(&endpoint_key);
drop(guard); drop(guard);
if !dc_keys.is_empty() {
let mut dc_guard = pool.refill_inflight_dc.lock().await; let mut dc_guard = pool.refill_inflight_dc.lock().await;
for key in &dc_keys { dc_guard.remove(&dc_key);
dc_guard.remove(key);
}
}
}); });
} }
} }

View File

@ -4,6 +4,7 @@ use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::io::ErrorKind; use std::io::ErrorKind;
use bytes::Bytes;
use bytes::BytesMut; use bytes::BytesMut;
use rand::Rng; use rand::Rng;
use tokio::sync::mpsc; use tokio::sync::mpsc;
@ -50,11 +51,22 @@ impl MePool {
} }
pub(crate) async fn connect_one(self: &Arc<Self>, addr: SocketAddr, rng: &SecureRandom) -> Result<()> { pub(crate) async fn connect_one(self: &Arc<Self>, addr: SocketAddr, rng: &SecureRandom) -> Result<()> {
let writer_dc = self.resolve_dc_for_endpoint(addr).await;
self.connect_one_for_dc(addr, writer_dc, rng).await
}
pub(crate) async fn connect_one_for_dc(
self: &Arc<Self>,
addr: SocketAddr,
writer_dc: i32,
rng: &SecureRandom,
) -> Result<()> {
self.connect_one_with_generation_contour( self.connect_one_with_generation_contour(
addr, addr,
rng, rng,
self.current_generation(), self.current_generation(),
WriterContour::Active, WriterContour::Active,
writer_dc,
) )
.await .await
} }
@ -65,13 +77,27 @@ impl MePool {
rng: &SecureRandom, rng: &SecureRandom,
generation: u64, generation: u64,
contour: WriterContour, contour: WriterContour,
writer_dc: i32,
) -> Result<()> {
self.connect_one_with_generation_contour_for_dc(addr, rng, generation, contour, writer_dc)
.await
}
pub(super) async fn connect_one_with_generation_contour_for_dc(
self: &Arc<Self>,
addr: SocketAddr,
rng: &SecureRandom,
generation: u64,
contour: WriterContour,
writer_dc: i32,
) -> Result<()> { ) -> Result<()> {
let secret_len = self.proxy_secret.read().await.secret.len(); let secret_len = self.proxy_secret.read().await.secret.len();
if secret_len < 32 { if secret_len < 32 {
return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into())); return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into()));
} }
let (stream, _connect_ms, upstream_egress) = self.connect_tcp(addr).await?; let dc_idx = i16::try_from(writer_dc).ok();
let (stream, _connect_ms, upstream_egress) = self.connect_tcp(addr, dc_idx).await?;
let hs = self.handshake_only(stream, addr, upstream_egress, rng).await?; let hs = self.handshake_only(stream, addr, upstream_egress, rng).await?;
let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed); let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed);
@ -80,6 +106,7 @@ impl MePool {
let degraded = Arc::new(AtomicBool::new(false)); let degraded = Arc::new(AtomicBool::new(false));
let draining = Arc::new(AtomicBool::new(false)); let draining = Arc::new(AtomicBool::new(false));
let draining_started_at_epoch_secs = Arc::new(AtomicU64::new(0)); let draining_started_at_epoch_secs = Arc::new(AtomicU64::new(0));
let drain_deadline_epoch_secs = Arc::new(AtomicU64::new(0));
let allow_drain_fallback = Arc::new(AtomicBool::new(false)); let allow_drain_fallback = Arc::new(AtomicBool::new(false));
let (tx, mut rx) = mpsc::channel::<WriterCommand>(4096); let (tx, mut rx) = mpsc::channel::<WriterCommand>(4096);
let mut rpc_writer = RpcWriter { let mut rpc_writer = RpcWriter {
@ -111,6 +138,7 @@ impl MePool {
let writer = MeWriter { let writer = MeWriter {
id: writer_id, id: writer_id,
addr, addr,
writer_dc,
generation, generation,
contour: contour.clone(), contour: contour.clone(),
created_at: Instant::now(), created_at: Instant::now(),
@ -119,6 +147,7 @@ impl MePool {
degraded: degraded.clone(), degraded: degraded.clone(),
draining: draining.clone(), draining: draining.clone(),
draining_started_at_epoch_secs: draining_started_at_epoch_secs.clone(), draining_started_at_epoch_secs: draining_started_at_epoch_secs.clone(),
drain_deadline_epoch_secs: drain_deadline_epoch_secs.clone(),
allow_drain_fallback: allow_drain_fallback.clone(), allow_drain_fallback: allow_drain_fallback.clone(),
}; };
self.writers.write().await.push(writer.clone()); self.writers.write().await.push(writer.clone());
@ -254,17 +283,47 @@ impl MePool {
p.extend_from_slice(&sent_id.to_le_bytes()); p.extend_from_slice(&sent_id.to_le_bytes());
{ {
let mut tracker = ping_tracker_ping.lock().await; let mut tracker = ping_tracker_ping.lock().await;
let now_epoch_ms = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let mut run_cleanup = false;
if let Some(pool) = pool_ping.upgrade() {
let last_cleanup_ms = pool
.ping_tracker_last_cleanup_epoch_ms
.load(Ordering::Relaxed);
if now_epoch_ms.saturating_sub(last_cleanup_ms) >= 30_000
&& pool
.ping_tracker_last_cleanup_epoch_ms
.compare_exchange(
last_cleanup_ms,
now_epoch_ms,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
run_cleanup = true;
}
}
if run_cleanup {
let before = tracker.len(); let before = tracker.len();
tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120)); tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120));
let expired = before.saturating_sub(tracker.len()); let expired = before.saturating_sub(tracker.len());
if expired > 0 { if expired > 0 {
stats_ping.increment_me_keepalive_timeout_by(expired as u64); stats_ping.increment_me_keepalive_timeout_by(expired as u64);
} }
}
tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); tracker.insert(sent_id, (std::time::Instant::now(), writer_id));
} }
ping_id = ping_id.wrapping_add(1); ping_id = ping_id.wrapping_add(1);
stats_ping.increment_me_keepalive_sent(); stats_ping.increment_me_keepalive_sent();
if tx_ping.send(WriterCommand::DataAndFlush(p)).await.is_err() { if tx_ping
.send(WriterCommand::DataAndFlush(Bytes::from(p)))
.await
.is_err()
{
stats_ping.increment_me_keepalive_failed(); stats_ping.increment_me_keepalive_failed();
debug!("ME ping failed, removing dead writer"); debug!("ME ping failed, removing dead writer");
cancel_ping.cancel(); cancel_ping.cancel();
@ -338,7 +397,11 @@ impl MePool {
meta.proto_flags, meta.proto_flags,
); );
if tx_signal.send(WriterCommand::DataAndFlush(payload)).await.is_err() { if tx_signal
.send(WriterCommand::DataAndFlush(payload))
.await
.is_err()
{
stats_signal.increment_me_rpc_proxy_req_signal_failed_total(); stats_signal.increment_me_rpc_proxy_req_signal_failed_total();
let _ = pool.registry.unregister(conn_id).await; let _ = pool.registry.unregister(conn_id).await;
cancel_signal.cancel(); cancel_signal.cancel();
@ -369,7 +432,7 @@ impl MePool {
close_payload.extend_from_slice(&conn_id.to_le_bytes()); close_payload.extend_from_slice(&conn_id.to_le_bytes());
if tx_signal if tx_signal
.send(WriterCommand::DataAndFlush(close_payload)) .send(WriterCommand::DataAndFlush(Bytes::from(close_payload)))
.await .await
.is_err() .is_err()
{ {
@ -404,6 +467,7 @@ impl MePool {
async fn remove_writer_only(self: &Arc<Self>, writer_id: u64) -> Vec<BoundConn> { async fn remove_writer_only(self: &Arc<Self>, writer_id: u64) -> Vec<BoundConn> {
let mut close_tx: Option<mpsc::Sender<WriterCommand>> = None; let mut close_tx: Option<mpsc::Sender<WriterCommand>> = None;
let mut removed_addr: Option<SocketAddr> = None; let mut removed_addr: Option<SocketAddr> = None;
let mut removed_dc: Option<i32> = None;
let mut removed_uptime: Option<Duration> = None; let mut removed_uptime: Option<Duration> = None;
let mut trigger_refill = false; let mut trigger_refill = false;
{ {
@ -417,6 +481,7 @@ impl MePool {
self.stats.increment_me_writer_removed_total(); self.stats.increment_me_writer_removed_total();
w.cancel.cancel(); w.cancel.cancel();
removed_addr = Some(w.addr); removed_addr = Some(w.addr);
removed_dc = Some(w.writer_dc);
removed_uptime = Some(w.created_at.elapsed()); removed_uptime = Some(w.created_at.elapsed());
trigger_refill = !was_draining; trigger_refill = !was_draining;
if trigger_refill { if trigger_refill {
@ -431,11 +496,12 @@ impl MePool {
} }
if trigger_refill if trigger_refill
&& let Some(addr) = removed_addr && let Some(addr) = removed_addr
&& let Some(writer_dc) = removed_dc
{ {
if let Some(uptime) = removed_uptime { if let Some(uptime) = removed_uptime {
self.maybe_quarantine_flapping_endpoint(addr, uptime).await; self.maybe_quarantine_flapping_endpoint(addr, uptime).await;
} }
self.trigger_immediate_refill(addr); self.trigger_immediate_refill_for_dc(addr, writer_dc);
} }
self.rtt_stats.lock().await.remove(&writer_id); self.rtt_stats.lock().await.remove(&writer_id);
self.registry.writer_lost(writer_id).await self.registry.writer_lost(writer_id).await
@ -454,8 +520,14 @@ impl MePool {
let already_draining = w.draining.swap(true, Ordering::Relaxed); let already_draining = w.draining.swap(true, Ordering::Relaxed);
w.allow_drain_fallback w.allow_drain_fallback
.store(allow_drain_fallback, Ordering::Relaxed); .store(allow_drain_fallback, Ordering::Relaxed);
let now_epoch_secs = Self::now_epoch_secs();
w.draining_started_at_epoch_secs w.draining_started_at_epoch_secs
.store(Self::now_epoch_secs(), Ordering::Relaxed); .store(now_epoch_secs, Ordering::Relaxed);
let drain_deadline_epoch_secs = timeout
.map(|duration| now_epoch_secs.saturating_add(duration.as_secs()))
.unwrap_or(0);
w.drain_deadline_epoch_secs
.store(drain_deadline_epoch_secs, Ordering::Relaxed);
if !already_draining { if !already_draining {
self.stats.increment_pool_drain_active(); self.stats.increment_pool_drain_active();
} }
@ -479,26 +551,6 @@ impl MePool {
allow_drain_fallback, allow_drain_fallback,
"ME writer marked draining" "ME writer marked draining"
); );
let pool = Arc::downgrade(self);
tokio::spawn(async move {
let deadline = timeout.map(|t| Instant::now() + t);
while let Some(p) = pool.upgrade() {
if let Some(deadline_at) = deadline
&& Instant::now() >= deadline_at
{
warn!(writer_id, "Drain timeout, force-closing");
p.stats.increment_pool_force_close_total();
let _ = p.remove_writer_and_close_clients(writer_id).await;
break;
}
if p.registry.is_writer_empty(writer_id).await {
let _ = p.remove_writer_only(writer_id).await;
break;
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
} }
pub(crate) async fn mark_writer_draining(self: &Arc<Self>, writer_id: u64) { pub(crate) async fn mark_writer_draining(self: &Arc<Self>, writer_id: u64) {

View File

@ -181,7 +181,11 @@ pub(crate) async fn reader_loop(
let mut pong = Vec::with_capacity(12); let mut pong = Vec::with_capacity(12);
pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes()); pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes());
pong.extend_from_slice(&ping_id.to_le_bytes()); pong.extend_from_slice(&ping_id.to_le_bytes());
if tx.send(WriterCommand::DataAndFlush(pong)).await.is_err() { if tx
.send(WriterCommand::DataAndFlush(Bytes::from(pong)))
.await
.is_err()
{
warn!("PONG send failed"); warn!("PONG send failed");
break; break;
} }
@ -222,5 +226,5 @@ async fn send_close_conn(tx: &mpsc::Sender<WriterCommand>, conn_id: u64) {
p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes()); p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes());
let _ = tx.send(WriterCommand::DataAndFlush(p)).await; let _ = tx.send(WriterCommand::DataAndFlush(Bytes::from(p))).await;
} }

View File

@ -264,6 +264,20 @@ impl ConnRegistry {
inner.writer_idle_since_epoch_secs.clone() inner.writer_idle_since_epoch_secs.clone()
} }
pub async fn writer_idle_since_for_writer_ids(
&self,
writer_ids: &[u64],
) -> HashMap<u64, u64> {
let inner = self.inner.read().await;
let mut out = HashMap::<u64, u64>::with_capacity(writer_ids.len());
for writer_id in writer_ids {
if let Some(idle_since) = inner.writer_idle_since_epoch_secs.get(writer_id).copied() {
out.insert(*writer_id, idle_since);
}
}
out
}
pub(super) async fn writer_activity_snapshot(&self) -> WriterActivitySnapshot { pub(super) async fn writer_activity_snapshot(&self) -> WriterActivitySnapshot {
let inner = self.inner.read().await; let inner = self.inner.read().await;
let mut bound_clients_by_writer = HashMap::<u64, usize>::new(); let mut bound_clients_by_writer = HashMap::<u64, usize>::new();

View File

@ -5,6 +5,7 @@ use std::sync::Arc;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use bytes::Bytes;
use tokio::sync::mpsc::error::TrySendError; use tokio::sync::mpsc::error::TrySendError;
use tracing::{debug, warn}; use tracing::{debug, warn};
@ -59,6 +60,7 @@ impl MePool {
let mut hybrid_recovery_round = 0u32; let mut hybrid_recovery_round = 0u32;
let mut hybrid_last_recovery_at: Option<Instant> = None; let mut hybrid_last_recovery_at: Option<Instant> = None;
let hybrid_wait_step = self.me_route_no_writer_wait.max(Duration::from_millis(50)); let hybrid_wait_step = self.me_route_no_writer_wait.max(Duration::from_millis(50));
let mut hybrid_wait_current = hybrid_wait_step;
loop { loop {
if let Some(current) = self.registry.get_writer(conn_id).await { if let Some(current) = self.registry.get_writer(conn_id).await {
@ -147,11 +149,14 @@ impl MePool {
target_dc, target_dc,
&mut hybrid_recovery_round, &mut hybrid_recovery_round,
&mut hybrid_last_recovery_at, &mut hybrid_last_recovery_at,
hybrid_wait_step, hybrid_wait_current,
) )
.await; .await;
let deadline = Instant::now() + hybrid_wait_step; let deadline = Instant::now() + hybrid_wait_current;
let _ = self.wait_for_writer_until(deadline).await; let _ = self.wait_for_writer_until(deadline).await;
hybrid_wait_current =
(hybrid_wait_current.saturating_mul(2))
.min(Duration::from_millis(400));
continue; continue;
} }
} }
@ -223,16 +228,26 @@ impl MePool {
target_dc, target_dc,
&mut hybrid_recovery_round, &mut hybrid_recovery_round,
&mut hybrid_last_recovery_at, &mut hybrid_last_recovery_at,
hybrid_wait_step, hybrid_wait_current,
) )
.await; .await;
let deadline = Instant::now() + hybrid_wait_step; let deadline = Instant::now() + hybrid_wait_current;
let _ = self.wait_for_candidate_until(target_dc, deadline).await; let _ = self.wait_for_candidate_until(target_dc, deadline).await;
hybrid_wait_current = (hybrid_wait_current.saturating_mul(2))
.min(Duration::from_millis(400));
continue; continue;
} }
} }
} }
let writer_idle_since = self.registry.writer_idle_since_snapshot().await; hybrid_wait_current = hybrid_wait_step;
let writer_ids: Vec<u64> = candidate_indices
.iter()
.map(|idx| writers_snapshot[*idx].id)
.collect();
let writer_idle_since = self
.registry
.writer_idle_since_for_writer_ids(&writer_ids)
.await;
let now_epoch_secs = Self::now_epoch_secs(); let now_epoch_secs = Self::now_epoch_secs();
if self.me_deterministic_writer_sort.load(Ordering::Relaxed) { if self.me_deterministic_writer_sort.load(Ordering::Relaxed) {
@ -507,7 +522,11 @@ impl MePool {
let mut p = Vec::with_capacity(12); let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes());
if w.tx.send(WriterCommand::DataAndFlush(p)).await.is_err() { if w.tx
.send(WriterCommand::DataAndFlush(Bytes::from(p)))
.await
.is_err()
{
debug!("ME close write failed"); debug!("ME close write failed");
self.remove_writer_and_close_clients(w.writer_id).await; self.remove_writer_and_close_clients(w.writer_id).await;
} }
@ -524,7 +543,7 @@ impl MePool {
let mut p = Vec::with_capacity(12); let mut p = Vec::with_capacity(12);
p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes()); p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes());
p.extend_from_slice(&conn_id.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes());
match w.tx.try_send(WriterCommand::DataAndFlush(p)) { match w.tx.try_send(WriterCommand::DataAndFlush(Bytes::from(p))) {
Ok(()) => {} Ok(()) => {}
Err(TrySendError::Full(cmd)) => { Err(TrySendError::Full(cmd)) => {
let _ = tokio::time::timeout(Duration::from_millis(50), w.tx.send(cmd)).await; let _ = tokio::time::timeout(Duration::from_millis(50), w.tx.send(cmd)).await;

View File

@ -1,4 +1,5 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use bytes::Bytes;
use crate::protocol::constants::*; use crate::protocol::constants::*;
@ -48,7 +49,7 @@ pub(crate) fn build_proxy_req_payload(
data: &[u8], data: &[u8],
proxy_tag: Option<&[u8]>, proxy_tag: Option<&[u8]>,
proto_flags: u32, proto_flags: u32,
) -> Vec<u8> { ) -> Bytes {
let mut b = Vec::with_capacity(128 + data.len()); let mut b = Vec::with_capacity(128 + data.len());
b.extend_from_slice(&RPC_PROXY_REQ_U32.to_le_bytes()); b.extend_from_slice(&RPC_PROXY_REQ_U32.to_le_bytes());
@ -85,7 +86,7 @@ pub(crate) fn build_proxy_req_payload(
} }
b.extend_from_slice(data); b.extend_from_slice(data);
b Bytes::from(b)
} }
pub fn proto_flags_for_tag(tag: crate::protocol::constants::ProtoTag, has_proxy_tag: bool) -> u32 { pub fn proto_flags_for_tag(tag: crate::protocol::constants::ProtoTag, has_proxy_tag: bool) -> u32 {

View File

@ -7,7 +7,7 @@
use std::collections::{BTreeSet, HashMap}; use std::collections::{BTreeSet, HashMap};
use std::net::{SocketAddr, IpAddr}; use std::net::{SocketAddr, IpAddr};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::Duration; use std::time::Duration;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::RwLock; use tokio::sync::RwLock;
@ -237,6 +237,8 @@ pub struct UpstreamManager {
connect_budget: Duration, connect_budget: Duration,
unhealthy_fail_threshold: u32, unhealthy_fail_threshold: u32,
connect_failfast_hard_errors: bool, connect_failfast_hard_errors: bool,
no_upstreams_warn_epoch_ms: Arc<AtomicU64>,
no_healthy_warn_epoch_ms: Arc<AtomicU64>,
stats: Arc<Stats>, stats: Arc<Stats>,
} }
@ -262,10 +264,35 @@ impl UpstreamManager {
connect_budget: Duration::from_millis(connect_budget_ms.max(1)), connect_budget: Duration::from_millis(connect_budget_ms.max(1)),
unhealthy_fail_threshold: unhealthy_fail_threshold.max(1), unhealthy_fail_threshold: unhealthy_fail_threshold.max(1),
connect_failfast_hard_errors, connect_failfast_hard_errors,
no_upstreams_warn_epoch_ms: Arc::new(AtomicU64::new(0)),
no_healthy_warn_epoch_ms: Arc::new(AtomicU64::new(0)),
stats, stats,
} }
} }
fn now_epoch_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
fn should_emit_warn(last_epoch_ms: &AtomicU64, cooldown_ms: u64) -> bool {
let now_epoch_ms = Self::now_epoch_ms();
let previous_epoch_ms = last_epoch_ms.load(Ordering::Relaxed);
if now_epoch_ms.saturating_sub(previous_epoch_ms) < cooldown_ms {
return false;
}
last_epoch_ms
.compare_exchange(
previous_epoch_ms,
now_epoch_ms,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
}
pub fn try_api_snapshot(&self) -> Option<UpstreamApiSnapshot> { pub fn try_api_snapshot(&self) -> Option<UpstreamApiSnapshot> {
let guard = self.upstreams.try_read().ok()?; let guard = self.upstreams.try_read().ok()?;
let now = std::time::Instant::now(); let now = std::time::Instant::now();
@ -533,12 +560,22 @@ impl UpstreamManager {
.collect(); .collect();
if filtered_upstreams.is_empty() { if filtered_upstreams.is_empty() {
if Self::should_emit_warn(
self.no_upstreams_warn_epoch_ms.as_ref(),
5_000,
) {
warn!(scope = scope, "No upstreams available! Using first (direct?)"); warn!(scope = scope, "No upstreams available! Using first (direct?)");
}
return None; return None;
} }
if healthy.is_empty() { if healthy.is_empty() {
if Self::should_emit_warn(
self.no_healthy_warn_epoch_ms.as_ref(),
5_000,
) {
warn!(scope = scope, "No healthy upstreams available! Using random."); warn!(scope = scope, "No healthy upstreams available! Using random.");
}
return Some(filtered_upstreams[rand::rng().gen_range(0..filtered_upstreams.len())]); return Some(filtered_upstreams[rand::rng().gen_range(0..filtered_upstreams.len())]);
} }