UpstreamManager Health-check for ME Pool over SOCKS

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey 2026-03-01 04:02:32 +03:00
parent 44cdfd4b23
commit 47b12f9489
No known key found for this signature in database
2 changed files with 328 additions and 118 deletions

View File

@ -770,12 +770,14 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
// Background tasks
let um_clone = upstream_manager.clone();
let decision_clone = decision.clone();
let dc_overrides_for_health = config.dc_overrides.clone();
tokio::spawn(async move {
um_clone
.run_health_checks(
prefer_ipv6,
decision_clone.ipv4_dc,
decision_clone.ipv6_dc,
dc_overrides_for_health,
)
.await;
});

View File

@ -4,7 +4,7 @@
#![allow(deprecated)]
use std::collections::HashMap;
use std::collections::{BTreeSet, HashMap};
use std::net::{SocketAddr, IpAddr};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
@ -29,6 +29,12 @@ const NUM_DCS: usize = 5;
const DC_PING_TIMEOUT_SECS: u64 = 5;
/// Timeout for direct TG DC TCP connect readiness.
const DIRECT_CONNECT_TIMEOUT_SECS: u64 = 10;
/// Interval between upstream health-check cycles.
const HEALTH_CHECK_INTERVAL_SECS: u64 = 30;
/// Timeout for a single health-check connect attempt.
const HEALTH_CHECK_CONNECT_TIMEOUT_SECS: u64 = 10;
/// Upstream is considered healthy when at least this many DC groups are reachable.
const MIN_HEALTHY_DC_GROUPS: usize = 3;
// ============= RTT Tracking =============
@ -167,6 +173,13 @@ pub struct UpstreamEgressInfo {
pub socks_proxy_addr: Option<SocketAddr>,
}
#[derive(Debug, Clone)]
struct HealthCheckGroup {
dc_idx: i16,
primary: Vec<SocketAddr>,
fallback: Vec<SocketAddr>,
}
// ============= Upstream Manager =============
#[derive(Clone)]
@ -987,41 +1000,144 @@ impl UpstreamManager {
Ok(start.elapsed().as_secs_f64() * 1000.0)
}
fn required_healthy_group_count(total_groups: usize) -> usize {
if total_groups == 0 {
0
} else {
total_groups.min(MIN_HEALTHY_DC_GROUPS)
}
}
fn build_health_check_groups(
prefer_ipv6: bool,
ipv4_enabled: bool,
ipv6_enabled: bool,
dc_overrides: &HashMap<String, Vec<String>>,
) -> Vec<HealthCheckGroup> {
let mut v4_by_dc: HashMap<i16, Vec<SocketAddr>> = HashMap::new();
let mut v6_by_dc: HashMap<i16, Vec<SocketAddr>> = HashMap::new();
if ipv4_enabled {
for (idx, dc_ip) in TG_DATACENTERS_V4.iter().enumerate() {
let dc_idx = (idx + 1) as i16;
v4_by_dc
.entry(dc_idx)
.or_default()
.push(SocketAddr::new(*dc_ip, TG_DATACENTER_PORT));
}
}
if ipv6_enabled {
for (idx, dc_ip) in TG_DATACENTERS_V6.iter().enumerate() {
let dc_idx = (idx + 1) as i16;
v6_by_dc
.entry(dc_idx)
.or_default()
.push(SocketAddr::new(*dc_ip, TG_DATACENTER_PORT));
}
}
for (dc_key, addrs) in dc_overrides {
let dc_idx = match dc_key.parse::<i16>() {
Ok(v) if v > 0 => v,
_ => {
warn!(dc = %dc_key, "Invalid dc_overrides key for health-check, skipping");
continue;
}
};
for addr_str in addrs {
match addr_str.parse::<SocketAddr>() {
Ok(addr) if addr.is_ipv6() => {
if ipv6_enabled {
v6_by_dc.entry(dc_idx).or_default().push(addr);
}
}
Ok(addr) => {
if ipv4_enabled {
v4_by_dc.entry(dc_idx).or_default().push(addr);
}
}
Err(_) => {
warn!(
dc = %dc_idx,
addr = %addr_str,
"Invalid dc_overrides address for health-check, skipping"
);
}
}
}
}
for addrs in v4_by_dc.values_mut() {
addrs.sort_unstable();
addrs.dedup();
}
for addrs in v6_by_dc.values_mut() {
addrs.sort_unstable();
addrs.dedup();
}
let mut all_dcs = BTreeSet::new();
all_dcs.extend(v4_by_dc.keys().copied());
all_dcs.extend(v6_by_dc.keys().copied());
let mut groups = Vec::with_capacity(all_dcs.len());
for dc_idx in all_dcs {
let v4_endpoints = v4_by_dc.remove(&dc_idx).unwrap_or_default();
let v6_endpoints = v6_by_dc.remove(&dc_idx).unwrap_or_default();
let (primary, fallback) = if prefer_ipv6 {
(v6_endpoints, v4_endpoints)
} else {
(v4_endpoints, v6_endpoints)
};
if primary.is_empty() && fallback.is_empty() {
continue;
}
groups.push(HealthCheckGroup {
dc_idx,
primary,
fallback,
});
}
groups
}
// ============= Health Checks =============
/// Background health check: rotates through DCs, 30s interval.
/// Uses preferred IP version based on config.
pub async fn run_health_checks(&self, prefer_ipv6: bool, ipv4_enabled: bool, ipv6_enabled: bool) {
let mut dc_rotation = 0usize;
/// Background health check based on reachable DC groups through each upstream.
/// Upstream stays healthy while at least `MIN_HEALTHY_DC_GROUPS` groups are reachable.
pub async fn run_health_checks(
&self,
prefer_ipv6: bool,
ipv4_enabled: bool,
ipv6_enabled: bool,
dc_overrides: HashMap<String, Vec<String>>,
) {
let groups = Self::build_health_check_groups(
prefer_ipv6,
ipv4_enabled,
ipv6_enabled,
&dc_overrides,
);
let required_healthy_groups = Self::required_healthy_group_count(groups.len());
let mut endpoint_rotation: HashMap<(usize, i16, bool), usize> = HashMap::new();
if groups.is_empty() {
warn!("No DC groups available for upstream health-checks");
}
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
tokio::time::sleep(Duration::from_secs(HEALTH_CHECK_INTERVAL_SECS)).await;
let dc_zero_idx = dc_rotation % NUM_DCS;
dc_rotation += 1;
let primary_v6 = SocketAddr::new(TG_DATACENTERS_V6[dc_zero_idx], TG_DATACENTER_PORT);
let primary_v4 = SocketAddr::new(TG_DATACENTERS_V4[dc_zero_idx], TG_DATACENTER_PORT);
let dc_addr = if prefer_ipv6 && ipv6_enabled {
primary_v6
} else if ipv4_enabled {
primary_v4
} else if ipv6_enabled {
primary_v6
} else {
if groups.is_empty() || required_healthy_groups == 0 {
continue;
};
let fallback_addr = if dc_addr.is_ipv6() && ipv4_enabled {
Some(primary_v4)
} else if dc_addr.is_ipv4() && ipv6_enabled {
Some(primary_v6)
} else {
None
};
}
let count = self.upstreams.read().await.len();
for i in 0..count {
let (config, bind_rr) = {
let guard = self.upstreams.read().await;
@ -1029,104 +1145,123 @@ impl UpstreamManager {
(u.config.clone(), u.bind_rr.clone())
};
let start = Instant::now();
let result = tokio::time::timeout(
Duration::from_secs(10),
self.connect_via_upstream(&config, dc_addr, Some(bind_rr.clone()))
).await;
let mut healthy_groups = 0usize;
let mut latency_updates: Vec<(usize, f64)> = Vec::new();
match result {
Ok(Ok(_stream)) => {
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
let mut guard = self.upstreams.write().await;
let u = &mut guard[i];
u.dc_latency[dc_zero_idx].update(rtt_ms);
for group in &groups {
let mut group_ok = false;
let mut group_rtt_ms = None;
if !u.healthy {
info!(
rtt = format!("{:.0} ms", rtt_ms),
dc = dc_zero_idx + 1,
"Upstream recovered"
);
}
u.healthy = true;
u.fails = 0;
u.last_check = std::time::Instant::now();
}
Ok(Err(_)) | Err(_) => {
// Try fallback
debug!(dc = dc_zero_idx + 1, "Health check failed, trying fallback");
if let Some(fallback_addr) = fallback_addr {
let start2 = Instant::now();
let result2 = tokio::time::timeout(
Duration::from_secs(10),
self.connect_via_upstream(&config, fallback_addr, Some(bind_rr.clone()))
).await;
let mut guard = self.upstreams.write().await;
let u = &mut guard[i];
match result2 {
Ok(Ok(_stream)) => {
let rtt_ms = start2.elapsed().as_secs_f64() * 1000.0;
u.dc_latency[dc_zero_idx].update(rtt_ms);
if !u.healthy {
info!(
rtt = format!("{:.0} ms", rtt_ms),
dc = dc_zero_idx + 1,
"Upstream recovered (fallback)"
);
}
u.healthy = true;
u.fails = 0;
}
Ok(Err(e)) => {
u.fails += 1;
debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check failed (both): {}", e);
if u.fails >= self.unhealthy_fail_threshold {
u.healthy = false;
warn!(
fails = u.fails,
threshold = self.unhealthy_fail_threshold,
"Upstream unhealthy (fails)"
);
}
}
Err(_) => {
u.fails += 1;
debug!(dc = dc_zero_idx + 1, fails = u.fails,
"Health check timeout (both)");
if u.fails >= self.unhealthy_fail_threshold {
u.healthy = false;
warn!(
fails = u.fails,
threshold = self.unhealthy_fail_threshold,
"Upstream unhealthy (timeout)"
);
}
}
}
u.last_check = std::time::Instant::now();
for (is_primary, endpoints) in [(true, &group.primary), (false, &group.fallback)] {
if endpoints.is_empty() {
continue;
}
let mut guard = self.upstreams.write().await;
let u = &mut guard[i];
u.fails += 1;
if u.fails >= self.unhealthy_fail_threshold {
u.healthy = false;
warn!(
fails = u.fails,
threshold = self.unhealthy_fail_threshold,
"Upstream unhealthy (no fallback family)"
);
let rotation_key = (i, group.dc_idx, is_primary);
let start_idx = *endpoint_rotation.entry(rotation_key).or_insert(0) % endpoints.len();
let mut next_idx = (start_idx + 1) % endpoints.len();
for step in 0..endpoints.len() {
let endpoint_idx = (start_idx + step) % endpoints.len();
let endpoint = endpoints[endpoint_idx];
let start = Instant::now();
let result = tokio::time::timeout(
Duration::from_secs(HEALTH_CHECK_CONNECT_TIMEOUT_SECS),
self.connect_via_upstream(&config, endpoint, Some(bind_rr.clone())),
)
.await;
match result {
Ok(Ok(_stream)) => {
group_ok = true;
group_rtt_ms = Some(start.elapsed().as_secs_f64() * 1000.0);
next_idx = (endpoint_idx + 1) % endpoints.len();
break;
}
Ok(Err(e)) => {
debug!(
upstream = i,
dc = group.dc_idx,
endpoint = %endpoint,
primary = is_primary,
error = %e,
"Health-check endpoint failed"
);
}
Err(_) => {
debug!(
upstream = i,
dc = group.dc_idx,
endpoint = %endpoint,
primary = is_primary,
"Health-check endpoint timed out"
);
}
}
}
endpoint_rotation.insert(rotation_key, next_idx);
if group_ok {
break;
}
}
if group_ok {
healthy_groups += 1;
if let (Some(dc_array_idx), Some(rtt_ms)) =
(UpstreamState::dc_array_idx(group.dc_idx), group_rtt_ms)
{
latency_updates.push((dc_array_idx, rtt_ms));
}
u.last_check = std::time::Instant::now();
}
}
let mut guard = self.upstreams.write().await;
let u = &mut guard[i];
for (dc_array_idx, rtt_ms) in latency_updates {
u.dc_latency[dc_array_idx].update(rtt_ms);
}
if healthy_groups >= required_healthy_groups {
if !u.healthy {
info!(
upstream = i,
healthy_groups,
total_groups = groups.len(),
required_groups = required_healthy_groups,
"Upstream recovered by DC-group health threshold"
);
}
u.healthy = true;
u.fails = 0;
} else {
u.fails += 1;
debug!(
upstream = i,
healthy_groups,
total_groups = groups.len(),
required_groups = required_healthy_groups,
fails = u.fails,
"Upstream health-check below DC-group threshold"
);
if u.fails >= self.unhealthy_fail_threshold {
u.healthy = false;
warn!(
upstream = i,
healthy_groups,
total_groups = groups.len(),
required_groups = required_healthy_groups,
fails = u.fails,
threshold = self.unhealthy_fail_threshold,
"Upstream unhealthy (insufficient reachable DC groups)"
);
}
}
u.last_check = std::time::Instant::now();
}
}
}
@ -1157,3 +1292,76 @@ impl UpstreamManager {
Some(SocketAddr::new(ip, TG_DATACENTER_PORT))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn required_healthy_group_count_applies_three_group_threshold() {
assert_eq!(UpstreamManager::required_healthy_group_count(0), 0);
assert_eq!(UpstreamManager::required_healthy_group_count(1), 1);
assert_eq!(UpstreamManager::required_healthy_group_count(2), 2);
assert_eq!(UpstreamManager::required_healthy_group_count(3), 3);
assert_eq!(UpstreamManager::required_healthy_group_count(5), 3);
}
#[test]
fn build_health_check_groups_merges_family_endpoints_with_preference() {
let mut overrides = HashMap::new();
overrides.insert(
"2".to_string(),
vec![
"203.0.113.10:443".to_string(),
"203.0.113.11:443".to_string(),
"[2001:db8::10]:443".to_string(),
],
);
let groups = UpstreamManager::build_health_check_groups(true, true, true, &overrides);
let dc2 = groups
.iter()
.find(|g| g.dc_idx == 2)
.expect("dc2 must be present");
assert!(dc2.primary.iter().all(|addr| addr.is_ipv6()));
assert!(dc2.fallback.iter().all(|addr| addr.is_ipv4()));
assert!(dc2
.primary
.contains(&"[2001:db8::10]:443".parse::<SocketAddr>().unwrap()));
assert!(dc2
.fallback
.contains(&"203.0.113.10:443".parse::<SocketAddr>().unwrap()));
assert!(dc2
.fallback
.contains(&"203.0.113.11:443".parse::<SocketAddr>().unwrap()));
}
#[test]
fn build_health_check_groups_keeps_multiple_endpoints_per_group() {
let mut overrides = HashMap::new();
overrides.insert(
"9".to_string(),
vec![
"198.51.100.1:443".to_string(),
"198.51.100.2:443".to_string(),
"198.51.100.1:443".to_string(),
],
);
let groups = UpstreamManager::build_health_check_groups(false, true, false, &overrides);
let dc9 = groups
.iter()
.find(|g| g.dc_idx == 9)
.expect("override-only dc group must be present");
assert_eq!(dc9.primary.len(), 2);
assert!(dc9
.primary
.contains(&"198.51.100.1:443".parse::<SocketAddr>().unwrap()));
assert!(dc9
.primary
.contains(&"198.51.100.2:443".parse::<SocketAddr>().unwrap()));
assert!(dc9.fallback.is_empty());
}
}