Secret Atomic Snapshot + KDF Fingerprint on RwLock

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey 2026-03-05 23:18:26 +03:00
parent e7cb9238dc
commit 9b84fc7a5b
No known key found for this signature in database
6 changed files with 184 additions and 82 deletions

View File

@ -548,6 +548,12 @@ impl ProxyConfig {
config.general.middle_proxy_nat_probe = true; config.general.middle_proxy_nat_probe = true;
warn!("Auto-enabled middle_proxy_nat_probe for middle proxy mode"); warn!("Auto-enabled middle_proxy_nat_probe for middle proxy mode");
} }
if config.general.use_middle_proxy && !config.general.me_secret_atomic_snapshot {
config.general.me_secret_atomic_snapshot = true;
warn!(
"Auto-enabled me_secret_atomic_snapshot for middle proxy mode to keep KDF key_selector/secret coherent"
);
}
validate_network_cfg(&mut config.network)?; validate_network_cfg(&mut config.network)?;
crate::network::dns_overrides::validate_entries(&config.network.dns_overrides)?; crate::network::dns_overrides::validate_entries(&config.network.dns_overrides)?;

View File

@ -1339,7 +1339,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let (admission_tx, admission_rx) = watch::channel(true); let (admission_tx, admission_rx) = watch::channel(true);
if config.general.use_middle_proxy { if config.general.use_middle_proxy {
if let Some(pool) = me_pool.as_ref() { if let Some(pool) = me_pool.as_ref() {
let initial_open = pool.admission_ready_full_floor().await; let initial_open = pool.admission_ready_conditional_cast().await;
admission_tx.send_replace(initial_open); admission_tx.send_replace(initial_open);
if initial_open { if initial_open {
info!("Conditional-admission gate: open (ME pool ready)"); info!("Conditional-admission gate: open (ME pool ready)");
@ -1354,7 +1354,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let mut open_streak = if initial_open { 1u32 } else { 0u32 }; let mut open_streak = if initial_open { 1u32 } else { 0u32 };
let mut close_streak = if initial_open { 0u32 } else { 1u32 }; let mut close_streak = if initial_open { 0u32 } else { 1u32 };
loop { loop {
let ready = pool_for_gate.admission_ready_full_floor().await; let ready = pool_for_gate.admission_ready_conditional_cast().await;
if ready { if ready {
open_streak = open_streak.saturating_add(1); open_streak = open_streak.saturating_add(1);
close_streak = 0; close_streak = 0;
@ -1374,7 +1374,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
admission_tx_gate.send_replace(false); admission_tx_gate.send_replace(false);
warn!( warn!(
close_streak, close_streak,
"Conditional-admission gate closed (ME pool below required floor)" "Conditional-admission gate closed (ME pool has uncovered DC groups)"
); );
} }
} }

View File

@ -387,9 +387,11 @@ impl MePool {
socks_bound_addr.map(|value| value.ip()), socks_bound_addr.map(|value| value.ip()),
client_port_source, client_port_source,
); );
let mut kdf_fingerprint_guard = self.kdf_material_fingerprint.lock().await; let previous_kdf_fingerprint = {
if let Some((prev_fingerprint, prev_client_port)) = let kdf_fingerprint_guard = self.kdf_material_fingerprint.read().await;
kdf_fingerprint_guard.get(&peer_addr_nat).copied() kdf_fingerprint_guard.get(&peer_addr_nat).copied()
};
if let Some((prev_fingerprint, prev_client_port)) = previous_kdf_fingerprint
{ {
if prev_fingerprint != kdf_fingerprint { if prev_fingerprint != kdf_fingerprint {
self.stats.increment_me_kdf_drift_total(); self.stats.increment_me_kdf_drift_total();
@ -416,6 +418,9 @@ impl MePool {
); );
} }
} }
// Keep fingerprint updates eventually consistent for diagnostics while avoiding
// serializing all concurrent handshakes on a single async mutex.
let mut kdf_fingerprint_guard = self.kdf_material_fingerprint.write().await;
kdf_fingerprint_guard.insert(peer_addr_nat, (kdf_fingerprint, client_port_for_kdf)); kdf_fingerprint_guard.insert(peer_addr_nat, (kdf_fingerprint, client_port_for_kdf));
drop(kdf_fingerprint_guard); drop(kdf_fingerprint_guard);

View File

@ -132,7 +132,7 @@ pub struct MePool {
pub(super) pending_hardswap_map_hash: AtomicU64, pub(super) pending_hardswap_map_hash: AtomicU64,
pub(super) hardswap: AtomicBool, pub(super) hardswap: AtomicBool,
pub(super) endpoint_quarantine: Arc<Mutex<HashMap<SocketAddr, Instant>>>, pub(super) endpoint_quarantine: Arc<Mutex<HashMap<SocketAddr, Instant>>>,
pub(super) kdf_material_fingerprint: Arc<Mutex<HashMap<SocketAddr, (u64, u16)>>>, pub(super) kdf_material_fingerprint: Arc<RwLock<HashMap<SocketAddr, (u64, u16)>>>,
pub(super) me_pool_drain_ttl_secs: AtomicU64, pub(super) me_pool_drain_ttl_secs: AtomicU64,
pub(super) me_pool_force_close_secs: AtomicU64, pub(super) me_pool_force_close_secs: AtomicU64,
pub(super) me_pool_min_fresh_ratio_permille: AtomicU32, pub(super) me_pool_min_fresh_ratio_permille: AtomicU32,
@ -335,7 +335,7 @@ impl MePool {
pending_hardswap_map_hash: AtomicU64::new(0), pending_hardswap_map_hash: AtomicU64::new(0),
hardswap: AtomicBool::new(hardswap), hardswap: AtomicBool::new(hardswap),
endpoint_quarantine: Arc::new(Mutex::new(HashMap::new())), endpoint_quarantine: Arc::new(Mutex::new(HashMap::new())),
kdf_material_fingerprint: Arc::new(Mutex::new(HashMap::new())), kdf_material_fingerprint: Arc::new(RwLock::new(HashMap::new())),
me_pool_drain_ttl_secs: AtomicU64::new(me_pool_drain_ttl_secs), me_pool_drain_ttl_secs: AtomicU64::new(me_pool_drain_ttl_secs),
me_pool_force_close_secs: AtomicU64::new(me_pool_force_close_secs), me_pool_force_close_secs: AtomicU64::new(me_pool_force_close_secs),
me_pool_min_fresh_ratio_permille: AtomicU32::new(Self::ratio_to_permille( me_pool_min_fresh_ratio_permille: AtomicU32::new(Self::ratio_to_permille(

View File

@ -14,10 +14,12 @@ use super::pool::MePool;
impl MePool { impl MePool {
pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &Arc<SecureRandom>) -> Result<()> { pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &Arc<SecureRandom>) -> Result<()> {
let family_order = self.family_order(); let family_order = self.family_order();
let connect_concurrency = self.me_reconnect_max_concurrent_per_dc.max(1) as usize;
let ks = self.key_selector().await; let ks = self.key_selector().await;
info!( info!(
me_servers = self.proxy_map_v4.read().await.len(), me_servers = self.proxy_map_v4.read().await.len(),
pool_size, pool_size,
connect_concurrency,
key_selector = format_args!("0x{ks:08x}"), key_selector = format_args!("0x{ks:08x}"),
secret_len = self.proxy_secret.read().await.secret.len(), secret_len = self.proxy_secret.read().await.secret.len(),
"Initializing ME pool" "Initializing ME pool"
@ -41,23 +43,39 @@ impl MePool {
}) })
.collect(); .collect();
dc_addrs.sort_unstable_by_key(|(dc, _)| *dc); dc_addrs.sort_unstable_by_key(|(dc, _)| *dc);
dc_addrs.sort_by_key(|(_, addrs)| (addrs.len() != 1, addrs.len()));
// Ensure at least one live writer per DC group; run missing DCs in parallel. // Stage 1: build base coverage for conditional-cast.
// Single-endpoint DCs are prefilled first; multi-endpoint DCs require one live writer.
let mut join = tokio::task::JoinSet::new(); let mut join = tokio::task::JoinSet::new();
for (dc, addrs) in dc_addrs.iter().cloned() { for (dc, addrs) in dc_addrs.iter().cloned() {
if addrs.is_empty() { if addrs.is_empty() {
continue; continue;
} }
let target_writers = if addrs.len() == 1 {
self.required_writers_for_dc_with_floor_mode(addrs.len(), false)
} else {
1usize
};
let endpoints: HashSet<SocketAddr> = addrs let endpoints: HashSet<SocketAddr> = addrs
.iter() .iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port)) .map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect(); .collect();
if self.active_writer_count_for_endpoints(&endpoints).await > 0 { if self.active_writer_count_for_endpoints(&endpoints).await >= target_writers {
continue; continue;
} }
let pool = Arc::clone(self); let pool = Arc::clone(self);
let rng_clone = Arc::clone(rng); let rng_clone = Arc::clone(rng);
join.spawn(async move { pool.connect_primary_for_dc(dc, addrs, rng_clone).await }); join.spawn(async move {
pool.connect_primary_for_dc(
dc,
addrs,
target_writers,
rng_clone,
connect_concurrency,
)
.await
});
} }
while join.join_next().await.is_some() {} while join.join_next().await.is_some() {}
@ -77,47 +95,35 @@ impl MePool {
))); )));
} }
// Warm reserve writers asynchronously so startup does not block after first working pool is ready. // Stage 2: continue saturating multi-endpoint DC groups in background.
let pool = Arc::clone(self); let pool = Arc::clone(self);
let rng_clone = Arc::clone(rng); let rng_clone = Arc::clone(rng);
let dc_addrs_bg = dc_addrs.clone(); let dc_addrs_bg = dc_addrs.clone();
tokio::spawn(async move { tokio::spawn(async move {
if pool.me_warmup_stagger_enabled { let mut join_bg = tokio::task::JoinSet::new();
for (dc, addrs) in &dc_addrs_bg { for (dc, addrs) in dc_addrs_bg {
for (ip, port) in addrs { if addrs.len() <= 1 {
if pool.connection_count() >= pool_size { continue;
break;
}
let addr = SocketAddr::new(*ip, *port);
let jitter = rand::rng()
.random_range(0..=pool.me_warmup_step_jitter.as_millis() as u64);
let delay_ms = pool.me_warmup_step_delay.as_millis() as u64 + jitter;
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
if let Err(e) = pool.connect_one(addr, rng_clone.as_ref()).await {
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)");
}
}
}
} else {
for (dc, addrs) in &dc_addrs_bg {
for (ip, port) in addrs {
if pool.connection_count() >= pool_size {
break;
}
let addr = SocketAddr::new(*ip, *port);
if let Err(e) = pool.connect_one(addr, rng_clone.as_ref()).await {
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed");
}
}
if pool.connection_count() >= pool_size {
break;
}
} }
let target_writers = pool.required_writers_for_dc_with_floor_mode(addrs.len(), false);
let pool_clone = Arc::clone(&pool);
let rng_clone_local = Arc::clone(&rng_clone);
join_bg.spawn(async move {
pool_clone
.connect_primary_for_dc(
dc,
addrs,
target_writers,
rng_clone_local,
connect_concurrency,
)
.await
});
} }
while join_bg.join_next().await.is_some() {}
debug!( debug!(
target_pool_size = pool_size,
current_pool_size = pool.connection_count(), current_pool_size = pool.connection_count(),
"Background ME reserve warmup finished" "Background ME saturation warmup finished"
); );
}); });
@ -140,62 +146,85 @@ impl MePool {
self: Arc<Self>, self: Arc<Self>,
dc: i32, dc: i32,
mut addrs: Vec<(IpAddr, u16)>, mut addrs: Vec<(IpAddr, u16)>,
target_writers: usize,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
connect_concurrency: usize,
) -> bool { ) -> bool {
if addrs.is_empty() { if addrs.is_empty() {
return false; return false;
} }
let target_writers = target_writers.max(1);
addrs.shuffle(&mut rand::rng()); addrs.shuffle(&mut rand::rng());
if addrs.len() > 1 { let endpoints: Vec<SocketAddr> = addrs
let concurrency = 2usize; .iter()
let mut join = tokio::task::JoinSet::new(); .map(|(ip, port)| SocketAddr::new(*ip, *port))
let mut next_idx = 0usize; .collect();
let endpoint_set: HashSet<SocketAddr> = endpoints.iter().copied().collect();
while next_idx < addrs.len() || !join.is_empty() { loop {
while next_idx < addrs.len() && join.len() < concurrency { let alive = self.active_writer_count_for_endpoints(&endpoint_set).await;
let (ip, port) = addrs[next_idx]; if alive >= target_writers {
next_idx += 1; info!(
let addr = SocketAddr::new(ip, port); dc = %dc,
alive,
target_writers,
"ME connected"
);
return true;
}
let missing = target_writers.saturating_sub(alive).max(1);
let concurrency = connect_concurrency.max(1).min(missing);
let mut join = tokio::task::JoinSet::new();
for _ in 0..concurrency {
let pool = Arc::clone(&self); let pool = Arc::clone(&self);
let rng_clone = Arc::clone(&rng); let rng_clone = Arc::clone(&rng);
let endpoints_clone = endpoints.clone();
join.spawn(async move { join.spawn(async move {
(addr, pool.connect_one(addr, rng_clone.as_ref()).await) pool.connect_endpoints_round_robin(&endpoints_clone, rng_clone.as_ref())
.await
}); });
} }
let Some(res) = join.join_next().await else { let mut progress = false;
break; while let Some(res) = join.join_next().await {
};
match res { match res {
Ok((addr, Ok(()))) => { Ok(true) => {
info!(%addr, dc = %dc, "ME connected"); progress = true;
join.abort_all();
while join.join_next().await.is_some() {}
return true;
}
Ok((addr, Err(e))) => {
warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next");
} }
Ok(false) => {}
Err(e) => { Err(e) => {
warn!(dc = %dc, error = %e, "ME connect task failed"); warn!(dc = %dc, error = %e, "ME connect task failed");
} }
} }
} }
warn!(dc = %dc, "All ME servers for DC failed at init");
let alive_after = self.active_writer_count_for_endpoints(&endpoint_set).await;
if alive_after >= target_writers {
info!(
dc = %dc,
alive = alive_after,
target_writers,
"ME connected"
);
return true;
}
if !progress {
warn!(
dc = %dc,
alive = alive_after,
target_writers,
"All ME servers for DC failed at init"
);
return false; return false;
} }
for (ip, port) in addrs { if self.me_warmup_stagger_enabled {
let addr = SocketAddr::new(ip, port); let jitter = rand::rng()
match self.connect_one(addr, rng.as_ref()).await { .random_range(0..=self.me_warmup_step_jitter.as_millis() as u64);
Ok(()) => { let delay_ms = self.me_warmup_step_delay.as_millis() as u64 + jitter;
info!(%addr, dc = %dc, "ME connected"); tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
return true;
}
Err(e) => warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"),
} }
} }
warn!(dc = %dc, "All ME servers for DC failed at init");
false
} }
} }

