diff --git a/src/main.rs b/src/main.rs index a87dd99..2675509 100644 --- a/src/main.rs +++ b/src/main.rs @@ -770,12 +770,14 @@ async fn main() -> std::result::Result<(), Box> { // Background tasks let um_clone = upstream_manager.clone(); let decision_clone = decision.clone(); + let dc_overrides_for_health = config.dc_overrides.clone(); tokio::spawn(async move { um_clone .run_health_checks( prefer_ipv6, decision_clone.ipv4_dc, decision_clone.ipv6_dc, + dc_overrides_for_health, ) .await; }); diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index 8411f5a..fa7b0a6 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -4,7 +4,7 @@ #![allow(deprecated)] -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap}; use std::net::{SocketAddr, IpAddr}; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -29,6 +29,12 @@ const NUM_DCS: usize = 5; const DC_PING_TIMEOUT_SECS: u64 = 5; /// Timeout for direct TG DC TCP connect readiness. const DIRECT_CONNECT_TIMEOUT_SECS: u64 = 10; +/// Interval between upstream health-check cycles. +const HEALTH_CHECK_INTERVAL_SECS: u64 = 30; +/// Timeout for a single health-check connect attempt. +const HEALTH_CHECK_CONNECT_TIMEOUT_SECS: u64 = 10; +/// Upstream is considered healthy when at least this many DC groups are reachable. +const MIN_HEALTHY_DC_GROUPS: usize = 3; // ============= RTT Tracking ============= @@ -167,6 +173,13 @@ pub struct UpstreamEgressInfo { pub socks_proxy_addr: Option, } +#[derive(Debug, Clone)] +struct HealthCheckGroup { + dc_idx: i16, + primary: Vec, + fallback: Vec, +} + // ============= Upstream Manager ============= #[derive(Clone)] @@ -987,41 +1000,144 @@ impl UpstreamManager { Ok(start.elapsed().as_secs_f64() * 1000.0) } + fn required_healthy_group_count(total_groups: usize) -> usize { + if total_groups == 0 { + 0 + } else { + total_groups.min(MIN_HEALTHY_DC_GROUPS) + } + } + + fn build_health_check_groups( + prefer_ipv6: bool, + ipv4_enabled: bool, + ipv6_enabled: bool, + dc_overrides: &HashMap>, + ) -> Vec { + let mut v4_by_dc: HashMap> = HashMap::new(); + let mut v6_by_dc: HashMap> = HashMap::new(); + + if ipv4_enabled { + for (idx, dc_ip) in TG_DATACENTERS_V4.iter().enumerate() { + let dc_idx = (idx + 1) as i16; + v4_by_dc + .entry(dc_idx) + .or_default() + .push(SocketAddr::new(*dc_ip, TG_DATACENTER_PORT)); + } + } + + if ipv6_enabled { + for (idx, dc_ip) in TG_DATACENTERS_V6.iter().enumerate() { + let dc_idx = (idx + 1) as i16; + v6_by_dc + .entry(dc_idx) + .or_default() + .push(SocketAddr::new(*dc_ip, TG_DATACENTER_PORT)); + } + } + + for (dc_key, addrs) in dc_overrides { + let dc_idx = match dc_key.parse::() { + Ok(v) if v > 0 => v, + _ => { + warn!(dc = %dc_key, "Invalid dc_overrides key for health-check, skipping"); + continue; + } + }; + + for addr_str in addrs { + match addr_str.parse::() { + Ok(addr) if addr.is_ipv6() => { + if ipv6_enabled { + v6_by_dc.entry(dc_idx).or_default().push(addr); + } + } + Ok(addr) => { + if ipv4_enabled { + v4_by_dc.entry(dc_idx).or_default().push(addr); + } + } + Err(_) => { + warn!( + dc = %dc_idx, + addr = %addr_str, + "Invalid dc_overrides address for health-check, skipping" + ); + } + } + } + } + + for addrs in v4_by_dc.values_mut() { + addrs.sort_unstable(); + addrs.dedup(); + } + for addrs in v6_by_dc.values_mut() { + addrs.sort_unstable(); + addrs.dedup(); + } + + let mut all_dcs = BTreeSet::new(); + all_dcs.extend(v4_by_dc.keys().copied()); + all_dcs.extend(v6_by_dc.keys().copied()); + + let mut groups = Vec::with_capacity(all_dcs.len()); + for dc_idx in all_dcs { + let v4_endpoints = v4_by_dc.remove(&dc_idx).unwrap_or_default(); + let v6_endpoints = v6_by_dc.remove(&dc_idx).unwrap_or_default(); + let (primary, fallback) = if prefer_ipv6 { + (v6_endpoints, v4_endpoints) + } else { + (v4_endpoints, v6_endpoints) + }; + + if primary.is_empty() && fallback.is_empty() { + continue; + } + + groups.push(HealthCheckGroup { + dc_idx, + primary, + fallback, + }); + } + + groups + } + // ============= Health Checks ============= - /// Background health check: rotates through DCs, 30s interval. - /// Uses preferred IP version based on config. - pub async fn run_health_checks(&self, prefer_ipv6: bool, ipv4_enabled: bool, ipv6_enabled: bool) { - let mut dc_rotation = 0usize; + /// Background health check based on reachable DC groups through each upstream. + /// Upstream stays healthy while at least `MIN_HEALTHY_DC_GROUPS` groups are reachable. + pub async fn run_health_checks( + &self, + prefer_ipv6: bool, + ipv4_enabled: bool, + ipv6_enabled: bool, + dc_overrides: HashMap>, + ) { + let groups = Self::build_health_check_groups( + prefer_ipv6, + ipv4_enabled, + ipv6_enabled, + &dc_overrides, + ); + let required_healthy_groups = Self::required_healthy_group_count(groups.len()); + let mut endpoint_rotation: HashMap<(usize, i16, bool), usize> = HashMap::new(); + + if groups.is_empty() { + warn!("No DC groups available for upstream health-checks"); + } loop { - tokio::time::sleep(Duration::from_secs(30)).await; + tokio::time::sleep(Duration::from_secs(HEALTH_CHECK_INTERVAL_SECS)).await; - let dc_zero_idx = dc_rotation % NUM_DCS; - dc_rotation += 1; - - let primary_v6 = SocketAddr::new(TG_DATACENTERS_V6[dc_zero_idx], TG_DATACENTER_PORT); - let primary_v4 = SocketAddr::new(TG_DATACENTERS_V4[dc_zero_idx], TG_DATACENTER_PORT); - let dc_addr = if prefer_ipv6 && ipv6_enabled { - primary_v6 - } else if ipv4_enabled { - primary_v4 - } else if ipv6_enabled { - primary_v6 - } else { + if groups.is_empty() || required_healthy_groups == 0 { continue; - }; - - let fallback_addr = if dc_addr.is_ipv6() && ipv4_enabled { - Some(primary_v4) - } else if dc_addr.is_ipv4() && ipv6_enabled { - Some(primary_v6) - } else { - None - }; + } let count = self.upstreams.read().await.len(); - for i in 0..count { let (config, bind_rr) = { let guard = self.upstreams.read().await; @@ -1029,104 +1145,123 @@ impl UpstreamManager { (u.config.clone(), u.bind_rr.clone()) }; - let start = Instant::now(); - let result = tokio::time::timeout( - Duration::from_secs(10), - self.connect_via_upstream(&config, dc_addr, Some(bind_rr.clone())) - ).await; + let mut healthy_groups = 0usize; + let mut latency_updates: Vec<(usize, f64)> = Vec::new(); - match result { - Ok(Ok(_stream)) => { - let rtt_ms = start.elapsed().as_secs_f64() * 1000.0; - let mut guard = self.upstreams.write().await; - let u = &mut guard[i]; - u.dc_latency[dc_zero_idx].update(rtt_ms); + for group in &groups { + let mut group_ok = false; + let mut group_rtt_ms = None; - if !u.healthy { - info!( - rtt = format!("{:.0} ms", rtt_ms), - dc = dc_zero_idx + 1, - "Upstream recovered" - ); - } - u.healthy = true; - u.fails = 0; - u.last_check = std::time::Instant::now(); - } - Ok(Err(_)) | Err(_) => { - // Try fallback - debug!(dc = dc_zero_idx + 1, "Health check failed, trying fallback"); - - if let Some(fallback_addr) = fallback_addr { - let start2 = Instant::now(); - let result2 = tokio::time::timeout( - Duration::from_secs(10), - self.connect_via_upstream(&config, fallback_addr, Some(bind_rr.clone())) - ).await; - - let mut guard = self.upstreams.write().await; - let u = &mut guard[i]; - - match result2 { - Ok(Ok(_stream)) => { - let rtt_ms = start2.elapsed().as_secs_f64() * 1000.0; - u.dc_latency[dc_zero_idx].update(rtt_ms); - - if !u.healthy { - info!( - rtt = format!("{:.0} ms", rtt_ms), - dc = dc_zero_idx + 1, - "Upstream recovered (fallback)" - ); - } - u.healthy = true; - u.fails = 0; - } - Ok(Err(e)) => { - u.fails += 1; - debug!(dc = dc_zero_idx + 1, fails = u.fails, - "Health check failed (both): {}", e); - if u.fails >= self.unhealthy_fail_threshold { - u.healthy = false; - warn!( - fails = u.fails, - threshold = self.unhealthy_fail_threshold, - "Upstream unhealthy (fails)" - ); - } - } - Err(_) => { - u.fails += 1; - debug!(dc = dc_zero_idx + 1, fails = u.fails, - "Health check timeout (both)"); - if u.fails >= self.unhealthy_fail_threshold { - u.healthy = false; - warn!( - fails = u.fails, - threshold = self.unhealthy_fail_threshold, - "Upstream unhealthy (timeout)" - ); - } - } - } - u.last_check = std::time::Instant::now(); + for (is_primary, endpoints) in [(true, &group.primary), (false, &group.fallback)] { + if endpoints.is_empty() { continue; } - let mut guard = self.upstreams.write().await; - let u = &mut guard[i]; - u.fails += 1; - if u.fails >= self.unhealthy_fail_threshold { - u.healthy = false; - warn!( - fails = u.fails, - threshold = self.unhealthy_fail_threshold, - "Upstream unhealthy (no fallback family)" - ); + let rotation_key = (i, group.dc_idx, is_primary); + let start_idx = *endpoint_rotation.entry(rotation_key).or_insert(0) % endpoints.len(); + let mut next_idx = (start_idx + 1) % endpoints.len(); + + for step in 0..endpoints.len() { + let endpoint_idx = (start_idx + step) % endpoints.len(); + let endpoint = endpoints[endpoint_idx]; + + let start = Instant::now(); + let result = tokio::time::timeout( + Duration::from_secs(HEALTH_CHECK_CONNECT_TIMEOUT_SECS), + self.connect_via_upstream(&config, endpoint, Some(bind_rr.clone())), + ) + .await; + + match result { + Ok(Ok(_stream)) => { + group_ok = true; + group_rtt_ms = Some(start.elapsed().as_secs_f64() * 1000.0); + next_idx = (endpoint_idx + 1) % endpoints.len(); + break; + } + Ok(Err(e)) => { + debug!( + upstream = i, + dc = group.dc_idx, + endpoint = %endpoint, + primary = is_primary, + error = %e, + "Health-check endpoint failed" + ); + } + Err(_) => { + debug!( + upstream = i, + dc = group.dc_idx, + endpoint = %endpoint, + primary = is_primary, + "Health-check endpoint timed out" + ); + } + } + } + + endpoint_rotation.insert(rotation_key, next_idx); + + if group_ok { + break; + } + } + + if group_ok { + healthy_groups += 1; + if let (Some(dc_array_idx), Some(rtt_ms)) = + (UpstreamState::dc_array_idx(group.dc_idx), group_rtt_ms) + { + latency_updates.push((dc_array_idx, rtt_ms)); } - u.last_check = std::time::Instant::now(); } } + + let mut guard = self.upstreams.write().await; + let u = &mut guard[i]; + + for (dc_array_idx, rtt_ms) in latency_updates { + u.dc_latency[dc_array_idx].update(rtt_ms); + } + + if healthy_groups >= required_healthy_groups { + if !u.healthy { + info!( + upstream = i, + healthy_groups, + total_groups = groups.len(), + required_groups = required_healthy_groups, + "Upstream recovered by DC-group health threshold" + ); + } + u.healthy = true; + u.fails = 0; + } else { + u.fails += 1; + debug!( + upstream = i, + healthy_groups, + total_groups = groups.len(), + required_groups = required_healthy_groups, + fails = u.fails, + "Upstream health-check below DC-group threshold" + ); + if u.fails >= self.unhealthy_fail_threshold { + u.healthy = false; + warn!( + upstream = i, + healthy_groups, + total_groups = groups.len(), + required_groups = required_healthy_groups, + fails = u.fails, + threshold = self.unhealthy_fail_threshold, + "Upstream unhealthy (insufficient reachable DC groups)" + ); + } + } + + u.last_check = std::time::Instant::now(); } } } @@ -1157,3 +1292,76 @@ impl UpstreamManager { Some(SocketAddr::new(ip, TG_DATACENTER_PORT)) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn required_healthy_group_count_applies_three_group_threshold() { + assert_eq!(UpstreamManager::required_healthy_group_count(0), 0); + assert_eq!(UpstreamManager::required_healthy_group_count(1), 1); + assert_eq!(UpstreamManager::required_healthy_group_count(2), 2); + assert_eq!(UpstreamManager::required_healthy_group_count(3), 3); + assert_eq!(UpstreamManager::required_healthy_group_count(5), 3); + } + + #[test] + fn build_health_check_groups_merges_family_endpoints_with_preference() { + let mut overrides = HashMap::new(); + overrides.insert( + "2".to_string(), + vec![ + "203.0.113.10:443".to_string(), + "203.0.113.11:443".to_string(), + "[2001:db8::10]:443".to_string(), + ], + ); + + let groups = UpstreamManager::build_health_check_groups(true, true, true, &overrides); + let dc2 = groups + .iter() + .find(|g| g.dc_idx == 2) + .expect("dc2 must be present"); + + assert!(dc2.primary.iter().all(|addr| addr.is_ipv6())); + assert!(dc2.fallback.iter().all(|addr| addr.is_ipv4())); + assert!(dc2 + .primary + .contains(&"[2001:db8::10]:443".parse::().unwrap())); + assert!(dc2 + .fallback + .contains(&"203.0.113.10:443".parse::().unwrap())); + assert!(dc2 + .fallback + .contains(&"203.0.113.11:443".parse::().unwrap())); + } + + #[test] + fn build_health_check_groups_keeps_multiple_endpoints_per_group() { + let mut overrides = HashMap::new(); + overrides.insert( + "9".to_string(), + vec![ + "198.51.100.1:443".to_string(), + "198.51.100.2:443".to_string(), + "198.51.100.1:443".to_string(), + ], + ); + + let groups = UpstreamManager::build_health_check_groups(false, true, false, &overrides); + let dc9 = groups + .iter() + .find(|g| g.dc_idx == 9) + .expect("override-only dc group must be present"); + + assert_eq!(dc9.primary.len(), 2); + assert!(dc9 + .primary + .contains(&"198.51.100.1:443".parse::().unwrap())); + assert!(dc9 + .primary + .contains(&"198.51.100.2:443".parse::().unwrap())); + assert!(dc9.fallback.is_empty()); + } +}