mirror of https://github.com/telemt/telemt.git
231 lines
8.5 KiB
Rust
231 lines
8.5 KiB
Rust
use std::collections::{HashMap, HashSet};
|
|
use std::net::{IpAddr, SocketAddr};
|
|
use std::sync::Arc;
|
|
|
|
use rand::Rng;
|
|
use rand::seq::SliceRandom;
|
|
use tracing::{debug, info, warn};
|
|
|
|
use crate::crypto::SecureRandom;
|
|
use crate::error::{ProxyError, Result};
|
|
|
|
use super::pool::MePool;
|
|
|
|
impl MePool {
|
|
pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &Arc<SecureRandom>) -> Result<()> {
|
|
let family_order = self.family_order();
|
|
let connect_concurrency = self.me_reconnect_max_concurrent_per_dc.max(1) as usize;
|
|
let ks = self.key_selector().await;
|
|
info!(
|
|
me_servers = self.proxy_map_v4.read().await.len(),
|
|
pool_size,
|
|
connect_concurrency,
|
|
key_selector = format_args!("0x{ks:08x}"),
|
|
secret_len = self.proxy_secret.read().await.secret.len(),
|
|
"Initializing ME pool"
|
|
);
|
|
|
|
for family in family_order {
|
|
let map = self.proxy_map_for_family(family).await;
|
|
let mut grouped_dc_addrs: HashMap<i32, Vec<(IpAddr, u16)>> = HashMap::new();
|
|
for (dc, addrs) in map {
|
|
if addrs.is_empty() {
|
|
continue;
|
|
}
|
|
grouped_dc_addrs.entry(dc.abs()).or_default().extend(addrs);
|
|
}
|
|
let mut dc_addrs: Vec<(i32, Vec<(IpAddr, u16)>)> = grouped_dc_addrs
|
|
.into_iter()
|
|
.map(|(dc, mut addrs)| {
|
|
addrs.sort_unstable();
|
|
addrs.dedup();
|
|
(dc, addrs)
|
|
})
|
|
.collect();
|
|
dc_addrs.sort_unstable_by_key(|(dc, _)| *dc);
|
|
dc_addrs.sort_by_key(|(_, addrs)| (addrs.len() != 1, addrs.len()));
|
|
|
|
// Stage 1: build base coverage for conditional-cast.
|
|
// Single-endpoint DCs are prefilled first; multi-endpoint DCs require one live writer.
|
|
let mut join = tokio::task::JoinSet::new();
|
|
for (dc, addrs) in dc_addrs.iter().cloned() {
|
|
if addrs.is_empty() {
|
|
continue;
|
|
}
|
|
let target_writers = if addrs.len() == 1 {
|
|
self.required_writers_for_dc_with_floor_mode(addrs.len(), false)
|
|
} else {
|
|
1usize
|
|
};
|
|
let endpoints: HashSet<SocketAddr> = addrs
|
|
.iter()
|
|
.map(|(ip, port)| SocketAddr::new(*ip, *port))
|
|
.collect();
|
|
if self.active_writer_count_for_endpoints(&endpoints).await >= target_writers {
|
|
continue;
|
|
}
|
|
let pool = Arc::clone(self);
|
|
let rng_clone = Arc::clone(rng);
|
|
join.spawn(async move {
|
|
pool.connect_primary_for_dc(
|
|
dc,
|
|
addrs,
|
|
target_writers,
|
|
rng_clone,
|
|
connect_concurrency,
|
|
)
|
|
.await
|
|
});
|
|
}
|
|
while join.join_next().await.is_some() {}
|
|
|
|
let mut missing_dcs = Vec::new();
|
|
for (dc, addrs) in &dc_addrs {
|
|
let endpoints: HashSet<SocketAddr> = addrs
|
|
.iter()
|
|
.map(|(ip, port)| SocketAddr::new(*ip, *port))
|
|
.collect();
|
|
if self.active_writer_count_for_endpoints(&endpoints).await == 0 {
|
|
missing_dcs.push(*dc);
|
|
}
|
|
}
|
|
if !missing_dcs.is_empty() {
|
|
return Err(ProxyError::Proxy(format!(
|
|
"ME init incomplete: no live writers for DC groups {missing_dcs:?}"
|
|
)));
|
|
}
|
|
|
|
// Stage 2: continue saturating multi-endpoint DC groups in background.
|
|
let pool = Arc::clone(self);
|
|
let rng_clone = Arc::clone(rng);
|
|
let dc_addrs_bg = dc_addrs.clone();
|
|
tokio::spawn(async move {
|
|
let mut join_bg = tokio::task::JoinSet::new();
|
|
for (dc, addrs) in dc_addrs_bg {
|
|
if addrs.len() <= 1 {
|
|
continue;
|
|
}
|
|
let target_writers = pool.required_writers_for_dc_with_floor_mode(addrs.len(), false);
|
|
let pool_clone = Arc::clone(&pool);
|
|
let rng_clone_local = Arc::clone(&rng_clone);
|
|
join_bg.spawn(async move {
|
|
pool_clone
|
|
.connect_primary_for_dc(
|
|
dc,
|
|
addrs,
|
|
target_writers,
|
|
rng_clone_local,
|
|
connect_concurrency,
|
|
)
|
|
.await
|
|
});
|
|
}
|
|
while join_bg.join_next().await.is_some() {}
|
|
debug!(
|
|
current_pool_size = pool.connection_count(),
|
|
"Background ME saturation warmup finished"
|
|
);
|
|
});
|
|
|
|
if !self.decision.effective_multipath && self.connection_count() > 0 {
|
|
break;
|
|
}
|
|
}
|
|
|
|
if self.writers.read().await.is_empty() {
|
|
return Err(ProxyError::Proxy("No ME connections".into()));
|
|
}
|
|
info!(
|
|
active_writers = self.connection_count(),
|
|
"ME primary pool ready; reserve warmup continues in background"
|
|
);
|
|
Ok(())
|
|
}
|
|
|
|
async fn connect_primary_for_dc(
|
|
self: Arc<Self>,
|
|
dc: i32,
|
|
mut addrs: Vec<(IpAddr, u16)>,
|
|
target_writers: usize,
|
|
rng: Arc<SecureRandom>,
|
|
connect_concurrency: usize,
|
|
) -> bool {
|
|
if addrs.is_empty() {
|
|
return false;
|
|
}
|
|
let target_writers = target_writers.max(1);
|
|
addrs.shuffle(&mut rand::rng());
|
|
let endpoints: Vec<SocketAddr> = addrs
|
|
.iter()
|
|
.map(|(ip, port)| SocketAddr::new(*ip, *port))
|
|
.collect();
|
|
let endpoint_set: HashSet<SocketAddr> = endpoints.iter().copied().collect();
|
|
|
|
loop {
|
|
let alive = self.active_writer_count_for_endpoints(&endpoint_set).await;
|
|
if alive >= target_writers {
|
|
info!(
|
|
dc = %dc,
|
|
alive,
|
|
target_writers,
|
|
"ME connected"
|
|
);
|
|
return true;
|
|
}
|
|
|
|
let missing = target_writers.saturating_sub(alive).max(1);
|
|
let concurrency = connect_concurrency.max(1).min(missing);
|
|
let mut join = tokio::task::JoinSet::new();
|
|
for _ in 0..concurrency {
|
|
let pool = Arc::clone(&self);
|
|
let rng_clone = Arc::clone(&rng);
|
|
let endpoints_clone = endpoints.clone();
|
|
join.spawn(async move {
|
|
pool.connect_endpoints_round_robin(&endpoints_clone, rng_clone.as_ref())
|
|
.await
|
|
});
|
|
}
|
|
|
|
let mut progress = false;
|
|
while let Some(res) = join.join_next().await {
|
|
match res {
|
|
Ok(true) => {
|
|
progress = true;
|
|
}
|
|
Ok(false) => {}
|
|
Err(e) => {
|
|
warn!(dc = %dc, error = %e, "ME connect task failed");
|
|
}
|
|
}
|
|
}
|
|
|
|
let alive_after = self.active_writer_count_for_endpoints(&endpoint_set).await;
|
|
if alive_after >= target_writers {
|
|
info!(
|
|
dc = %dc,
|
|
alive = alive_after,
|
|
target_writers,
|
|
"ME connected"
|
|
);
|
|
return true;
|
|
}
|
|
if !progress {
|
|
warn!(
|
|
dc = %dc,
|
|
alive = alive_after,
|
|
target_writers,
|
|
"All ME servers for DC failed at init"
|
|
);
|
|
return false;
|
|
}
|
|
|
|
if self.me_warmup_stagger_enabled {
|
|
let jitter = rand::rng()
|
|
.random_range(0..=self.me_warmup_step_jitter.as_millis() as u64);
|
|
let delay_ms = self.me_warmup_step_delay.as_millis() as u64 + jitter;
|
|
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
|
|
}
|
|
}
|
|
}
|
|
}
|