diff --git a/src/transport/middle_proxy/pool_init.rs b/src/transport/middle_proxy/pool_init.rs index fbb5c64..668cfda 100644 --- a/src/transport/middle_proxy/pool_init.rs +++ b/src/transport/middle_proxy/pool_init.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; @@ -27,20 +27,14 @@ impl MePool { for family in family_order { let map = self.proxy_map_for_family(family).await; - let mut grouped_dc_addrs: HashMap> = HashMap::new(); - for (dc, addrs) in map { - if addrs.is_empty() { - continue; - } - grouped_dc_addrs.entry(dc.abs()).or_default().extend(addrs); - } - let mut dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = grouped_dc_addrs + let mut dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = map .into_iter() .map(|(dc, mut addrs)| { addrs.sort_unstable(); addrs.dedup(); (dc, addrs) }) + .filter(|(_, addrs)| !addrs.is_empty()) .collect(); dc_addrs.sort_unstable_by_key(|(dc, _)| *dc); dc_addrs.sort_by_key(|(_, addrs)| (addrs.len() != 1, addrs.len())); diff --git a/src/transport/middle_proxy/pool_refill.rs b/src/transport/middle_proxy/pool_refill.rs index 6e14617..87b87d5 100644 --- a/src/transport/middle_proxy/pool_refill.rs +++ b/src/transport/middle_proxy/pool_refill.rs @@ -108,19 +108,10 @@ impl MePool { } else { IpFamily::V6 }; - let map = self.proxy_map_for_family(family).await; - for (dc, endpoints) in map { - if endpoints - .into_iter() - .any(|(ip, port)| SocketAddr::new(ip, port) == addr) - { - return Some(RefillDcKey { - dc: dc.abs(), - family, - }); - } - } - None + Some(RefillDcKey { + dc: self.resolve_dc_for_endpoint(addr).await, + family, + }) } async fn resolve_refill_dc_keys_for_endpoints( @@ -177,47 +168,23 @@ impl MePool { } async fn endpoints_for_same_dc(&self, addr: SocketAddr) -> Vec { - let mut target_dc = HashSet::::new(); let mut endpoints = HashSet::::new(); + let target_dc = self.resolve_dc_for_endpoint(addr).await; if self.decision.ipv4_me { let map = self.proxy_map_v4.read().await.clone(); - for (dc, addrs) in &map { - if addrs - .iter() - .any(|(ip, port)| SocketAddr::new(*ip, *port) == addr) - { - target_dc.insert(dc.abs()); - } - } - for dc in &target_dc { - for key in [*dc, -*dc] { - if let Some(addrs) = map.get(&key) { - for (ip, port) in addrs { - endpoints.insert(SocketAddr::new(*ip, *port)); - } - } + if let Some(addrs) = map.get(&target_dc) { + for (ip, port) in addrs { + endpoints.insert(SocketAddr::new(*ip, *port)); } } } if self.decision.ipv6_me { let map = self.proxy_map_v6.read().await.clone(); - for (dc, addrs) in &map { - if addrs - .iter() - .any(|(ip, port)| SocketAddr::new(*ip, *port) == addr) - { - target_dc.insert(dc.abs()); - } - } - for dc in &target_dc { - for key in [*dc, -*dc] { - if let Some(addrs) = map.get(&key) { - for (ip, port) in addrs { - endpoints.insert(SocketAddr::new(*ip, *port)); - } - } + if let Some(addrs) = map.get(&target_dc) { + for (ip, port) in addrs { + endpoints.insert(SocketAddr::new(*ip, *port)); } } } diff --git a/src/transport/middle_proxy/pool_reinit.rs b/src/transport/middle_proxy/pool_reinit.rs index d5242b7..39944ba 100644 --- a/src/transport/middle_proxy/pool_reinit.rs +++ b/src/transport/middle_proxy/pool_reinit.rs @@ -128,7 +128,7 @@ impl MePool { if self.decision.ipv4_me { let map_v4 = self.proxy_map_v4.read().await.clone(); for (dc, addrs) in map_v4 { - let entry = out.entry(dc.abs()).or_default(); + let entry = out.entry(dc).or_default(); for (ip, port) in addrs { entry.insert(SocketAddr::new(ip, port)); } @@ -138,7 +138,7 @@ impl MePool { if self.decision.ipv6_me { let map_v6 = self.proxy_map_v6.read().await.clone(); for (dc, addrs) in map_v6 { - let entry = out.entry(dc.abs()).or_default(); + let entry = out.entry(dc).or_default(); for (ip, port) in addrs { entry.insert(SocketAddr::new(ip, port)); } diff --git a/src/transport/middle_proxy/pool_status.rs b/src/transport/middle_proxy/pool_status.rs index 17a418c..d9898b1 100644 --- a/src/transport/middle_proxy/pool_status.rs +++ b/src/transport/middle_proxy/pool_status.rs @@ -1,5 +1,5 @@ use std::collections::{BTreeMap, BTreeSet, HashMap}; -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::Ordering; use std::time::Instant; @@ -104,35 +104,11 @@ impl MePool { let mut endpoints_by_dc = BTreeMap::>::new(); if self.decision.ipv4_me { let map = self.proxy_map_v4.read().await.clone(); - for (dc, addrs) in map { - let abs_dc = dc.abs(); - if abs_dc == 0 { - continue; - } - let Ok(dc_idx) = i16::try_from(abs_dc) else { - continue; - }; - let entry = endpoints_by_dc.entry(dc_idx).or_default(); - for (ip, port) in addrs { - entry.insert(SocketAddr::new(ip, port)); - } - } + extend_signed_endpoints(&mut endpoints_by_dc, map); } if self.decision.ipv6_me { let map = self.proxy_map_v6.read().await.clone(); - for (dc, addrs) in map { - let abs_dc = dc.abs(); - if abs_dc == 0 { - continue; - } - let Ok(dc_idx) = i16::try_from(abs_dc) else { - continue; - }; - let entry = endpoints_by_dc.entry(dc_idx).or_default(); - for (ip, port) in addrs { - entry.insert(SocketAddr::new(ip, port)); - } - } + extend_signed_endpoints(&mut endpoints_by_dc, map); } if endpoints_by_dc.is_empty() { @@ -166,35 +142,11 @@ impl MePool { let mut endpoints_by_dc = BTreeMap::>::new(); if self.decision.ipv4_me { let map = self.proxy_map_v4.read().await.clone(); - for (dc, addrs) in map { - let abs_dc = dc.abs(); - if abs_dc == 0 { - continue; - } - let Ok(dc_idx) = i16::try_from(abs_dc) else { - continue; - }; - let entry = endpoints_by_dc.entry(dc_idx).or_default(); - for (ip, port) in addrs { - entry.insert(SocketAddr::new(ip, port)); - } - } + extend_signed_endpoints(&mut endpoints_by_dc, map); } if self.decision.ipv6_me { let map = self.proxy_map_v6.read().await.clone(); - for (dc, addrs) in map { - let abs_dc = dc.abs(); - if abs_dc == 0 { - continue; - } - let Ok(dc_idx) = i16::try_from(abs_dc) else { - continue; - }; - let entry = endpoints_by_dc.entry(dc_idx).or_default(); - for (ip, port) in addrs { - entry.insert(SocketAddr::new(ip, port)); - } - } + extend_signed_endpoints(&mut endpoints_by_dc, map); } if endpoints_by_dc.is_empty() { @@ -234,41 +186,17 @@ impl MePool { let mut endpoints_by_dc = BTreeMap::>::new(); if self.decision.ipv4_me { let map = self.proxy_map_v4.read().await.clone(); - for (dc, addrs) in map { - let abs_dc = dc.abs(); - if abs_dc == 0 { - continue; - } - let Ok(dc_idx) = i16::try_from(abs_dc) else { - continue; - }; - let entry = endpoints_by_dc.entry(dc_idx).or_default(); - for (ip, port) in addrs { - entry.insert(SocketAddr::new(ip, port)); - } - } + extend_signed_endpoints(&mut endpoints_by_dc, map); } if self.decision.ipv6_me { let map = self.proxy_map_v6.read().await.clone(); - for (dc, addrs) in map { - let abs_dc = dc.abs(); - if abs_dc == 0 { - continue; - } - let Ok(dc_idx) = i16::try_from(abs_dc) else { - continue; - }; - let entry = endpoints_by_dc.entry(dc_idx).or_default(); - for (ip, port) in addrs { - entry.insert(SocketAddr::new(ip, port)); - } - } + extend_signed_endpoints(&mut endpoints_by_dc, map); } - let mut endpoint_to_dc = HashMap::::new(); + let mut endpoint_to_dc = HashMap::>::new(); for (dc, endpoints) in &endpoints_by_dc { for endpoint in endpoints { - endpoint_to_dc.entry(*endpoint).or_insert(*dc); + endpoint_to_dc.entry(*endpoint).or_default().insert(*dc); } } @@ -292,7 +220,13 @@ impl MePool { for writer in writers { let endpoint = writer.addr; - let dc = endpoint_to_dc.get(&endpoint).copied(); + let dc = endpoint_to_dc.get(&endpoint).and_then(|dcs| { + if dcs.len() == 1 { + dcs.iter().next().copied() + } else { + None + } + }); let draining = writer.draining.load(Ordering::Relaxed); let degraded = writer.degraded.load(Ordering::Relaxed); let bound_clients = activity @@ -499,6 +433,24 @@ fn ratio_pct(part: usize, total: usize) -> f64 { pct.clamp(0.0, 100.0) } +fn extend_signed_endpoints( + endpoints_by_dc: &mut BTreeMap>, + map: HashMap>, +) { + for (dc, addrs) in map { + if dc == 0 { + continue; + } + let Ok(dc_idx) = i16::try_from(dc) else { + continue; + }; + let entry = endpoints_by_dc.entry(dc_idx).or_default(); + for (ip, port) in addrs { + entry.insert(SocketAddr::new(ip, port)); + } + } +} + fn floor_mode_label(mode: MeFloorMode) -> &'static str { match mode { MeFloorMode::Static => "static", diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index 66a7f81..b437885 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -273,13 +273,12 @@ impl ConnRegistry { bound_clients_by_writer.insert(*writer_id, conn_ids.len()); } for conn_meta in inner.meta.values() { - let dc_u16 = conn_meta.target_dc.unsigned_abs(); - if dc_u16 == 0 { + if conn_meta.target_dc == 0 { continue; } - if let Ok(dc) = i16::try_from(dc_u16) { - *active_sessions_by_target_dc.entry(dc).or_insert(0) += 1; - } + *active_sessions_by_target_dc + .entry(conn_meta.target_dc) + .or_insert(0) += 1; } WriterActivitySnapshot { @@ -402,7 +401,8 @@ mod tests { let snapshot = registry.writer_activity_snapshot().await; assert_eq!(snapshot.bound_clients_by_writer.get(&10), Some(&2)); assert_eq!(snapshot.bound_clients_by_writer.get(&20), Some(&1)); - assert_eq!(snapshot.active_sessions_by_target_dc.get(&2), Some(&2)); + assert_eq!(snapshot.active_sessions_by_target_dc.get(&2), Some(&1)); + assert_eq!(snapshot.active_sessions_by_target_dc.get(&-2), Some(&1)); assert_eq!(snapshot.active_sessions_by_target_dc.get(&4), Some(&1)); } }