View File

@ -100,6 +100,68 @@ pub(crate) struct MeApiRuntimeSnapshot {
} }
impl MePool { impl MePool {
pub(crate) async fn admission_ready_conditional_cast(&self) -> bool {
let mut endpoints_by_dc = BTreeMap::<i16, BTreeSet<SocketAddr>>::new();
if self.decision.ipv4_me {
let map = self.proxy_map_v4.read().await.clone();
for (dc, addrs) in map {
let abs_dc = dc.abs();
if abs_dc == 0 {
continue;
}
let Ok(dc_idx) = i16::try_from(abs_dc) else {
continue;
};
let entry = endpoints_by_dc.entry(dc_idx).or_default();
for (ip, port) in addrs {
entry.insert(SocketAddr::new(ip, port));
}
}
}
if self.decision.ipv6_me {
let map = self.proxy_map_v6.read().await.clone();
for (dc, addrs) in map {
let abs_dc = dc.abs();
if abs_dc == 0 {
continue;
}
let Ok(dc_idx) = i16::try_from(abs_dc) else {
continue;
};
let entry = endpoints_by_dc.entry(dc_idx).or_default();
for (ip, port) in addrs {
entry.insert(SocketAddr::new(ip, port));
}
}
}
if endpoints_by_dc.is_empty() {
return false;
}
let writers = self.writers.read().await.clone();
let mut live_writers_by_endpoint = HashMap::<SocketAddr, usize>::new();
for writer in writers {
if writer.draining.load(Ordering::Relaxed) {
continue;
}
*live_writers_by_endpoint.entry(writer.addr).or_insert(0) += 1;
}
for endpoints in endpoints_by_dc.values() {
let alive: usize = endpoints
.iter()
.map(|endpoint| live_writers_by_endpoint.get(endpoint).copied().unwrap_or(0))
.sum();
if alive == 0 {
return false;
}
}
true
}
#[allow(dead_code)]
pub(crate) async fn admission_ready_full_floor(&self) -> bool { pub(crate) async fn admission_ready_full_floor(&self) -> bool {
let mut endpoints_by_dc = BTreeMap::<i16, BTreeSet<SocketAddr>>::new(); let mut endpoints_by_dc = BTreeMap::<i16, BTreeSet<SocketAddr>>::new();
if self.decision.ipv4_me { if self.decision.ipv4_me {