diff --git a/src/transport/middle_proxy/health.rs b/src/transport/middle_proxy/health.rs index 4fcba39..b0536cc 100644 --- a/src/transport/middle_proxy/health.rs +++ b/src/transport/middle_proxy/health.rs @@ -147,7 +147,7 @@ async fn check_family( IpFamily::V6 => pool.proxy_map_v6.read().await, }; for (dc, addrs) in map_guard.iter() { - let entry = dc_endpoints.entry(dc.abs()).or_default(); + let entry = dc_endpoints.entry(*dc).or_default(); for (ip, port) in addrs.iter().copied() { entry.push(SocketAddr::new(ip, port)); } @@ -164,14 +164,15 @@ async fn check_family( adaptive_recover_until.clear(); } - let mut live_addr_counts = HashMap::::new(); - let mut live_writer_ids_by_addr = HashMap::>::new(); + let mut live_addr_counts = HashMap::<(i32, SocketAddr), usize>::new(); + let mut live_writer_ids_by_addr = HashMap::<(i32, SocketAddr), Vec>::new(); for writer in pool.writers.read().await.iter().filter(|w| { !w.draining.load(std::sync::atomic::Ordering::Relaxed) }) { - *live_addr_counts.entry(writer.addr).or_insert(0) += 1; + let key = (writer.writer_dc, writer.addr); + *live_addr_counts.entry(key).or_insert(0) += 1; live_writer_ids_by_addr - .entry(writer.addr) + .entry(key) .or_default() .push(writer.id); } @@ -211,7 +212,7 @@ async fn check_family( }); let alive = endpoints .iter() - .map(|addr| *live_addr_counts.get(addr).unwrap_or(&0)) + .map(|addr| *live_addr_counts.get(&(dc, *addr)).unwrap_or(&0)) .sum::(); if endpoints.len() == 1 && pool.single_endpoint_outage_mode_enabled() && alive == 0 { @@ -321,7 +322,10 @@ async fn check_family( if *inflight.get(&key).unwrap_or(&0) >= max_concurrent { continue; } - if pool.has_refill_inflight_for_endpoints(&endpoints).await { + if pool + .has_refill_inflight_for_dc_key(super::pool::RefillDcKey { dc, family }) + .await + { debug!( dc = %dc, ?family, @@ -373,7 +377,7 @@ async fn check_family( } let res = tokio::time::timeout( pool.me_one_timeout, - pool.connect_endpoints_round_robin(&endpoints, rng.as_ref()), + pool.connect_endpoints_round_robin(dc, &endpoints, rng.as_ref()), ) .await; match res { @@ -484,12 +488,13 @@ fn adaptive_floor_class_max( } fn list_writer_ids_for_endpoints( + dc: i32, endpoints: &[SocketAddr], - live_writer_ids_by_addr: &HashMap>, + live_writer_ids_by_addr: &HashMap<(i32, SocketAddr), Vec>, ) -> Vec { let mut out = Vec::::new(); for endpoint in endpoints { - if let Some(ids) = live_writer_ids_by_addr.get(endpoint) { + if let Some(ids) = live_writer_ids_by_addr.get(&(dc, *endpoint)) { out.extend(ids.iter().copied()); } } @@ -500,8 +505,8 @@ async fn build_family_floor_plan( pool: &Arc, family: IpFamily, dc_endpoints: &HashMap>, - live_addr_counts: &HashMap, - live_writer_ids_by_addr: &HashMap>, + live_addr_counts: &HashMap<(i32, SocketAddr), usize>, + live_writer_ids_by_addr: &HashMap<(i32, SocketAddr), Vec>, bound_clients_by_writer: &HashMap, adaptive_idle_since: &mut HashMap<(i32, IpFamily), Instant>, adaptive_recover_until: &mut HashMap<(i32, IpFamily), Instant>, @@ -522,6 +527,7 @@ async fn build_family_floor_plan( let reduce_for_idle = should_reduce_floor_for_idle( pool, key, + *dc, endpoints, live_writer_ids_by_addr, bound_clients_by_writer, @@ -551,10 +557,10 @@ async fn build_family_floor_plan( let target_required = desired_raw.clamp(min_required, max_required); let alive = endpoints .iter() - .map(|endpoint| live_addr_counts.get(endpoint).copied().unwrap_or(0)) + .map(|endpoint| live_addr_counts.get(&(*dc, *endpoint)).copied().unwrap_or(0)) .sum::(); family_active_total = family_active_total.saturating_add(alive); - let writer_ids = list_writer_ids_for_endpoints(endpoints, live_writer_ids_by_addr); + let writer_ids = list_writer_ids_for_endpoints(*dc, endpoints, live_writer_ids_by_addr); let has_bound_clients = has_bound_clients_on_endpoint(&writer_ids, bound_clients_by_writer); entries.push(DcFloorPlanEntry { @@ -654,14 +660,14 @@ async fn maybe_swap_idle_writer_for_cap( dc: i32, family: IpFamily, endpoints: &[SocketAddr], - live_writer_ids_by_addr: &HashMap>, + live_writer_ids_by_addr: &HashMap<(i32, SocketAddr), Vec>, writer_idle_since: &HashMap, bound_clients_by_writer: &HashMap, ) -> bool { let now_epoch_secs = MePool::now_epoch_secs(); let mut candidate: Option<(u64, SocketAddr, u64)> = None; for endpoint in endpoints { - let Some(writer_ids) = live_writer_ids_by_addr.get(endpoint) else { + let Some(writer_ids) = live_writer_ids_by_addr.get(&(dc, *endpoint)) else { continue; }; for writer_id in writer_ids { @@ -686,7 +692,12 @@ async fn maybe_swap_idle_writer_for_cap( return false; }; - let connected = match tokio::time::timeout(pool.me_one_timeout, pool.connect_one(endpoint, rng.as_ref())).await { + let connected = match tokio::time::timeout( + pool.me_one_timeout, + pool.connect_one_for_dc(endpoint, dc, rng.as_ref()), + ) + .await + { Ok(Ok(())) => true, Ok(Err(error)) => { debug!( @@ -738,7 +749,7 @@ async fn maybe_refresh_idle_writer_for_dc( endpoints: &[SocketAddr], alive: usize, required: usize, - live_writer_ids_by_addr: &HashMap>, + live_writer_ids_by_addr: &HashMap<(i32, SocketAddr), Vec>, writer_idle_since: &HashMap, bound_clients_by_writer: &HashMap, idle_refresh_next_attempt: &mut HashMap<(i32, IpFamily), Instant>, @@ -757,7 +768,7 @@ async fn maybe_refresh_idle_writer_for_dc( let now_epoch_secs = MePool::now_epoch_secs(); let mut candidate: Option<(u64, SocketAddr, u64, u64)> = None; for endpoint in endpoints { - let Some(writer_ids) = live_writer_ids_by_addr.get(endpoint) else { + let Some(writer_ids) = live_writer_ids_by_addr.get(&(dc, *endpoint)) else { continue; }; for writer_id in writer_ids { @@ -787,7 +798,12 @@ async fn maybe_refresh_idle_writer_for_dc( return; }; - let rotate_ok = match tokio::time::timeout(pool.me_one_timeout, pool.connect_one(endpoint, rng.as_ref())).await { + let rotate_ok = match tokio::time::timeout( + pool.me_one_timeout, + pool.connect_one_for_dc(endpoint, dc, rng.as_ref()), + ) + .await + { Ok(Ok(())) => true, Ok(Err(error)) => { debug!( @@ -843,8 +859,9 @@ async fn maybe_refresh_idle_writer_for_dc( async fn should_reduce_floor_for_idle( pool: &Arc, key: (i32, IpFamily), + dc: i32, endpoints: &[SocketAddr], - live_writer_ids_by_addr: &HashMap>, + live_writer_ids_by_addr: &HashMap<(i32, SocketAddr), Vec>, bound_clients_by_writer: &HashMap, adaptive_idle_since: &mut HashMap<(i32, IpFamily), Instant>, adaptive_recover_until: &mut HashMap<(i32, IpFamily), Instant>, @@ -856,7 +873,7 @@ async fn should_reduce_floor_for_idle( } let now = Instant::now(); - let writer_ids = list_writer_ids_for_endpoints(endpoints, live_writer_ids_by_addr); + let writer_ids = list_writer_ids_for_endpoints(dc, endpoints, live_writer_ids_by_addr); let has_bound_clients = has_bound_clients_on_endpoint(&writer_ids, bound_clients_by_writer); if has_bound_clients { adaptive_idle_since.remove(&key); @@ -922,7 +939,12 @@ async fn recover_single_endpoint_outage( let attempt_ok = if bypass_quarantine { pool.stats .increment_me_single_endpoint_quarantine_bypass_total(); - match tokio::time::timeout(pool.me_one_timeout, pool.connect_one(endpoint, rng.as_ref())).await { + match tokio::time::timeout( + pool.me_one_timeout, + pool.connect_one_for_dc(endpoint, key.0, rng.as_ref()), + ) + .await + { Ok(Ok(())) => true, Ok(Err(e)) => { debug!( @@ -948,7 +970,7 @@ async fn recover_single_endpoint_outage( let one_endpoint = [endpoint]; match tokio::time::timeout( pool.me_one_timeout, - pool.connect_endpoints_round_robin(&one_endpoint, rng.as_ref()), + pool.connect_endpoints_round_robin(key.0, &one_endpoint, rng.as_ref()), ) .await { @@ -1012,7 +1034,7 @@ async fn maybe_rotate_single_endpoint_shadow( endpoints: &[SocketAddr], alive: usize, required: usize, - live_writer_ids_by_addr: &HashMap>, + live_writer_ids_by_addr: &HashMap<(i32, SocketAddr), Vec>, bound_clients_by_writer: &HashMap, shadow_rotate_deadline: &mut HashMap<(i32, IpFamily), Instant>, ) { @@ -1045,7 +1067,7 @@ async fn maybe_rotate_single_endpoint_shadow( return; } - let Some(writer_ids) = live_writer_ids_by_addr.get(&endpoint) else { + let Some(writer_ids) = live_writer_ids_by_addr.get(&(dc, endpoint)) else { shadow_rotate_deadline.insert(key, now + Duration::from_secs(SHADOW_ROTATE_RETRY_SECS)); return; }; @@ -1071,7 +1093,12 @@ async fn maybe_rotate_single_endpoint_shadow( return; }; - let rotate_ok = match tokio::time::timeout(pool.me_one_timeout, pool.connect_one(endpoint, rng.as_ref())).await { + let rotate_ok = match tokio::time::timeout( + pool.me_one_timeout, + pool.connect_one_for_dc(endpoint, dc, rng.as_ref()), + ) + .await + { Ok(Ok(())) => true, Ok(Err(e)) => { debug!( diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 236a12a..13259bb 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -819,18 +819,6 @@ impl MePool { (self.default_dc_for_routing(), true) } - pub(super) fn dc_lookup_chain_for_target(&self, target_dc: i32) -> Vec { - let mut out = Vec::with_capacity(1); - if target_dc != 0 { - out.push(target_dc); - } else { - // Use default DC only when target DC is unknown and pinning is not established. - let fallback_dc = self.default_dc_for_routing(); - out.push(fallback_dc); - } - out - } - pub(super) async fn resolve_dc_for_endpoint(&self, addr: SocketAddr) -> i32 { if let Some(cached) = self.endpoint_dc_map.read().await.get(&addr).copied() && let Some(dc) = cached diff --git a/src/transport/middle_proxy/pool_config.rs b/src/transport/middle_proxy/pool_config.rs index a43f9bf..66752bf 100644 --- a/src/transport/middle_proxy/pool_config.rs +++ b/src/transport/middle_proxy/pool_config.rs @@ -110,7 +110,10 @@ impl MePool { pub async fn reconnect_all(self: &Arc) { let ws = self.writers.read().await.clone(); for w in ws { - if let Ok(()) = self.connect_one(w.addr, self.rng.as_ref()).await { + if let Ok(()) = self + .connect_one_for_dc(w.addr, w.writer_dc, self.rng.as_ref()) + .await + { self.mark_writer_draining(w.id).await; tokio::time::sleep(Duration::from_secs(2)).await; } diff --git a/src/transport/middle_proxy/pool_refill.rs b/src/transport/middle_proxy/pool_refill.rs index 3c8b0bb..316f3ff 100644 --- a/src/transport/middle_proxy/pool_refill.rs +++ b/src/transport/middle_proxy/pool_refill.rs @@ -218,14 +218,6 @@ impl MePool { false } - pub(crate) fn trigger_immediate_refill(self: &Arc, addr: SocketAddr) { - let pool = Arc::clone(self); - tokio::spawn(async move { - let writer_dc = pool.resolve_dc_for_endpoint(addr).await; - pool.trigger_immediate_refill_for_dc(addr, writer_dc); - }); - } - pub(crate) fn trigger_immediate_refill_for_dc(self: &Arc, addr: SocketAddr, writer_dc: i32) { let endpoint_key = RefillEndpointKey { dc: writer_dc, @@ -243,7 +235,6 @@ impl MePool { let pool = Arc::clone(self); tokio::spawn(async move { - let dc_endpoints = pool.endpoints_for_dc(writer_dc).await; let dc_key = RefillDcKey { dc: writer_dc, family: if addr.is_ipv4() { diff --git a/src/transport/middle_proxy/pool_reinit.rs b/src/transport/middle_proxy/pool_reinit.rs index 39944ba..625ccf0 100644 --- a/src/transport/middle_proxy/pool_reinit.rs +++ b/src/transport/middle_proxy/pool_reinit.rs @@ -62,7 +62,7 @@ impl MePool { fn coverage_ratio( desired_by_dc: &HashMap>, - active_writer_addrs: &HashSet, + active_writer_addrs: &HashSet<(i32, SocketAddr)>, ) -> (f32, Vec) { if desired_by_dc.is_empty() { return (1.0, Vec::new()); @@ -76,7 +76,7 @@ impl MePool { } if endpoints .iter() - .any(|addr| active_writer_addrs.contains(addr)) + .any(|addr| active_writer_addrs.contains(&(*dc, *addr))) { covered += 1; } else { @@ -91,32 +91,25 @@ impl MePool { } pub async fn reconcile_connections(self: &Arc, rng: &SecureRandom) { - let writers = self.writers.read().await; - let current: HashSet = writers - .iter() - .filter(|w| !w.draining.load(Ordering::Relaxed)) - .map(|w| w.addr) - .collect(); - drop(writers); - for family in self.family_order() { let map = self.proxy_map_for_family(family).await; - for (_dc, addrs) in &map { + for (dc, addrs) in &map { let dc_addrs: Vec = addrs .iter() .map(|(ip, port)| SocketAddr::new(*ip, *port)) .collect(); - if !dc_addrs.iter().any(|a| current.contains(a)) { + let dc_endpoints: HashSet = dc_addrs.iter().copied().collect(); + if self.active_writer_count_for_dc_endpoints(*dc, &dc_endpoints).await == 0 { let mut shuffled = dc_addrs.clone(); shuffled.shuffle(&mut rand::rng()); for addr in shuffled { - if self.connect_one(addr, rng).await.is_ok() { + if self.connect_one_for_dc(addr, *dc, rng).await.is_ok() { break; } } } } - if !self.decision.effective_multipath && !current.is_empty() { + if !self.decision.effective_multipath && self.connection_count() > 0 { break; } } @@ -174,26 +167,30 @@ impl MePool { core.saturating_add(rand::rng().random_range(0..=jitter)) } - async fn fresh_writer_count_for_endpoints( + async fn fresh_writer_count_for_dc_endpoints( &self, generation: u64, + dc: i32, endpoints: &HashSet, ) -> usize { let ws = self.writers.read().await; ws.iter() .filter(|w| !w.draining.load(Ordering::Relaxed)) .filter(|w| w.generation == generation) + .filter(|w| w.writer_dc == dc) .filter(|w| endpoints.contains(&w.addr)) .count() } - pub(super) async fn active_writer_count_for_endpoints( + pub(super) async fn active_writer_count_for_dc_endpoints( &self, + dc: i32, endpoints: &HashSet, ) -> usize { let ws = self.writers.read().await; ws.iter() .filter(|w| !w.draining.load(Ordering::Relaxed)) + .filter(|w| w.writer_dc == dc) .filter(|w| endpoints.contains(&w.addr)) .count() } @@ -220,7 +217,7 @@ impl MePool { let required = self.required_writers_for_dc(endpoint_list.len()); let mut completed = false; let mut last_fresh_count = self - .fresh_writer_count_for_endpoints(generation, endpoints) + .fresh_writer_count_for_dc_endpoints(generation, *dc, endpoints) .await; for pass_idx in 0..total_passes { @@ -247,6 +244,7 @@ impl MePool { let connected = self .connect_endpoints_round_robin_with_generation_contour( + *dc, &endpoint_list, rng, generation, @@ -265,7 +263,7 @@ impl MePool { } last_fresh_count = self - .fresh_writer_count_for_endpoints(generation, endpoints) + .fresh_writer_count_for_dc_endpoints(generation, *dc, endpoints) .await; if last_fresh_count >= required { completed = true; @@ -377,10 +375,10 @@ impl MePool { } let writers = self.writers.read().await; - let active_writer_addrs: HashSet = writers + let active_writer_addrs: HashSet<(i32, SocketAddr)> = writers .iter() .filter(|w| !w.draining.load(Ordering::Relaxed)) - .map(|w| w.addr) + .map(|w| (w.writer_dc, w.addr)) .collect(); let min_ratio = Self::permille_to_ratio( self.me_pool_min_fresh_ratio_permille @@ -410,6 +408,7 @@ impl MePool { .iter() .filter(|w| !w.draining.load(Ordering::Relaxed)) .filter(|w| w.generation == generation) + .filter(|w| w.writer_dc == *dc) .filter(|w| endpoints.contains(&w.addr)) .count(); if fresh_count < required { @@ -438,9 +437,9 @@ impl MePool { self.promote_warm_generation_to_active(generation).await; } - let desired_addrs: HashSet = desired_by_dc - .values() - .flat_map(|set| set.iter().copied()) + let desired_addrs: HashSet<(i32, SocketAddr)> = desired_by_dc + .iter() + .flat_map(|(dc, set)| set.iter().copied().map(|addr| (*dc, addr))) .collect(); let stale_writer_ids: Vec = writers @@ -450,7 +449,7 @@ impl MePool { if hardswap { w.generation < generation } else { - !desired_addrs.contains(&w.addr) + !desired_addrs.contains(&(w.writer_dc, w.addr)) } }) .map(|w| w.id) diff --git a/src/transport/middle_proxy/pool_status.rs b/src/transport/middle_proxy/pool_status.rs index 46346b5..2922ed8 100644 --- a/src/transport/middle_proxy/pool_status.rs +++ b/src/transport/middle_proxy/pool_status.rs @@ -130,19 +130,18 @@ impl MePool { } let writers = self.writers.read().await.clone(); - let mut live_writers_by_endpoint = HashMap::::new(); + let mut live_writers_by_dc = HashMap::::new(); for writer in writers { if writer.draining.load(Ordering::Relaxed) { continue; } - *live_writers_by_endpoint.entry(writer.addr).or_insert(0) += 1; + if let Ok(dc) = i16::try_from(writer.writer_dc) { + *live_writers_by_dc.entry(dc).or_insert(0) += 1; + } } - for endpoints in endpoints_by_dc.values() { - let alive: usize = endpoints - .iter() - .map(|endpoint| live_writers_by_endpoint.get(endpoint).copied().unwrap_or(0)) - .sum(); + for dc in endpoints_by_dc.keys() { + let alive = live_writers_by_dc.get(dc).copied().unwrap_or(0); if alive == 0 { return false; } @@ -168,24 +167,23 @@ impl MePool { } let writers = self.writers.read().await.clone(); - let mut live_writers_by_endpoint = HashMap::::new(); + let mut live_writers_by_dc = HashMap::::new(); for writer in writers { if writer.draining.load(Ordering::Relaxed) { continue; } - *live_writers_by_endpoint.entry(writer.addr).or_insert(0) += 1; + if let Ok(dc) = i16::try_from(writer.writer_dc) { + *live_writers_by_dc.entry(dc).or_insert(0) += 1; + } } - for endpoints in endpoints_by_dc.values() { + for (dc, endpoints) in endpoints_by_dc { let endpoint_count = endpoints.len(); if endpoint_count == 0 { return false; } let required = self.required_writers_for_dc_with_floor_mode(endpoint_count, false); - let alive: usize = endpoints - .iter() - .map(|endpoint| live_writers_by_endpoint.get(endpoint).copied().unwrap_or(0)) - .sum(); + let alive = live_writers_by_dc.get(&dc).copied().unwrap_or(0); if alive < required { return false; } @@ -207,13 +205,6 @@ impl MePool { extend_signed_endpoints(&mut endpoints_by_dc, map); } - let mut endpoint_to_dc = HashMap::>::new(); - for (dc, endpoints) in &endpoints_by_dc { - for endpoint in endpoints { - endpoint_to_dc.entry(*endpoint).or_default().insert(*dc); - } - } - let configured_dc_groups = endpoints_by_dc.len(); let configured_endpoints = endpoints_by_dc.values().map(BTreeSet::len).sum(); @@ -227,20 +218,14 @@ impl MePool { let rtt = self.rtt_stats.lock().await.clone(); let writers = self.writers.read().await.clone(); - let mut live_writers_by_endpoint = HashMap::::new(); + let mut live_writers_by_dc_endpoint = HashMap::<(i16, SocketAddr), usize>::new(); let mut live_writers_by_dc = HashMap::::new(); let mut dc_rtt_agg = HashMap::::new(); let mut writer_rows = Vec::::with_capacity(writers.len()); for writer in writers { let endpoint = writer.addr; - let dc = endpoint_to_dc.get(&endpoint).and_then(|dcs| { - if dcs.len() == 1 { - dcs.iter().next().copied() - } else { - None - } - }); + let dc = i16::try_from(writer.writer_dc).ok(); let draining = writer.draining.load(Ordering::Relaxed); let degraded = writer.degraded.load(Ordering::Relaxed); let bound_clients = activity @@ -259,8 +244,10 @@ impl MePool { }; if !draining { - *live_writers_by_endpoint.entry(endpoint).or_insert(0) += 1; if let Some(dc_idx) = dc { + *live_writers_by_dc_endpoint + .entry((dc_idx, endpoint)) + .or_insert(0) += 1; *live_writers_by_dc.entry(dc_idx).or_insert(0) += 1; if let Some(ema_ms) = rtt_ema_ms { let entry = dc_rtt_agg.entry(dc_idx).or_insert((0.0, 0)); @@ -298,7 +285,7 @@ impl MePool { let endpoint_count = endpoints.len(); let dc_available_endpoints = endpoints .iter() - .filter(|endpoint| live_writers_by_endpoint.contains_key(endpoint)) + .filter(|endpoint| live_writers_by_dc_endpoint.contains_key(&(dc, **endpoint))) .count(); let base_required = self.required_writers_for_dc(endpoint_count); let dc_required_writers = diff --git a/src/transport/middle_proxy/pool_writer.rs b/src/transport/middle_proxy/pool_writer.rs index 17ef331..244a08e 100644 --- a/src/transport/middle_proxy/pool_writer.rs +++ b/src/transport/middle_proxy/pool_writer.rs @@ -50,11 +50,6 @@ impl MePool { } } - pub(crate) async fn connect_one(self: &Arc, addr: SocketAddr, rng: &SecureRandom) -> Result<()> { - let writer_dc = self.resolve_dc_for_endpoint(addr).await; - self.connect_one_for_dc(addr, writer_dc, rng).await - } - pub(crate) async fn connect_one_for_dc( self: &Arc, addr: SocketAddr, diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index ccaad4a..07d39f6 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -54,6 +54,9 @@ impl MePool { }; let no_writer_mode = MeRouteNoWriterMode::from_u8(self.me_route_no_writer_mode.load(Ordering::Relaxed)); + let (routed_dc, unknown_target_dc) = self + .resolve_target_dc_for_routing(target_dc as i32) + .await; let mut no_writer_deadline: Option = None; let mut emergency_attempts = 0u32; let mut async_recovery_triggered = false; @@ -91,9 +94,9 @@ impl MePool { let deadline = *no_writer_deadline.get_or_insert_with(|| { Instant::now() + self.me_route_no_writer_wait }); - if !async_recovery_triggered { + if !async_recovery_triggered && !unknown_target_dc { let triggered = - self.trigger_async_recovery_for_target_dc(target_dc).await; + self.trigger_async_recovery_for_target_dc(routed_dc).await; if !triggered { self.trigger_async_recovery_global().await; } @@ -109,31 +112,34 @@ impl MePool { } MeRouteNoWriterMode::InlineRecoveryLegacy => { self.stats.increment_me_inline_recovery_total(); - for _ in 0..self.me_route_inline_recovery_attempts.max(1) { - for family in self.family_order() { - let map = match family { - IpFamily::V4 => self.proxy_map_v4.read().await.clone(), - IpFamily::V6 => self.proxy_map_v6.read().await.clone(), - }; - for (_dc, addrs) in &map { - for (ip, port) in addrs { - let addr = SocketAddr::new(*ip, *port); - let _ = self.connect_one(addr, self.rng.as_ref()).await; + if !unknown_target_dc { + for _ in 0..self.me_route_inline_recovery_attempts.max(1) { + for family in self.family_order() { + let map = match family { + IpFamily::V4 => self.proxy_map_v4.read().await.clone(), + IpFamily::V6 => self.proxy_map_v6.read().await.clone(), + }; + for (dc, addrs) in &map { + for (ip, port) in addrs { + let addr = SocketAddr::new(*ip, *port); + let _ = self + .connect_one_for_dc(addr, *dc, self.rng.as_ref()) + .await; + } } } - } - if !self.writers.read().await.is_empty() { - break; + if !self.writers.read().await.is_empty() { + break; + } } } + if !self.writers.read().await.is_empty() { continue; } - let waiter = self.writer_available.notified(); - if tokio::time::timeout(self.me_route_inline_recovery_wait, waiter) - .await - .is_err() - { + let deadline = *no_writer_deadline + .get_or_insert_with(|| Instant::now() + self.me_route_inline_recovery_wait); + if !self.wait_for_writer_until(deadline).await { if !self.writers.read().await.is_empty() { continue; } @@ -145,13 +151,15 @@ impl MePool { continue; } MeRouteNoWriterMode::HybridAsyncPersistent => { - self.maybe_trigger_hybrid_recovery( - target_dc, - &mut hybrid_recovery_round, - &mut hybrid_last_recovery_at, - hybrid_wait_current, - ) - .await; + if !unknown_target_dc { + self.maybe_trigger_hybrid_recovery( + routed_dc, + &mut hybrid_recovery_round, + &mut hybrid_last_recovery_at, + hybrid_wait_current, + ) + .await; + } let deadline = Instant::now() + hybrid_wait_current; let _ = self.wait_for_writer_until(deadline).await; hybrid_wait_current = @@ -165,11 +173,11 @@ impl MePool { }; let mut candidate_indices = self - .candidate_indices_for_dc(&writers_snapshot, target_dc, false) + .candidate_indices_for_dc(&writers_snapshot, routed_dc, false) .await; if candidate_indices.is_empty() { candidate_indices = self - .candidate_indices_for_dc(&writers_snapshot, target_dc, true) + .candidate_indices_for_dc(&writers_snapshot, routed_dc, true) .await; } if candidate_indices.is_empty() { @@ -178,14 +186,14 @@ impl MePool { let deadline = *no_writer_deadline.get_or_insert_with(|| { Instant::now() + self.me_route_no_writer_wait }); - if !async_recovery_triggered { - let triggered = self.trigger_async_recovery_for_target_dc(target_dc).await; + if !async_recovery_triggered && !unknown_target_dc { + let triggered = self.trigger_async_recovery_for_target_dc(routed_dc).await; if !triggered { self.trigger_async_recovery_global().await; } async_recovery_triggered = true; } - if self.wait_for_candidate_until(target_dc, deadline).await { + if self.wait_for_candidate_until(routed_dc, deadline).await { continue; } self.stats.increment_me_no_writer_failfast_total(); @@ -195,15 +203,24 @@ impl MePool { } MeRouteNoWriterMode::InlineRecoveryLegacy => { self.stats.increment_me_inline_recovery_total(); + if unknown_target_dc { + let deadline = *no_writer_deadline + .get_or_insert_with(|| Instant::now() + self.me_route_inline_recovery_wait); + if self.wait_for_candidate_until(routed_dc, deadline).await { + continue; + } + self.stats.increment_me_no_writer_failfast_total(); + return Err(ProxyError::Proxy("No ME writers available for target DC".into())); + } if emergency_attempts >= self.me_route_inline_recovery_attempts.max(1) { self.stats.increment_me_no_writer_failfast_total(); return Err(ProxyError::Proxy("No ME writers available for target DC".into())); } emergency_attempts += 1; - let mut endpoints = self.endpoint_candidates_for_target_dc(target_dc).await; + let mut endpoints = self.endpoint_candidates_for_target_dc(routed_dc).await; endpoints.shuffle(&mut rand::rng()); for addr in endpoints { - if self.connect_one(addr, self.rng.as_ref()).await.is_ok() { + if self.connect_one_for_dc(addr, routed_dc, self.rng.as_ref()).await.is_ok() { break; } } @@ -212,11 +229,11 @@ impl MePool { writers_snapshot = ws2.clone(); drop(ws2); candidate_indices = self - .candidate_indices_for_dc(&writers_snapshot, target_dc, false) + .candidate_indices_for_dc(&writers_snapshot, routed_dc, false) .await; if candidate_indices.is_empty() { candidate_indices = self - .candidate_indices_for_dc(&writers_snapshot, target_dc, true) + .candidate_indices_for_dc(&writers_snapshot, routed_dc, true) .await; } if candidate_indices.is_empty() { @@ -224,15 +241,17 @@ impl MePool { } } MeRouteNoWriterMode::HybridAsyncPersistent => { - self.maybe_trigger_hybrid_recovery( - target_dc, - &mut hybrid_recovery_round, - &mut hybrid_last_recovery_at, - hybrid_wait_current, - ) - .await; + if !unknown_target_dc { + self.maybe_trigger_hybrid_recovery( + routed_dc, + &mut hybrid_recovery_round, + &mut hybrid_last_recovery_at, + hybrid_wait_current, + ) + .await; + } let deadline = Instant::now() + hybrid_wait_current; - let _ = self.wait_for_candidate_until(target_dc, deadline).await; + let _ = self.wait_for_candidate_until(routed_dc, deadline).await; hybrid_wait_current = (hybrid_wait_current.saturating_mul(2)) .min(Duration::from_millis(400)); continue; @@ -382,32 +401,32 @@ impl MePool { !self.writers.read().await.is_empty() } - async fn wait_for_candidate_until(&self, target_dc: i16, deadline: Instant) -> bool { + async fn wait_for_candidate_until(&self, routed_dc: i32, deadline: Instant) -> bool { loop { - if self.has_candidate_for_target_dc(target_dc).await { + if self.has_candidate_for_target_dc(routed_dc).await { return true; } let now = Instant::now(); if now >= deadline { - return self.has_candidate_for_target_dc(target_dc).await; + return self.has_candidate_for_target_dc(routed_dc).await; } let waiter = self.writer_available.notified(); - if self.has_candidate_for_target_dc(target_dc).await { + if self.has_candidate_for_target_dc(routed_dc).await { return true; } let remaining = deadline.saturating_duration_since(Instant::now()); if remaining.is_zero() { - return self.has_candidate_for_target_dc(target_dc).await; + return self.has_candidate_for_target_dc(routed_dc).await; } if tokio::time::timeout(remaining, waiter).await.is_err() { - return self.has_candidate_for_target_dc(target_dc).await; + return self.has_candidate_for_target_dc(routed_dc).await; } } } - async fn has_candidate_for_target_dc(&self, target_dc: i16) -> bool { + async fn has_candidate_for_target_dc(&self, routed_dc: i32) -> bool { let writers_snapshot = { let ws = self.writers.read().await; if ws.is_empty() { @@ -416,41 +435,41 @@ impl MePool { ws.clone() }; let mut candidate_indices = self - .candidate_indices_for_dc(&writers_snapshot, target_dc, false) + .candidate_indices_for_dc(&writers_snapshot, routed_dc, false) .await; if candidate_indices.is_empty() { candidate_indices = self - .candidate_indices_for_dc(&writers_snapshot, target_dc, true) + .candidate_indices_for_dc(&writers_snapshot, routed_dc, true) .await; } !candidate_indices.is_empty() } - async fn trigger_async_recovery_for_target_dc(self: &Arc, target_dc: i16) -> bool { - let endpoints = self.endpoint_candidates_for_target_dc(target_dc).await; + async fn trigger_async_recovery_for_target_dc(self: &Arc, routed_dc: i32) -> bool { + let endpoints = self.endpoint_candidates_for_target_dc(routed_dc).await; if endpoints.is_empty() { return false; } self.stats.increment_me_async_recovery_trigger_total(); for addr in endpoints.into_iter().take(8) { - self.trigger_immediate_refill(addr); + self.trigger_immediate_refill_for_dc(addr, routed_dc); } true } async fn trigger_async_recovery_global(self: &Arc) { self.stats.increment_me_async_recovery_trigger_total(); - let mut seen = HashSet::::new(); + let mut seen = HashSet::<(i32, SocketAddr)>::new(); for family in self.family_order() { let map_guard = match family { IpFamily::V4 => self.proxy_map_v4.read().await, IpFamily::V6 => self.proxy_map_v6.read().await, }; - for addrs in map_guard.values() { + for (dc, addrs) in map_guard.iter() { for (ip, port) in addrs { let addr = SocketAddr::new(*ip, *port); - if seen.insert(addr) { - self.trigger_immediate_refill(addr); + if seen.insert((*dc, addr)) { + self.trigger_immediate_refill_for_dc(addr, *dc); } if seen.len() >= 8 { return; @@ -460,11 +479,9 @@ impl MePool { } } - async fn endpoint_candidates_for_target_dc(&self, target_dc: i16) -> Vec { - let key = target_dc as i32; + async fn endpoint_candidates_for_target_dc(&self, routed_dc: i32) -> Vec { let mut preferred = Vec::::new(); let mut seen = HashSet::::new(); - let lookup_keys = self.dc_lookup_chain_for_target(key); for family in self.family_order() { let map_guard = match family { @@ -472,14 +489,9 @@ impl MePool { IpFamily::V6 => self.proxy_map_v6.read().await, }; let mut family_selected = Vec::::new(); - for lookup in lookup_keys.iter().copied() { - if let Some(addrs) = map_guard.get(&lookup) { - for (ip, port) in addrs { - family_selected.push(SocketAddr::new(*ip, *port)); - } - } - if !family_selected.is_empty() { - break; + if let Some(addrs) = map_guard.get(&routed_dc) { + for (ip, port) in addrs { + family_selected.push(SocketAddr::new(*ip, *port)); } } for addr in family_selected { @@ -497,7 +509,7 @@ impl MePool { async fn maybe_trigger_hybrid_recovery( self: &Arc, - target_dc: i16, + routed_dc: i32, hybrid_recovery_round: &mut u32, hybrid_last_recovery_at: &mut Option, hybrid_wait_step: Duration, @@ -509,7 +521,7 @@ impl MePool { } let round = *hybrid_recovery_round; - let target_triggered = self.trigger_async_recovery_for_target_dc(target_dc).await; + let target_triggered = self.trigger_async_recovery_for_target_dc(routed_dc).await; if !target_triggered || round % HYBRID_GLOBAL_BURST_PERIOD_ROUNDS == 0 { self.trigger_async_recovery_global().await; } @@ -576,12 +588,10 @@ impl MePool { pub(super) async fn candidate_indices_for_dc( &self, writers: &[super::pool::MeWriter], - target_dc: i16, + routed_dc: i32, include_warm: bool, ) -> Vec { - let key = target_dc as i32; let mut preferred = HashSet::::new(); - let lookup_keys = self.dc_lookup_chain_for_target(key); for family in self.family_order() { let map_guard = match family { @@ -589,13 +599,8 @@ impl MePool { IpFamily::V6 => self.proxy_map_v6.read().await, }; let mut family_selected = Vec::::new(); - for lookup in lookup_keys.iter().copied() { - if let Some(v) = map_guard.get(&lookup) { - family_selected.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port))); - } - if !family_selected.is_empty() { - break; - } + if let Some(v) = map_guard.get(&routed_dc) { + family_selected.extend(v.iter().map(|(ip, port)| SocketAddr::new(*ip, *port))); } for endpoint in family_selected { preferred.insert(endpoint); @@ -617,7 +622,7 @@ impl MePool { if !self.writer_eligible_for_selection(w, include_warm) { continue; } - if preferred.contains(&w.addr) { + if w.writer_dc == routed_dc && preferred.contains(&w.addr) { out.push(idx); } }