diff --git a/src/ip_tracker.rs b/src/ip_tracker.rs index 5da8222..d406d51 100644 --- a/src/ip_tracker.rs +++ b/src/ip_tracker.rs @@ -5,6 +5,7 @@ use std::collections::HashMap; use std::net::IpAddr; use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, Instant}; use tokio::sync::RwLock; @@ -18,6 +19,7 @@ pub struct UserIpTracker { max_ips: Arc>>, limit_mode: Arc>, limit_window: Arc>, + last_compact_epoch_secs: Arc, } impl UserIpTracker { @@ -28,6 +30,54 @@ impl UserIpTracker { max_ips: Arc::new(RwLock::new(HashMap::new())), limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)), limit_window: Arc::new(RwLock::new(Duration::from_secs(30))), + last_compact_epoch_secs: Arc::new(AtomicU64::new(0)), + } + } + + fn now_epoch_secs() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + } + + async fn maybe_compact_empty_users(&self) { + const COMPACT_INTERVAL_SECS: u64 = 60; + let now_epoch_secs = Self::now_epoch_secs(); + let last_compact_epoch_secs = self.last_compact_epoch_secs.load(Ordering::Relaxed); + if now_epoch_secs.saturating_sub(last_compact_epoch_secs) < COMPACT_INTERVAL_SECS { + return; + } + if self + .last_compact_epoch_secs + .compare_exchange( + last_compact_epoch_secs, + now_epoch_secs, + Ordering::AcqRel, + Ordering::Relaxed, + ) + .is_err() + { + return; + } + + let mut active_ips = self.active_ips.write().await; + let mut recent_ips = self.recent_ips.write().await; + let mut users = Vec::::with_capacity(active_ips.len().saturating_add(recent_ips.len())); + users.extend(active_ips.keys().cloned()); + for user in recent_ips.keys() { + if !active_ips.contains_key(user) { + users.push(user.clone()); + } + } + + for user in users { + let active_empty = active_ips.get(&user).map(|ips| ips.is_empty()).unwrap_or(true); + let recent_empty = recent_ips.get(&user).map(|ips| ips.is_empty()).unwrap_or(true); + if active_empty && recent_empty { + active_ips.remove(&user); + recent_ips.remove(&user); + } } } @@ -63,6 +113,7 @@ impl UserIpTracker { } pub async fn check_and_add(&self, username: &str, ip: IpAddr) -> Result<(), String> { + self.maybe_compact_empty_users().await; let limit = { let max_ips = self.max_ips.read().await; max_ips.get(username).copied() @@ -116,6 +167,7 @@ impl UserIpTracker { } pub async fn remove_ip(&self, username: &str, ip: IpAddr) { + self.maybe_compact_empty_users().await; let mut active_ips = self.active_ips.write().await; if let Some(user_ips) = active_ips.get_mut(username) { if let Some(count) = user_ips.get_mut(&ip) { diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 8384e32..707c8af 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -6,6 +6,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Mutex, OnceLock}; use std::time::{Duration, Instant}; +use bytes::Bytes; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::{mpsc, oneshot}; use tracing::{debug, trace, warn}; @@ -20,7 +21,7 @@ use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; enum C2MeCommand { - Data { payload: Vec, flags: u32 }, + Data { payload: Bytes, flags: u32 }, Close, } @@ -283,7 +284,7 @@ where success.dc_idx, peer, translated_local_addr, - &payload, + payload.as_ref(), flags, effective_tag.as_deref(), ).await?; @@ -479,7 +480,7 @@ async fn read_client_payload( forensics: &RelayForensicsState, frame_counter: &mut u64, stats: &Stats, -) -> Result, bool)>> +) -> Result> where R: AsyncRead + Unpin + Send + 'static, { @@ -578,7 +579,7 @@ where payload.truncate(secure_payload_len); } *frame_counter += 1; - return Ok(Some((payload, quickack))); + return Ok(Some((Bytes::from(payload), quickack))); } } @@ -715,7 +716,7 @@ mod tests { enqueue_c2me_command( &tx, C2MeCommand::Data { - payload: vec![1, 2, 3], + payload: Bytes::from_static(&[1, 2, 3]), flags: 0, }, ) @@ -728,7 +729,7 @@ mod tests { .unwrap(); match recv { C2MeCommand::Data { payload, flags } => { - assert_eq!(payload, vec![1, 2, 3]); + assert_eq!(payload.as_ref(), &[1, 2, 3]); assert_eq!(flags, 0); } C2MeCommand::Close => panic!("unexpected close command"), @@ -739,7 +740,7 @@ mod tests { async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { let (tx, mut rx) = mpsc::channel::(1); tx.send(C2MeCommand::Data { - payload: vec![9], + payload: Bytes::from_static(&[9]), flags: 9, }) .await @@ -750,7 +751,7 @@ mod tests { enqueue_c2me_command( &tx2, C2MeCommand::Data { - payload: vec![7, 7], + payload: Bytes::from_static(&[7, 7]), flags: 7, }, ) @@ -769,7 +770,7 @@ mod tests { .unwrap(); match recv { C2MeCommand::Data { payload, flags } => { - assert_eq!(payload, vec![7, 7]); + assert_eq!(payload.as_ref(), &[7, 7]); assert_eq!(flags, 7); } C2MeCommand::Close => panic!("unexpected close command"), diff --git a/src/stats/mod.rs b/src/stats/mod.rs index b51c941..fbfc987 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -6,7 +6,7 @@ pub mod beobachten; pub mod telemetry; use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering}; -use std::time::{Instant, Duration}; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; use dashmap::DashMap; use parking_lot::Mutex; use lru::LruCache; @@ -119,6 +119,7 @@ pub struct Stats { telemetry_user_enabled: AtomicBool, telemetry_me_level: AtomicU8, user_stats: DashMap, + user_stats_last_cleanup_epoch_secs: AtomicU64, start_time: parking_lot::RwLock>, } @@ -130,6 +131,7 @@ pub struct UserStats { pub octets_to_client: AtomicU64, pub msgs_from_client: AtomicU64, pub msgs_to_client: AtomicU64, + pub last_seen_epoch_secs: AtomicU64, } impl Stats { @@ -178,6 +180,54 @@ impl Stats { } } + fn now_epoch_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + } + + fn touch_user_stats(stats: &UserStats) { + stats + .last_seen_epoch_secs + .store(Self::now_epoch_secs(), Ordering::Relaxed); + } + + fn maybe_cleanup_user_stats(&self) { + const USER_STATS_CLEANUP_INTERVAL_SECS: u64 = 60; + const USER_STATS_IDLE_TTL_SECS: u64 = 24 * 60 * 60; + + let now_epoch_secs = Self::now_epoch_secs(); + let last_cleanup_epoch_secs = self + .user_stats_last_cleanup_epoch_secs + .load(Ordering::Relaxed); + if now_epoch_secs.saturating_sub(last_cleanup_epoch_secs) + < USER_STATS_CLEANUP_INTERVAL_SECS + { + return; + } + if self + .user_stats_last_cleanup_epoch_secs + .compare_exchange( + last_cleanup_epoch_secs, + now_epoch_secs, + Ordering::AcqRel, + Ordering::Relaxed, + ) + .is_err() + { + return; + } + + self.user_stats.retain(|_, stats| { + if stats.curr_connects.load(Ordering::Relaxed) > 0 { + return true; + } + let last_seen_epoch_secs = stats.last_seen_epoch_secs.load(Ordering::Relaxed); + now_epoch_secs.saturating_sub(last_seen_epoch_secs) <= USER_STATS_IDLE_TTL_SECS + }); + } + pub fn apply_telemetry_policy(&self, policy: TelemetryPolicy) { self.telemetry_core_enabled .store(policy.core_enabled, Ordering::Relaxed); @@ -970,34 +1020,36 @@ impl Stats { if !self.telemetry_user_enabled() { return; } + self.maybe_cleanup_user_stats(); if let Some(stats) = self.user_stats.get(user) { + Self::touch_user_stats(stats.value()); stats.connects.fetch_add(1, Ordering::Relaxed); return; } - self.user_stats - .entry(user.to_string()) - .or_default() - .connects - .fetch_add(1, Ordering::Relaxed); + let stats = self.user_stats.entry(user.to_string()).or_default(); + Self::touch_user_stats(stats.value()); + stats.connects.fetch_add(1, Ordering::Relaxed); } pub fn increment_user_curr_connects(&self, user: &str) { if !self.telemetry_user_enabled() { return; } + self.maybe_cleanup_user_stats(); if let Some(stats) = self.user_stats.get(user) { + Self::touch_user_stats(stats.value()); stats.curr_connects.fetch_add(1, Ordering::Relaxed); return; } - self.user_stats - .entry(user.to_string()) - .or_default() - .curr_connects - .fetch_add(1, Ordering::Relaxed); + let stats = self.user_stats.entry(user.to_string()).or_default(); + Self::touch_user_stats(stats.value()); + stats.curr_connects.fetch_add(1, Ordering::Relaxed); } pub fn decrement_user_curr_connects(&self, user: &str) { + self.maybe_cleanup_user_stats(); if let Some(stats) = self.user_stats.get(user) { + Self::touch_user_stats(stats.value()); let counter = &stats.curr_connects; let mut current = counter.load(Ordering::Relaxed); loop { @@ -1027,60 +1079,60 @@ impl Stats { if !self.telemetry_user_enabled() { return; } + self.maybe_cleanup_user_stats(); if let Some(stats) = self.user_stats.get(user) { + Self::touch_user_stats(stats.value()); stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); return; } - self.user_stats - .entry(user.to_string()) - .or_default() - .octets_from_client - .fetch_add(bytes, Ordering::Relaxed); + let stats = self.user_stats.entry(user.to_string()).or_default(); + Self::touch_user_stats(stats.value()); + stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); } pub fn add_user_octets_to(&self, user: &str, bytes: u64) { if !self.telemetry_user_enabled() { return; } + self.maybe_cleanup_user_stats(); if let Some(stats) = self.user_stats.get(user) { + Self::touch_user_stats(stats.value()); stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); return; } - self.user_stats - .entry(user.to_string()) - .or_default() - .octets_to_client - .fetch_add(bytes, Ordering::Relaxed); + let stats = self.user_stats.entry(user.to_string()).or_default(); + Self::touch_user_stats(stats.value()); + stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); } pub fn increment_user_msgs_from(&self, user: &str) { if !self.telemetry_user_enabled() { return; } + self.maybe_cleanup_user_stats(); if let Some(stats) = self.user_stats.get(user) { + Self::touch_user_stats(stats.value()); stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); return; } - self.user_stats - .entry(user.to_string()) - .or_default() - .msgs_from_client - .fetch_add(1, Ordering::Relaxed); + let stats = self.user_stats.entry(user.to_string()).or_default(); + Self::touch_user_stats(stats.value()); + stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); } pub fn increment_user_msgs_to(&self, user: &str) { if !self.telemetry_user_enabled() { return; } + self.maybe_cleanup_user_stats(); if let Some(stats) = self.user_stats.get(user) { + Self::touch_user_stats(stats.value()); stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); return; } - self.user_stats - .entry(user.to_string()) - .or_default() - .msgs_to_client - .fetch_add(1, Ordering::Relaxed); + let stats = self.user_stats.entry(user.to_string()).or_default(); + Self::touch_user_stats(stats.value()); + stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); } pub fn get_user_total_octets(&self, user: &str) -> u64 { diff --git a/src/transport/middle_proxy/codec.rs b/src/transport/middle_proxy/codec.rs index 6df0466..7f51aaa 100644 --- a/src/transport/middle_proxy/codec.rs +++ b/src/transport/middle_proxy/codec.rs @@ -1,4 +1,5 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use bytes::Bytes; use crate::crypto::{AesCbc, crc32, crc32c}; use crate::error::{ProxyError, Result}; @@ -6,8 +7,8 @@ use crate::protocol::constants::*; /// Commands sent to dedicated writer tasks to avoid mutex contention on TCP writes. pub(crate) enum WriterCommand { - Data(Vec), - DataAndFlush(Vec), + Data(Bytes), + DataAndFlush(Bytes), Close, } diff --git a/src/transport/middle_proxy/handshake.rs b/src/transport/middle_proxy/handshake.rs index 948c999..f556b99 100644 --- a/src/transport/middle_proxy/handshake.rs +++ b/src/transport/middle_proxy/handshake.rs @@ -135,10 +135,15 @@ impl MePool { pub(crate) async fn connect_tcp( &self, addr: SocketAddr, + dc_idx_override: Option, ) -> Result<(TcpStream, f64, Option)> { let start = Instant::now(); let (stream, upstream_egress) = if let Some(upstream) = &self.upstream { - let dc_idx = self.resolve_dc_idx_for_endpoint(addr).await; + let dc_idx = if let Some(dc_idx) = dc_idx_override { + Some(dc_idx) + } else { + self.resolve_dc_idx_for_endpoint(addr).await + }; let (stream, egress) = upstream.connect_with_details(addr, dc_idx, None).await?; (stream, Some(egress)) } else { diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index 8f4ad95..4fcba39 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -60,6 +60,7 @@ pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_c loop { tokio::time::sleep(Duration::from_secs(HEALTH_INTERVAL_SECS)).await; pool.prune_closed_writers().await; + reap_draining_writers(&pool).await; check_family( IpFamily::V4, &pool, @@ -95,6 +96,28 @@ pub async fn me_health_monitor(pool: Arc, rng: Arc, _min_c } } +async fn reap_draining_writers(pool: &Arc) { + let now_epoch_secs = MePool::now_epoch_secs(); + let writers = pool.writers.read().await.clone(); + for writer in writers { + if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) { + continue; + } + if pool.registry.is_writer_empty(writer.id).await { + pool.remove_writer_and_close_clients(writer.id).await; + continue; + } + let deadline_epoch_secs = writer + .drain_deadline_epoch_secs + .load(std::sync::atomic::Ordering::Relaxed); + if deadline_epoch_secs != 0 && now_epoch_secs >= deadline_epoch_secs { + warn!(writer_id = writer.id, "Drain timeout, force-closing"); + pool.stats.increment_pool_force_close_total(); + pool.remove_writer_and_close_clients(writer.id).await; + } + } +} + async fn check_family( family: IpFamily, pool: &Arc, @@ -153,12 +176,18 @@ async fn check_family( .push(writer.id); } let writer_idle_since = pool.registry.writer_idle_since_snapshot().await; + let bound_clients_by_writer = pool + .registry + .writer_activity_snapshot() + .await + .bound_clients_by_writer; let floor_plan = build_family_floor_plan( pool, family, &dc_endpoints, &live_addr_counts, &live_writer_ids_by_addr, + &bound_clients_by_writer, adaptive_idle_since, adaptive_recover_until, ) @@ -241,6 +270,7 @@ async fn check_family( required, &live_writer_ids_by_addr, &writer_idle_since, + &bound_clients_by_writer, idle_refresh_next_attempt, ) .await; @@ -254,6 +284,7 @@ async fn check_family( alive, required, &live_writer_ids_by_addr, + &bound_clients_by_writer, shadow_rotate_deadline, ) .await; @@ -320,6 +351,7 @@ async fn check_family( &endpoints, &live_writer_ids_by_addr, &writer_idle_since, + &bound_clients_by_writer, ) .await; if swapped { @@ -470,6 +502,7 @@ async fn build_family_floor_plan( dc_endpoints: &HashMap>, live_addr_counts: &HashMap, live_writer_ids_by_addr: &HashMap>, + bound_clients_by_writer: &HashMap, adaptive_idle_since: &mut HashMap<(i32, IpFamily), Instant>, adaptive_recover_until: &mut HashMap<(i32, IpFamily), Instant>, ) -> FamilyFloorPlan { @@ -491,6 +524,7 @@ async fn build_family_floor_plan( key, endpoints, live_writer_ids_by_addr, + bound_clients_by_writer, adaptive_idle_since, adaptive_recover_until, ) @@ -521,7 +555,7 @@ async fn build_family_floor_plan( .sum::(); family_active_total = family_active_total.saturating_add(alive); let writer_ids = list_writer_ids_for_endpoints(endpoints, live_writer_ids_by_addr); - let has_bound_clients = has_bound_clients_on_endpoint(pool, &writer_ids).await; + let has_bound_clients = has_bound_clients_on_endpoint(&writer_ids, bound_clients_by_writer); entries.push(DcFloorPlanEntry { dc: *dc, @@ -622,6 +656,7 @@ async fn maybe_swap_idle_writer_for_cap( endpoints: &[SocketAddr], live_writer_ids_by_addr: &HashMap>, writer_idle_since: &HashMap, + bound_clients_by_writer: &HashMap, ) -> bool { let now_epoch_secs = MePool::now_epoch_secs(); let mut candidate: Option<(u64, SocketAddr, u64)> = None; @@ -630,7 +665,7 @@ async fn maybe_swap_idle_writer_for_cap( continue; }; for writer_id in writer_ids { - if !pool.registry.is_writer_empty(*writer_id).await { + if bound_clients_by_writer.get(writer_id).copied().unwrap_or(0) > 0 { continue; } let Some(idle_since_epoch_secs) = writer_idle_since.get(writer_id).copied() else { @@ -705,6 +740,7 @@ async fn maybe_refresh_idle_writer_for_dc( required: usize, live_writer_ids_by_addr: &HashMap>, writer_idle_since: &HashMap, + bound_clients_by_writer: &HashMap, idle_refresh_next_attempt: &mut HashMap<(i32, IpFamily), Instant>, ) { if alive < required { @@ -725,6 +761,9 @@ async fn maybe_refresh_idle_writer_for_dc( continue; }; for writer_id in writer_ids { + if bound_clients_by_writer.get(writer_id).copied().unwrap_or(0) > 0 { + continue; + } let Some(idle_since_epoch_secs) = writer_idle_since.get(writer_id).copied() else { continue; }; @@ -806,6 +845,7 @@ async fn should_reduce_floor_for_idle( key: (i32, IpFamily), endpoints: &[SocketAddr], live_writer_ids_by_addr: &HashMap>, + bound_clients_by_writer: &HashMap, adaptive_idle_since: &mut HashMap<(i32, IpFamily), Instant>, adaptive_recover_until: &mut HashMap<(i32, IpFamily), Instant>, ) -> bool { @@ -817,7 +857,7 @@ async fn should_reduce_floor_for_idle( let now = Instant::now(); let writer_ids = list_writer_ids_for_endpoints(endpoints, live_writer_ids_by_addr); - let has_bound_clients = has_bound_clients_on_endpoint(pool, &writer_ids).await; + let has_bound_clients = has_bound_clients_on_endpoint(&writer_ids, bound_clients_by_writer); if has_bound_clients { adaptive_idle_since.remove(&key); adaptive_recover_until.insert(key, now + pool.adaptive_floor_recover_grace_duration()); @@ -836,13 +876,13 @@ async fn should_reduce_floor_for_idle( now.saturating_duration_since(*idle_since) >= pool.adaptive_floor_idle_duration() } -async fn has_bound_clients_on_endpoint(pool: &Arc, writer_ids: &[u64]) -> bool { - for writer_id in writer_ids { - if !pool.registry.is_writer_empty(*writer_id).await { - return true; - } - } - false +fn has_bound_clients_on_endpoint( + writer_ids: &[u64], + bound_clients_by_writer: &HashMap, +) -> bool { + writer_ids + .iter() + .any(|writer_id| bound_clients_by_writer.get(writer_id).copied().unwrap_or(0) > 0) } async fn recover_single_endpoint_outage( @@ -973,6 +1013,7 @@ async fn maybe_rotate_single_endpoint_shadow( alive: usize, required: usize, live_writer_ids_by_addr: &HashMap>, + bound_clients_by_writer: &HashMap, shadow_rotate_deadline: &mut HashMap<(i32, IpFamily), Instant>, ) { if endpoints.len() != 1 || alive < required { @@ -1011,7 +1052,7 @@ async fn maybe_rotate_single_endpoint_shadow( let mut candidate_writer_id = None; for writer_id in writer_ids { - if pool.registry.is_writer_empty(*writer_id).await { + if bound_clients_by_writer.get(writer_id).copied().unwrap_or(0) == 0 { candidate_writer_id = Some(*writer_id); break; } diff --git a/src/transport/middle_proxy/ping.rs b/src/transport/middle_proxy/ping.rs index b9f0836..2c76592 100644 --- a/src/transport/middle_proxy/ping.rs +++ b/src/transport/middle_proxy/ping.rs @@ -331,7 +331,7 @@ pub async fn run_me_ping(pool: &Arc, rng: &SecureRandom) -> Vec { connect_ms = Some(conn_rtt); route = route_from_egress(upstream_egress); diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index b3d8dc6..236a12a 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -22,10 +22,17 @@ pub(super) struct RefillDcKey { pub family: IpFamily, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub(super) struct RefillEndpointKey { + pub dc: i32, + pub addr: SocketAddr, +} + #[derive(Clone)] pub struct MeWriter { pub id: u64, pub addr: SocketAddr, + pub writer_dc: i32, pub generation: u64, pub contour: Arc, pub created_at: Instant, @@ -34,6 +41,7 @@ pub struct MeWriter { pub degraded: Arc, pub draining: Arc, pub draining_started_at_epoch_secs: Arc, + pub drain_deadline_epoch_secs: Arc, pub allow_drain_fallback: Arc, } @@ -128,12 +136,13 @@ pub struct MePool { pub(super) default_dc: AtomicI32, pub(super) next_writer_id: AtomicU64, pub(super) ping_tracker: Arc>>, + pub(super) ping_tracker_last_cleanup_epoch_ms: AtomicU64, pub(super) rtt_stats: Arc>>, pub(super) nat_reflection_cache: Arc>, pub(super) nat_reflection_singleflight_v4: Arc>, pub(super) nat_reflection_singleflight_v6: Arc>, pub(super) writer_available: Arc, - pub(super) refill_inflight: Arc>>, + pub(super) refill_inflight: Arc>>, pub(super) refill_inflight_dc: Arc>>, pub(super) conn_count: AtomicUsize, pub(super) stats: Arc, @@ -361,6 +370,7 @@ impl MePool { default_dc: AtomicI32::new(default_dc.unwrap_or(2)), next_writer_id: AtomicU64::new(1), ping_tracker: Arc::new(Mutex::new(HashMap::new())), + ping_tracker_last_cleanup_epoch_ms: AtomicU64::new(0), rtt_stats: Arc::new(Mutex::new(HashMap::new())), nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())), nat_reflection_singleflight_v4: Arc::new(Mutex::new(())), @@ -779,6 +789,36 @@ impl MePool { if dc == 0 { 2 } else { dc } } + pub(super) async fn has_configured_endpoints_for_dc(&self, dc: i32) -> bool { + if self.decision.ipv4_me { + let map = self.proxy_map_v4.read().await; + if map.get(&dc).is_some_and(|endpoints| !endpoints.is_empty()) { + return true; + } + } + + if self.decision.ipv6_me { + let map = self.proxy_map_v6.read().await; + if map.get(&dc).is_some_and(|endpoints| !endpoints.is_empty()) { + return true; + } + } + + false + } + + pub(super) async fn resolve_target_dc_for_routing(&self, target_dc: i32) -> (i32, bool) { + if target_dc == 0 { + return (self.default_dc_for_routing(), true); + } + + if self.has_configured_endpoints_for_dc(target_dc).await { + return (target_dc, false); + } + + (self.default_dc_for_routing(), true) + } + pub(super) fn dc_lookup_chain_for_target(&self, target_dc: i32) -> Vec { let mut out = Vec::with_capacity(1); if target_dc != 0 { diff --git a/src/transport/middle_proxy/pool_init.rs b/src/transport/middle_proxy/pool_init.rs index 668cfda..52cbc68 100644 --- a/src/transport/middle_proxy/pool_init.rs +++ b/src/transport/middle_proxy/pool_init.rs @@ -55,7 +55,11 @@ impl MePool { .iter() .map(|(ip, port)| SocketAddr::new(*ip, *port)) .collect(); - if self.active_writer_count_for_endpoints(&endpoints).await >= target_writers { + if self + .active_writer_count_for_dc_endpoints(dc, &endpoints) + .await + >= target_writers + { continue; } let pool = Arc::clone(self); @@ -79,7 +83,7 @@ impl MePool { .iter() .map(|(ip, port)| SocketAddr::new(*ip, *port)) .collect(); - if self.active_writer_count_for_endpoints(&endpoints).await == 0 { + if self.active_writer_count_for_dc_endpoints(*dc, &endpoints).await == 0 { missing_dcs.push(*dc); } } @@ -156,7 +160,9 @@ impl MePool { let endpoint_set: HashSet = endpoints.iter().copied().collect(); loop { - let alive = self.active_writer_count_for_endpoints(&endpoint_set).await; + let alive = self + .active_writer_count_for_dc_endpoints(dc, &endpoint_set) + .await; if alive >= target_writers { info!( dc = %dc, @@ -175,7 +181,7 @@ impl MePool { let rng_clone = Arc::clone(&rng); let endpoints_clone = endpoints.clone(); join.spawn(async move { - pool.connect_endpoints_round_robin(&endpoints_clone, rng_clone.as_ref()) + pool.connect_endpoints_round_robin(dc, &endpoints_clone, rng_clone.as_ref()) .await }); } @@ -193,7 +199,9 @@ impl MePool { } } - let alive_after = self.active_writer_count_for_endpoints(&endpoint_set).await; + let alive_after = self + .active_writer_count_for_dc_endpoints(dc, &endpoint_set) + .await; if alive_after >= target_writers { info!( dc = %dc, diff --git a/src/transport/middle_proxy/pool_refill.rs b/src/transport/middle_proxy/pool_refill.rs index 7da6acc..3c8b0bb 100644 --- a/src/transport/middle_proxy/pool_refill.rs +++ b/src/transport/middle_proxy/pool_refill.rs @@ -9,7 +9,7 @@ use tracing::{debug, info, warn}; use crate::crypto::SecureRandom; use crate::network::IpFamily; -use super::pool::{MePool, RefillDcKey, WriterContour}; +use super::pool::{MePool, RefillDcKey, RefillEndpointKey, WriterContour}; const ME_FLAP_UPTIME_THRESHOLD_SECS: u64 = 20; const ME_FLAP_QUARANTINE_SECS: u64 = 25; @@ -82,57 +82,19 @@ impl MePool { Vec::new() } - pub(super) async fn has_refill_inflight_for_endpoints(&self, endpoints: &[SocketAddr]) -> bool { - if endpoints.is_empty() { - return false; - } - - { - let guard = self.refill_inflight.lock().await; - if endpoints.iter().any(|addr| guard.contains(addr)) { - return true; - } - } - - let dc_keys = self.resolve_refill_dc_keys_for_endpoints(endpoints).await; - if dc_keys.is_empty() { - return false; - } + pub(super) async fn has_refill_inflight_for_dc_key(&self, key: RefillDcKey) -> bool { let guard = self.refill_inflight_dc.lock().await; - dc_keys.iter().any(|key| guard.contains(key)) - } - - async fn resolve_refill_dc_key_for_addr(&self, addr: SocketAddr) -> Option { - let family = if addr.is_ipv4() { - IpFamily::V4 - } else { - IpFamily::V6 - }; - Some(RefillDcKey { - dc: self.resolve_dc_for_endpoint(addr).await, - family, - }) - } - - async fn resolve_refill_dc_keys_for_endpoints( - &self, - endpoints: &[SocketAddr], - ) -> HashSet { - let mut out = HashSet::::new(); - for addr in endpoints { - if let Some(key) = self.resolve_refill_dc_key_for_addr(*addr).await { - out.insert(key); - } - } - out + guard.contains(&key) } pub(super) async fn connect_endpoints_round_robin( self: &Arc, + dc: i32, endpoints: &[SocketAddr], rng: &SecureRandom, ) -> bool { self.connect_endpoints_round_robin_with_generation_contour( + dc, endpoints, rng, self.current_generation(), @@ -143,6 +105,7 @@ impl MePool { pub(super) async fn connect_endpoints_round_robin_with_generation_contour( self: &Arc, + dc: i32, endpoints: &[SocketAddr], rng: &SecureRandom, generation: u64, @@ -157,7 +120,7 @@ impl MePool { let idx = (start + offset) % candidates.len(); let addr = candidates[idx]; match self - .connect_one_with_generation_contour(addr, rng, generation, contour) + .connect_one_with_generation_contour_for_dc(addr, rng, generation, contour, dc) .await { Ok(()) => return true, @@ -167,9 +130,8 @@ impl MePool { false } - async fn endpoints_for_same_dc(&self, addr: SocketAddr) -> Vec { + async fn endpoints_for_dc(&self, target_dc: i32) -> Vec { let mut endpoints = HashSet::::new(); - let target_dc = self.resolve_dc_for_endpoint(addr).await; if self.decision.ipv4_me { let map = self.proxy_map_v4.read().await; @@ -194,14 +156,14 @@ impl MePool { sorted } - async fn refill_writer_after_loss(self: &Arc, addr: SocketAddr) -> bool { + async fn refill_writer_after_loss(self: &Arc, addr: SocketAddr, writer_dc: i32) -> bool { let fast_retries = self.me_reconnect_fast_retry_count.max(1); let same_endpoint_quarantined = self.is_endpoint_quarantined(addr).await; if !same_endpoint_quarantined { for attempt in 0..fast_retries { self.stats.increment_me_reconnect_attempt(); - match self.connect_one(addr, self.rng.as_ref()).await { + match self.connect_one_for_dc(addr, writer_dc, self.rng.as_ref()).await { Ok(()) => { self.stats.increment_me_reconnect_success(); self.stats.increment_me_writer_restored_same_endpoint_total(); @@ -229,7 +191,7 @@ impl MePool { ); } - let dc_endpoints = self.endpoints_for_same_dc(addr).await; + let dc_endpoints = self.endpoints_for_dc(writer_dc).await; if dc_endpoints.is_empty() { self.stats.increment_me_refill_failed_total(); return false; @@ -238,7 +200,7 @@ impl MePool { for attempt in 0..fast_retries { self.stats.increment_me_reconnect_attempt(); if self - .connect_endpoints_round_robin(&dc_endpoints, self.rng.as_ref()) + .connect_endpoints_round_robin(writer_dc, &dc_endpoints, self.rng.as_ref()) .await { self.stats.increment_me_reconnect_success(); @@ -259,45 +221,69 @@ impl MePool { pub(crate) fn trigger_immediate_refill(self: &Arc, addr: SocketAddr) { let pool = Arc::clone(self); tokio::spawn(async move { - let dc_endpoints = pool.endpoints_for_same_dc(addr).await; - let dc_keys = pool.resolve_refill_dc_keys_for_endpoints(&dc_endpoints).await; + let writer_dc = pool.resolve_dc_for_endpoint(addr).await; + pool.trigger_immediate_refill_for_dc(addr, writer_dc); + }); + } - { + pub(crate) fn trigger_immediate_refill_for_dc(self: &Arc, addr: SocketAddr, writer_dc: i32) { + let endpoint_key = RefillEndpointKey { + dc: writer_dc, + addr, + }; + let pre_inserted = if let Ok(mut guard) = self.refill_inflight.try_lock() { + if !guard.insert(endpoint_key) { + self.stats.increment_me_refill_skipped_inflight_total(); + return; + } + true + } else { + false + }; + + let pool = Arc::clone(self); + tokio::spawn(async move { + let dc_endpoints = pool.endpoints_for_dc(writer_dc).await; + let dc_key = RefillDcKey { + dc: writer_dc, + family: if addr.is_ipv4() { + IpFamily::V4 + } else { + IpFamily::V6 + }, + }; + + if !pre_inserted { let mut guard = pool.refill_inflight.lock().await; - if !guard.insert(addr) { + if !guard.insert(endpoint_key) { pool.stats.increment_me_refill_skipped_inflight_total(); return; } } - if !dc_keys.is_empty() { + { let mut dc_guard = pool.refill_inflight_dc.lock().await; - if dc_keys.iter().any(|key| dc_guard.contains(key)) { + if dc_guard.contains(&dc_key) { pool.stats.increment_me_refill_skipped_inflight_total(); drop(dc_guard); let mut guard = pool.refill_inflight.lock().await; - guard.remove(&addr); + guard.remove(&endpoint_key); return; } - dc_guard.extend(dc_keys.iter().copied()); + dc_guard.insert(dc_key); } pool.stats.increment_me_refill_triggered_total(); - - let restored = pool.refill_writer_after_loss(addr).await; + let restored = pool.refill_writer_after_loss(addr, writer_dc).await; if !restored { - warn!(%addr, "ME immediate refill failed"); + warn!(%addr, dc = writer_dc, "ME immediate refill failed"); } let mut guard = pool.refill_inflight.lock().await; - guard.remove(&addr); + guard.remove(&endpoint_key); drop(guard); - if !dc_keys.is_empty() { - let mut dc_guard = pool.refill_inflight_dc.lock().await; - for key in &dc_keys { - dc_guard.remove(key); - } - } + let mut dc_guard = pool.refill_inflight_dc.lock().await; + dc_guard.remove(&dc_key); }); } } diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 1e86ea3..17ef331 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -4,6 +4,7 @@ use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering}; use std::time::{Duration, Instant}; use std::io::ErrorKind; +use bytes::Bytes; use bytes::BytesMut; use rand::Rng; use tokio::sync::mpsc; @@ -50,11 +51,22 @@ impl MePool { } pub(crate) async fn connect_one(self: &Arc, addr: SocketAddr, rng: &SecureRandom) -> Result<()> { + let writer_dc = self.resolve_dc_for_endpoint(addr).await; + self.connect_one_for_dc(addr, writer_dc, rng).await + } + + pub(crate) async fn connect_one_for_dc( + self: &Arc, + addr: SocketAddr, + writer_dc: i32, + rng: &SecureRandom, + ) -> Result<()> { self.connect_one_with_generation_contour( addr, rng, self.current_generation(), WriterContour::Active, + writer_dc, ) .await } @@ -65,13 +77,27 @@ impl MePool { rng: &SecureRandom, generation: u64, contour: WriterContour, + writer_dc: i32, + ) -> Result<()> { + self.connect_one_with_generation_contour_for_dc(addr, rng, generation, contour, writer_dc) + .await + } + + pub(super) async fn connect_one_with_generation_contour_for_dc( + self: &Arc, + addr: SocketAddr, + rng: &SecureRandom, + generation: u64, + contour: WriterContour, + writer_dc: i32, ) -> Result<()> { let secret_len = self.proxy_secret.read().await.secret.len(); if secret_len < 32 { return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into())); } - let (stream, _connect_ms, upstream_egress) = self.connect_tcp(addr).await?; + let dc_idx = i16::try_from(writer_dc).ok(); + let (stream, _connect_ms, upstream_egress) = self.connect_tcp(addr, dc_idx).await?; let hs = self.handshake_only(stream, addr, upstream_egress, rng).await?; let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed); @@ -80,6 +106,7 @@ impl MePool { let degraded = Arc::new(AtomicBool::new(false)); let draining = Arc::new(AtomicBool::new(false)); let draining_started_at_epoch_secs = Arc::new(AtomicU64::new(0)); + let drain_deadline_epoch_secs = Arc::new(AtomicU64::new(0)); let allow_drain_fallback = Arc::new(AtomicBool::new(false)); let (tx, mut rx) = mpsc::channel::(4096); let mut rpc_writer = RpcWriter { @@ -111,6 +138,7 @@ impl MePool { let writer = MeWriter { id: writer_id, addr, + writer_dc, generation, contour: contour.clone(), created_at: Instant::now(), @@ -119,6 +147,7 @@ impl MePool { degraded: degraded.clone(), draining: draining.clone(), draining_started_at_epoch_secs: draining_started_at_epoch_secs.clone(), + drain_deadline_epoch_secs: drain_deadline_epoch_secs.clone(), allow_drain_fallback: allow_drain_fallback.clone(), }; self.writers.write().await.push(writer.clone()); @@ -254,17 +283,47 @@ impl MePool { p.extend_from_slice(&sent_id.to_le_bytes()); { let mut tracker = ping_tracker_ping.lock().await; - let before = tracker.len(); - tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120)); - let expired = before.saturating_sub(tracker.len()); - if expired > 0 { - stats_ping.increment_me_keepalive_timeout_by(expired as u64); + let now_epoch_ms = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64; + let mut run_cleanup = false; + if let Some(pool) = pool_ping.upgrade() { + let last_cleanup_ms = pool + .ping_tracker_last_cleanup_epoch_ms + .load(Ordering::Relaxed); + if now_epoch_ms.saturating_sub(last_cleanup_ms) >= 30_000 + && pool + .ping_tracker_last_cleanup_epoch_ms + .compare_exchange( + last_cleanup_ms, + now_epoch_ms, + Ordering::AcqRel, + Ordering::Relaxed, + ) + .is_ok() + { + run_cleanup = true; + } + } + + if run_cleanup { + let before = tracker.len(); + tracker.retain(|_, (ts, _)| ts.elapsed() < Duration::from_secs(120)); + let expired = before.saturating_sub(tracker.len()); + if expired > 0 { + stats_ping.increment_me_keepalive_timeout_by(expired as u64); + } } tracker.insert(sent_id, (std::time::Instant::now(), writer_id)); } ping_id = ping_id.wrapping_add(1); stats_ping.increment_me_keepalive_sent(); - if tx_ping.send(WriterCommand::DataAndFlush(p)).await.is_err() { + if tx_ping + .send(WriterCommand::DataAndFlush(Bytes::from(p))) + .await + .is_err() + { stats_ping.increment_me_keepalive_failed(); debug!("ME ping failed, removing dead writer"); cancel_ping.cancel(); @@ -338,7 +397,11 @@ impl MePool { meta.proto_flags, ); - if tx_signal.send(WriterCommand::DataAndFlush(payload)).await.is_err() { + if tx_signal + .send(WriterCommand::DataAndFlush(payload)) + .await + .is_err() + { stats_signal.increment_me_rpc_proxy_req_signal_failed_total(); let _ = pool.registry.unregister(conn_id).await; cancel_signal.cancel(); @@ -369,7 +432,7 @@ impl MePool { close_payload.extend_from_slice(&conn_id.to_le_bytes()); if tx_signal - .send(WriterCommand::DataAndFlush(close_payload)) + .send(WriterCommand::DataAndFlush(Bytes::from(close_payload))) .await .is_err() { @@ -404,6 +467,7 @@ impl MePool { async fn remove_writer_only(self: &Arc, writer_id: u64) -> Vec { let mut close_tx: Option> = None; let mut removed_addr: Option = None; + let mut removed_dc: Option = None; let mut removed_uptime: Option = None; let mut trigger_refill = false; { @@ -417,6 +481,7 @@ impl MePool { self.stats.increment_me_writer_removed_total(); w.cancel.cancel(); removed_addr = Some(w.addr); + removed_dc = Some(w.writer_dc); removed_uptime = Some(w.created_at.elapsed()); trigger_refill = !was_draining; if trigger_refill { @@ -431,11 +496,12 @@ impl MePool { } if trigger_refill && let Some(addr) = removed_addr + && let Some(writer_dc) = removed_dc { if let Some(uptime) = removed_uptime { self.maybe_quarantine_flapping_endpoint(addr, uptime).await; } - self.trigger_immediate_refill(addr); + self.trigger_immediate_refill_for_dc(addr, writer_dc); } self.rtt_stats.lock().await.remove(&writer_id); self.registry.writer_lost(writer_id).await @@ -454,8 +520,14 @@ impl MePool { let already_draining = w.draining.swap(true, Ordering::Relaxed); w.allow_drain_fallback .store(allow_drain_fallback, Ordering::Relaxed); + let now_epoch_secs = Self::now_epoch_secs(); w.draining_started_at_epoch_secs - .store(Self::now_epoch_secs(), Ordering::Relaxed); + .store(now_epoch_secs, Ordering::Relaxed); + let drain_deadline_epoch_secs = timeout + .map(|duration| now_epoch_secs.saturating_add(duration.as_secs())) + .unwrap_or(0); + w.drain_deadline_epoch_secs + .store(drain_deadline_epoch_secs, Ordering::Relaxed); if !already_draining { self.stats.increment_pool_drain_active(); } @@ -479,26 +551,6 @@ impl MePool { allow_drain_fallback, "ME writer marked draining" ); - - let pool = Arc::downgrade(self); - tokio::spawn(async move { - let deadline = timeout.map(|t| Instant::now() + t); - while let Some(p) = pool.upgrade() { - if let Some(deadline_at) = deadline - && Instant::now() >= deadline_at - { - warn!(writer_id, "Drain timeout, force-closing"); - p.stats.increment_pool_force_close_total(); - let _ = p.remove_writer_and_close_clients(writer_id).await; - break; - } - if p.registry.is_writer_empty(writer_id).await { - let _ = p.remove_writer_only(writer_id).await; - break; - } - tokio::time::sleep(Duration::from_secs(1)).await; - } - }); } pub(crate) async fn mark_writer_draining(self: &Arc, writer_id: u64) { diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index 2a99164..61bd69c 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -181,7 +181,11 @@ pub(crate) async fn reader_loop( let mut pong = Vec::with_capacity(12); pong.extend_from_slice(&RPC_PONG_U32.to_le_bytes()); pong.extend_from_slice(&ping_id.to_le_bytes()); - if tx.send(WriterCommand::DataAndFlush(pong)).await.is_err() { + if tx + .send(WriterCommand::DataAndFlush(Bytes::from(pong))) + .await + .is_err() + { warn!("PONG send failed"); break; } @@ -222,5 +226,5 @@ async fn send_close_conn(tx: &mpsc::Sender, conn_id: u64) { p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes()); - let _ = tx.send(WriterCommand::DataAndFlush(p)).await; + let _ = tx.send(WriterCommand::DataAndFlush(Bytes::from(p))).await; } diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index b437885..0ee81e0 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -264,6 +264,20 @@ impl ConnRegistry { inner.writer_idle_since_epoch_secs.clone() } + pub async fn writer_idle_since_for_writer_ids( + &self, + writer_ids: &[u64], + ) -> HashMap { + let inner = self.inner.read().await; + let mut out = HashMap::::with_capacity(writer_ids.len()); + for writer_id in writer_ids { + if let Some(idle_since) = inner.writer_idle_since_epoch_secs.get(writer_id).copied() { + out.insert(*writer_id, idle_since); + } + } + out + } + pub(super) async fn writer_activity_snapshot(&self) -> WriterActivitySnapshot { let inner = self.inner.read().await; let mut bound_clients_by_writer = HashMap::::new(); diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 9ffcc8e..ccaad4a 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use std::sync::atomic::Ordering; use std::time::{Duration, Instant}; +use bytes::Bytes; use tokio::sync::mpsc::error::TrySendError; use tracing::{debug, warn}; @@ -59,6 +60,7 @@ impl MePool { let mut hybrid_recovery_round = 0u32; let mut hybrid_last_recovery_at: Option = None; let hybrid_wait_step = self.me_route_no_writer_wait.max(Duration::from_millis(50)); + let mut hybrid_wait_current = hybrid_wait_step; loop { if let Some(current) = self.registry.get_writer(conn_id).await { @@ -147,11 +149,14 @@ impl MePool { target_dc, &mut hybrid_recovery_round, &mut hybrid_last_recovery_at, - hybrid_wait_step, + hybrid_wait_current, ) .await; - let deadline = Instant::now() + hybrid_wait_step; + let deadline = Instant::now() + hybrid_wait_current; let _ = self.wait_for_writer_until(deadline).await; + hybrid_wait_current = + (hybrid_wait_current.saturating_mul(2)) + .min(Duration::from_millis(400)); continue; } } @@ -223,16 +228,26 @@ impl MePool { target_dc, &mut hybrid_recovery_round, &mut hybrid_last_recovery_at, - hybrid_wait_step, + hybrid_wait_current, ) .await; - let deadline = Instant::now() + hybrid_wait_step; + let deadline = Instant::now() + hybrid_wait_current; let _ = self.wait_for_candidate_until(target_dc, deadline).await; + hybrid_wait_current = (hybrid_wait_current.saturating_mul(2)) + .min(Duration::from_millis(400)); continue; } } } - let writer_idle_since = self.registry.writer_idle_since_snapshot().await; + hybrid_wait_current = hybrid_wait_step; + let writer_ids: Vec = candidate_indices + .iter() + .map(|idx| writers_snapshot[*idx].id) + .collect(); + let writer_idle_since = self + .registry + .writer_idle_since_for_writer_ids(&writer_ids) + .await; let now_epoch_secs = Self::now_epoch_secs(); if self.me_deterministic_writer_sort.load(Ordering::Relaxed) { @@ -507,7 +522,11 @@ impl MePool { let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_CLOSE_EXT_U32.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes()); - if w.tx.send(WriterCommand::DataAndFlush(p)).await.is_err() { + if w.tx + .send(WriterCommand::DataAndFlush(Bytes::from(p))) + .await + .is_err() + { debug!("ME close write failed"); self.remove_writer_and_close_clients(w.writer_id).await; } @@ -524,7 +543,7 @@ impl MePool { let mut p = Vec::with_capacity(12); p.extend_from_slice(&RPC_CLOSE_CONN_U32.to_le_bytes()); p.extend_from_slice(&conn_id.to_le_bytes()); - match w.tx.try_send(WriterCommand::DataAndFlush(p)) { + match w.tx.try_send(WriterCommand::DataAndFlush(Bytes::from(p))) { Ok(()) => {} Err(TrySendError::Full(cmd)) => { let _ = tokio::time::timeout(Duration::from_millis(50), w.tx.send(cmd)).await; diff --git a/src/transport/middle_proxy/wire.rs b/src/transport/middle_proxy/wire.rs index 3f78f20..7667646 100644 --- a/src/transport/middle_proxy/wire.rs +++ b/src/transport/middle_proxy/wire.rs @@ -1,4 +1,5 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use bytes::Bytes; use crate::protocol::constants::*; @@ -48,7 +49,7 @@ pub(crate) fn build_proxy_req_payload( data: &[u8], proxy_tag: Option<&[u8]>, proto_flags: u32, -) -> Vec { +) -> Bytes { let mut b = Vec::with_capacity(128 + data.len()); b.extend_from_slice(&RPC_PROXY_REQ_U32.to_le_bytes()); @@ -85,7 +86,7 @@ pub(crate) fn build_proxy_req_payload( } b.extend_from_slice(data); - b + Bytes::from(b) } pub fn proto_flags_for_tag(tag: crate::protocol::constants::ProtoTag, has_proxy_tag: bool) -> u32 { diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index 84c6fdf..2424f9c 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -7,7 +7,7 @@ use std::collections::{BTreeSet, HashMap}; use std::net::{SocketAddr, IpAddr}; use std::sync::Arc; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::time::Duration; use tokio::net::TcpStream; use tokio::sync::RwLock; @@ -237,6 +237,8 @@ pub struct UpstreamManager { connect_budget: Duration, unhealthy_fail_threshold: u32, connect_failfast_hard_errors: bool, + no_upstreams_warn_epoch_ms: Arc, + no_healthy_warn_epoch_ms: Arc, stats: Arc, } @@ -262,10 +264,35 @@ impl UpstreamManager { connect_budget: Duration::from_millis(connect_budget_ms.max(1)), unhealthy_fail_threshold: unhealthy_fail_threshold.max(1), connect_failfast_hard_errors, + no_upstreams_warn_epoch_ms: Arc::new(AtomicU64::new(0)), + no_healthy_warn_epoch_ms: Arc::new(AtomicU64::new(0)), stats, } } + fn now_epoch_ms() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64 + } + + fn should_emit_warn(last_epoch_ms: &AtomicU64, cooldown_ms: u64) -> bool { + let now_epoch_ms = Self::now_epoch_ms(); + let previous_epoch_ms = last_epoch_ms.load(Ordering::Relaxed); + if now_epoch_ms.saturating_sub(previous_epoch_ms) < cooldown_ms { + return false; + } + last_epoch_ms + .compare_exchange( + previous_epoch_ms, + now_epoch_ms, + Ordering::AcqRel, + Ordering::Relaxed, + ) + .is_ok() + } + pub fn try_api_snapshot(&self) -> Option { let guard = self.upstreams.try_read().ok()?; let now = std::time::Instant::now(); @@ -533,12 +560,22 @@ impl UpstreamManager { .collect(); if filtered_upstreams.is_empty() { - warn!(scope = scope, "No upstreams available! Using first (direct?)"); + if Self::should_emit_warn( + self.no_upstreams_warn_epoch_ms.as_ref(), + 5_000, + ) { + warn!(scope = scope, "No upstreams available! Using first (direct?)"); + } return None; } if healthy.is_empty() { - warn!(scope = scope, "No healthy upstreams available! Using random."); + if Self::should_emit_warn( + self.no_healthy_warn_epoch_ms.as_ref(), + 5_000, + ) { + warn!(scope = scope, "No healthy upstreams available! Using random."); + } return Some(filtered_upstreams[rand::rng().gen_range(0..filtered_upstreams.len())]); }