Compare commits

..

8 Commits
3.3.1 ... 3.3.3

Author SHA1 Message Date
Alexey
ef7dc2b80f Merge pull request #332 from telemt/bump
Update Cargo.toml
2026-03-06 04:05:46 +03:00
Alexey
691607f269 Update Cargo.toml 2026-03-06 04:05:35 +03:00
Alexey
55561a23bc ME NoWait Routing + Upstream Connbudget + another fixes: merge pull request #331 from telemt/flow-hp
ME NoWait Routing + Upstream Connbudget + another fixes
2026-03-06 04:05:04 +03:00
Alexey
f32c34f126 ME NoWait Routing + Upstream Connbudget + PROXY Header t/o + allocation cuts 2026-03-06 03:58:08 +03:00
Alexey
8f3bdaec2c Merge pull request #329 from telemt/bump
Update Cargo.toml
2026-03-05 23:23:40 +03:00
Alexey
69b02caf77 Update Cargo.toml 2026-03-05 23:23:24 +03:00
Alexey
3854955069 Merge pull request #328 from telemt/flow-mep
Secret Atomic Snapshot + KDF Fingerprint on RwLock
2026-03-05 23:23:01 +03:00
Alexey
9b84fc7a5b Secret Atomic Snapshot + KDF Fingerprint on RwLock
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-05 23:18:26 +03:00
19 changed files with 464 additions and 132 deletions

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "telemt" name = "telemt"
version = "3.3.1" version = "3.3.3"
edition = "2024" edition = "2024"
[dependencies] [dependencies]

View File

@@ -15,6 +15,7 @@ const DEFAULT_ME_ADAPTIVE_FLOOR_RECOVER_GRACE_SECS: u64 = 180;
const DEFAULT_USER_MAX_UNIQUE_IPS_WINDOW_SECS: u64 = 30; const DEFAULT_USER_MAX_UNIQUE_IPS_WINDOW_SECS: u64 = 30;
const DEFAULT_UPSTREAM_CONNECT_RETRY_ATTEMPTS: u32 = 2; const DEFAULT_UPSTREAM_CONNECT_RETRY_ATTEMPTS: u32 = 2;
const DEFAULT_UPSTREAM_UNHEALTHY_FAIL_THRESHOLD: u32 = 5; const DEFAULT_UPSTREAM_UNHEALTHY_FAIL_THRESHOLD: u32 = 5;
const DEFAULT_UPSTREAM_CONNECT_BUDGET_MS: u64 = 3000;
const DEFAULT_LISTEN_ADDR_IPV6: &str = "::"; const DEFAULT_LISTEN_ADDR_IPV6: &str = "::";
const DEFAULT_ACCESS_USER: &str = "default"; const DEFAULT_ACCESS_USER: &str = "default";
const DEFAULT_ACCESS_SECRET: &str = "00000000000000000000000000000000"; const DEFAULT_ACCESS_SECRET: &str = "00000000000000000000000000000000";
@@ -113,6 +114,10 @@ pub(crate) fn default_api_minimal_runtime_cache_ttl_ms() -> u64 {
1000 1000
} }
pub(crate) fn default_proxy_protocol_header_timeout_ms() -> u64 {
500
}
pub(crate) fn default_prefer_4() -> u8 { pub(crate) fn default_prefer_4() -> u8 {
4 4
} }
@@ -253,6 +258,10 @@ pub(crate) fn default_upstream_unhealthy_fail_threshold() -> u32 {
DEFAULT_UPSTREAM_UNHEALTHY_FAIL_THRESHOLD DEFAULT_UPSTREAM_UNHEALTHY_FAIL_THRESHOLD
} }
pub(crate) fn default_upstream_connect_budget_ms() -> u64 {
DEFAULT_UPSTREAM_CONNECT_BUDGET_MS
}
pub(crate) fn default_upstream_connect_failfast_hard_errors() -> bool { pub(crate) fn default_upstream_connect_failfast_hard_errors() -> bool {
false false
} }

View File

@@ -265,6 +265,12 @@ impl ProxyConfig {
)); ));
} }
if config.general.upstream_connect_budget_ms == 0 {
return Err(ProxyError::Config(
"general.upstream_connect_budget_ms must be > 0".to_string(),
));
}
if config.general.upstream_unhealthy_fail_threshold == 0 { if config.general.upstream_unhealthy_fail_threshold == 0 {
return Err(ProxyError::Config( return Err(ProxyError::Config(
"general.upstream_unhealthy_fail_threshold must be > 0".to_string(), "general.upstream_unhealthy_fail_threshold must be > 0".to_string(),
@@ -462,6 +468,12 @@ impl ProxyConfig {
)); ));
} }
if config.server.proxy_protocol_header_timeout_ms == 0 {
return Err(ProxyError::Config(
"server.proxy_protocol_header_timeout_ms must be > 0".to_string(),
));
}
if config.general.effective_me_pool_force_close_secs() > 0 if config.general.effective_me_pool_force_close_secs() > 0
&& config.general.effective_me_pool_force_close_secs() && config.general.effective_me_pool_force_close_secs()
< config.general.me_pool_drain_ttl_secs < config.general.me_pool_drain_ttl_secs
@@ -548,6 +560,12 @@ impl ProxyConfig {
config.general.middle_proxy_nat_probe = true; config.general.middle_proxy_nat_probe = true;
warn!("Auto-enabled middle_proxy_nat_probe for middle proxy mode"); warn!("Auto-enabled middle_proxy_nat_probe for middle proxy mode");
} }
if config.general.use_middle_proxy && !config.general.me_secret_atomic_snapshot {
config.general.me_secret_atomic_snapshot = true;
warn!(
"Auto-enabled me_secret_atomic_snapshot for middle proxy mode to keep KDF key_selector/secret coherent"
);
}
validate_network_cfg(&mut config.network)?; validate_network_cfg(&mut config.network)?;
crate::network::dns_overrides::validate_entries(&config.network.dns_overrides)?; crate::network::dns_overrides::validate_entries(&config.network.dns_overrides)?;

View File

@@ -532,6 +532,10 @@ pub struct GeneralConfig {
#[serde(default = "default_upstream_connect_retry_backoff_ms")] #[serde(default = "default_upstream_connect_retry_backoff_ms")]
pub upstream_connect_retry_backoff_ms: u64, pub upstream_connect_retry_backoff_ms: u64,
/// Total wall-clock budget in milliseconds for one upstream connect request across retries.
#[serde(default = "default_upstream_connect_budget_ms")]
pub upstream_connect_budget_ms: u64,
/// Consecutive failed requests before upstream is marked unhealthy. /// Consecutive failed requests before upstream is marked unhealthy.
#[serde(default = "default_upstream_unhealthy_fail_threshold")] #[serde(default = "default_upstream_unhealthy_fail_threshold")]
pub upstream_unhealthy_fail_threshold: u32, pub upstream_unhealthy_fail_threshold: u32,
@@ -774,6 +778,7 @@ impl Default for GeneralConfig {
me_adaptive_floor_recover_grace_secs: default_me_adaptive_floor_recover_grace_secs(), me_adaptive_floor_recover_grace_secs: default_me_adaptive_floor_recover_grace_secs(),
upstream_connect_retry_attempts: default_upstream_connect_retry_attempts(), upstream_connect_retry_attempts: default_upstream_connect_retry_attempts(),
upstream_connect_retry_backoff_ms: default_upstream_connect_retry_backoff_ms(), upstream_connect_retry_backoff_ms: default_upstream_connect_retry_backoff_ms(),
upstream_connect_budget_ms: default_upstream_connect_budget_ms(),
upstream_unhealthy_fail_threshold: default_upstream_unhealthy_fail_threshold(), upstream_unhealthy_fail_threshold: default_upstream_unhealthy_fail_threshold(),
upstream_connect_failfast_hard_errors: default_upstream_connect_failfast_hard_errors(), upstream_connect_failfast_hard_errors: default_upstream_connect_failfast_hard_errors(),
stun_iface_mismatch_ignore: false, stun_iface_mismatch_ignore: false,
@@ -962,6 +967,10 @@ pub struct ServerConfig {
#[serde(default)] #[serde(default)]
pub proxy_protocol: bool, pub proxy_protocol: bool,
/// Timeout in milliseconds for reading and parsing PROXY protocol headers.
#[serde(default = "default_proxy_protocol_header_timeout_ms")]
pub proxy_protocol_header_timeout_ms: u64,
#[serde(default)] #[serde(default)]
pub metrics_port: Option<u16>, pub metrics_port: Option<u16>,
@@ -985,6 +994,7 @@ impl Default for ServerConfig {
listen_unix_sock_perm: None, listen_unix_sock_perm: None,
listen_tcp: None, listen_tcp: None,
proxy_protocol: false, proxy_protocol: false,
proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(),
metrics_port: None, metrics_port: None,
metrics_whitelist: default_metrics_whitelist(), metrics_whitelist: default_metrics_whitelist(),
api: ApiConfig::default(), api: ApiConfig::default(),

View File

@@ -21,6 +21,7 @@ struct SecureRandomInner {
rng: StdRng, rng: StdRng,
cipher: AesCtr, cipher: AesCtr,
buffer: Vec<u8>, buffer: Vec<u8>,
buffer_start: usize,
} }
impl Drop for SecureRandomInner { impl Drop for SecureRandomInner {
@@ -48,6 +49,7 @@ impl SecureRandom {
rng, rng,
cipher, cipher,
buffer: Vec::with_capacity(1024), buffer: Vec::with_capacity(1024),
buffer_start: 0,
}), }),
} }
} }
@@ -59,16 +61,29 @@ impl SecureRandom {
let mut written = 0usize; let mut written = 0usize;
while written < out.len() { while written < out.len() {
if inner.buffer_start >= inner.buffer.len() {
inner.buffer.clear();
inner.buffer_start = 0;
}
if inner.buffer.is_empty() { if inner.buffer.is_empty() {
let mut chunk = vec![0u8; CHUNK_SIZE]; let mut chunk = vec![0u8; CHUNK_SIZE];
inner.rng.fill_bytes(&mut chunk); inner.rng.fill_bytes(&mut chunk);
inner.cipher.apply(&mut chunk); inner.cipher.apply(&mut chunk);
inner.buffer.extend_from_slice(&chunk); inner.buffer.extend_from_slice(&chunk);
inner.buffer_start = 0;
} }
let take = (out.len() - written).min(inner.buffer.len()); let available = inner.buffer.len().saturating_sub(inner.buffer_start);
out[written..written + take].copy_from_slice(&inner.buffer[..take]); let take = (out.len() - written).min(available);
inner.buffer.drain(..take); let start = inner.buffer_start;
let end = start + take;
out[written..written + take].copy_from_slice(&inner.buffer[start..end]);
inner.buffer_start = end;
if inner.buffer_start >= inner.buffer.len() {
inner.buffer.clear();
inner.buffer_start = 0;
}
written += take; written += take;
} }
} }

View File

@@ -464,6 +464,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
config.upstreams.clone(), config.upstreams.clone(),
config.general.upstream_connect_retry_attempts, config.general.upstream_connect_retry_attempts,
config.general.upstream_connect_retry_backoff_ms, config.general.upstream_connect_retry_backoff_ms,
config.general.upstream_connect_budget_ms,
config.general.upstream_unhealthy_fail_threshold, config.general.upstream_unhealthy_fail_threshold,
config.general.upstream_connect_failfast_hard_errors, config.general.upstream_connect_failfast_hard_errors,
stats.clone(), stats.clone(),
@@ -1339,7 +1340,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let (admission_tx, admission_rx) = watch::channel(true); let (admission_tx, admission_rx) = watch::channel(true);
if config.general.use_middle_proxy { if config.general.use_middle_proxy {
if let Some(pool) = me_pool.as_ref() { if let Some(pool) = me_pool.as_ref() {
let initial_open = pool.admission_ready_full_floor().await; let initial_open = pool.admission_ready_conditional_cast().await;
admission_tx.send_replace(initial_open); admission_tx.send_replace(initial_open);
if initial_open { if initial_open {
info!("Conditional-admission gate: open (ME pool ready)"); info!("Conditional-admission gate: open (ME pool ready)");
@@ -1354,7 +1355,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let mut open_streak = if initial_open { 1u32 } else { 0u32 }; let mut open_streak = if initial_open { 1u32 } else { 0u32 };
let mut close_streak = if initial_open { 0u32 } else { 1u32 }; let mut close_streak = if initial_open { 0u32 } else { 1u32 };
loop { loop {
let ready = pool_for_gate.admission_ready_full_floor().await; let ready = pool_for_gate.admission_ready_conditional_cast().await;
if ready { if ready {
open_streak = open_streak.saturating_add(1); open_streak = open_streak.saturating_add(1);
close_streak = 0; close_streak = 0;
@@ -1374,7 +1375,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
admission_tx_gate.send_replace(false); admission_tx_gate.send_replace(false);
warn!( warn!(
close_streak, close_streak,
"Conditional-admission gate closed (ME pool below required floor)" "Conditional-admission gate closed (ME pool has uncovered DC groups)"
); );
} }
} }

View File

@@ -97,8 +97,11 @@ where
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); .unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
if proxy_protocol_enabled { if proxy_protocol_enabled {
match parse_proxy_protocol(&mut stream, peer).await { let proxy_header_timeout = Duration::from_millis(
Ok(info) => { config.server.proxy_protocol_header_timeout_ms.max(1),
);
match timeout(proxy_header_timeout, parse_proxy_protocol(&mut stream, peer)).await {
Ok(Ok(info)) => {
debug!( debug!(
peer = %peer, peer = %peer,
client = %info.src_addr, client = %info.src_addr,
@@ -110,12 +113,18 @@ where
local_addr = dst; local_addr = dst;
} }
} }
Err(e) => { Ok(Err(e)) => {
stats.increment_connects_bad(); stats.increment_connects_bad();
warn!(peer = %peer, error = %e, "Invalid PROXY protocol header"); warn!(peer = %peer, error = %e, "Invalid PROXY protocol header");
record_beobachten_class(&beobachten, &config, peer.ip(), "other"); record_beobachten_class(&beobachten, &config, peer.ip(), "other");
return Err(e); return Err(e);
} }
Err(_) => {
stats.increment_connects_bad();
warn!(peer = %peer, timeout_ms = proxy_header_timeout.as_millis(), "PROXY protocol header timeout");
record_beobachten_class(&beobachten, &config, peer.ip(), "other");
return Err(ProxyError::InvalidProxyProtocol);
}
} }
} }
@@ -161,7 +170,7 @@ where
let (read_half, write_half) = tokio::io::split(stream); let (read_half, write_half) = tokio::io::split(stream);
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( let (mut tls_reader, tls_writer, tls_user) = match handle_tls_handshake(
&handshake, read_half, write_half, real_peer, &handshake, read_half, write_half, real_peer,
&config, &replay_checker, &rng, tls_cache.clone(), &config, &replay_checker, &rng, tls_cache.clone(),
).await { ).await {
@@ -190,7 +199,7 @@ where
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&mtproto_handshake, tls_reader, tls_writer, real_peer, &mtproto_handshake, tls_reader, tls_writer, real_peer,
&config, &replay_checker, true, &config, &replay_checker, true, Some(tls_user.as_str()),
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader: _, writer: _ } => { HandshakeResult::BadClient { reader: _, writer: _ } => {
@@ -234,7 +243,7 @@ where
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
&handshake, read_half, write_half, real_peer, &handshake, read_half, write_half, real_peer,
&config, &replay_checker, false, &config, &replay_checker, false, None,
).await { ).await {
HandshakeResult::Success(result) => result, HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => { HandshakeResult::BadClient { reader, writer } => {
@@ -415,8 +424,16 @@ impl RunningClientHandler {
let mut local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; let mut local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
if self.proxy_protocol_enabled { if self.proxy_protocol_enabled {
match parse_proxy_protocol(&mut self.stream, self.peer).await { let proxy_header_timeout = Duration::from_millis(
Ok(info) => { self.config.server.proxy_protocol_header_timeout_ms.max(1),
);
match timeout(
proxy_header_timeout,
parse_proxy_protocol(&mut self.stream, self.peer),
)
.await
{
Ok(Ok(info)) => {
debug!( debug!(
peer = %self.peer, peer = %self.peer,
client = %info.src_addr, client = %info.src_addr,
@@ -428,7 +445,7 @@ impl RunningClientHandler {
local_addr = dst; local_addr = dst;
} }
} }
Err(e) => { Ok(Err(e)) => {
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
warn!(peer = %self.peer, error = %e, "Invalid PROXY protocol header"); warn!(peer = %self.peer, error = %e, "Invalid PROXY protocol header");
record_beobachten_class( record_beobachten_class(
@@ -439,6 +456,21 @@ impl RunningClientHandler {
); );
return Err(e); return Err(e);
} }
Err(_) => {
self.stats.increment_connects_bad();
warn!(
peer = %self.peer,
timeout_ms = proxy_header_timeout.as_millis(),
"PROXY protocol header timeout"
);
record_beobachten_class(
&self.beobachten,
&self.config,
self.peer.ip(),
"other",
);
return Err(ProxyError::InvalidProxyProtocol);
}
} }
} }
@@ -494,7 +526,7 @@ impl RunningClientHandler {
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( let (mut tls_reader, tls_writer, tls_user) = match handle_tls_handshake(
&handshake, &handshake,
read_half, read_half,
write_half, write_half,
@@ -538,6 +570,7 @@ impl RunningClientHandler {
&config, &config,
&replay_checker, &replay_checker,
true, true,
Some(tls_user.as_str()),
) )
.await .await
{ {
@@ -611,6 +644,7 @@ impl RunningClientHandler {
&config, &config,
&replay_checker, &replay_checker,
false, false,
None,
) )
.await .await
{ {

View File

@@ -34,7 +34,7 @@ where
let user = &success.user; let user = &success.user;
let dc_addr = get_dc_addr_static(success.dc_idx, &config)?; let dc_addr = get_dc_addr_static(success.dc_idx, &config)?;
info!( debug!(
user = %user, user = %user,
peer = %success.peer, peer = %success.peer,
dc = success.dc_idx, dc = success.dc_idx,

View File

@@ -6,7 +6,7 @@ use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, warn, trace, info}; use tracing::{debug, warn, trace};
use zeroize::Zeroize; use zeroize::Zeroize;
use crate::crypto::{sha256, AesCtr, SecureRandom}; use crate::crypto::{sha256, AesCtr, SecureRandom};
@@ -19,6 +19,31 @@ use crate::stats::ReplayChecker;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::tls_front::{TlsFrontCache, emulator}; use crate::tls_front::{TlsFrontCache, emulator};
fn decode_user_secrets(
config: &ProxyConfig,
preferred_user: Option<&str>,
) -> Vec<(String, Vec<u8>)> {
let mut secrets = Vec::with_capacity(config.access.users.len());
if let Some(preferred) = preferred_user
&& let Some(secret_hex) = config.access.users.get(preferred)
&& let Ok(bytes) = hex::decode(secret_hex)
{
secrets.push((preferred.to_string(), bytes));
}
for (name, secret_hex) in &config.access.users {
if preferred_user.is_some_and(|preferred| preferred == name.as_str()) {
continue;
}
if let Ok(bytes) = hex::decode(secret_hex) {
secrets.push((name.clone(), bytes));
}
}
secrets
}
/// Result of successful handshake /// Result of successful handshake
/// ///
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is /// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
@@ -82,11 +107,7 @@ where
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
let secrets: Vec<(String, Vec<u8>)> = config.access.users.iter() let secrets = decode_user_secrets(config, None);
.filter_map(|(name, hex)| {
hex::decode(hex).ok().map(|bytes| (name.clone(), bytes))
})
.collect();
let validation = match tls::validate_tls_handshake( let validation = match tls::validate_tls_handshake(
handshake, handshake,
@@ -201,7 +222,7 @@ where
return HandshakeResult::Error(ProxyError::Io(e)); return HandshakeResult::Error(ProxyError::Io(e));
} }
info!( debug!(
peer = %peer, peer = %peer,
user = %validation.user, user = %validation.user,
"TLS handshake successful" "TLS handshake successful"
@@ -223,6 +244,7 @@ pub async fn handle_mtproto_handshake<R, W>(
config: &ProxyConfig, config: &ProxyConfig,
replay_checker: &ReplayChecker, replay_checker: &ReplayChecker,
is_tls: bool, is_tls: bool,
preferred_user: Option<&str>,
) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess), R, W> ) -> HandshakeResult<(CryptoReader<R>, CryptoWriter<W>, HandshakeSuccess), R, W>
where where
R: AsyncRead + Unpin + Send, R: AsyncRead + Unpin + Send,
@@ -239,11 +261,9 @@ where
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect(); let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
for (user, secret_hex) in &config.access.users { let decoded_users = decode_user_secrets(config, preferred_user);
let secret = match hex::decode(secret_hex) {
Ok(s) => s, for (user, secret) in decoded_users {
Err(_) => continue,
};
let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_prekey = &dec_prekey_iv[..PREKEY_LEN];
let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..];
@@ -311,7 +331,7 @@ where
is_tls, is_tls,
}; };
info!( debug!(
peer = %peer, peer = %peer,
user = %user, user = %user,
dc = dc_idx, dc = dc_idx,

View File

@@ -8,7 +8,7 @@ use std::time::{Duration, Instant};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tracing::{debug, info, trace, warn}; use tracing::{debug, trace, warn};
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
@@ -210,7 +210,7 @@ where
let proto_tag = success.proto_tag; let proto_tag = success.proto_tag;
let pool_generation = me_pool.current_generation(); let pool_generation = me_pool.current_generation();
info!( debug!(
user = %user, user = %user,
peer = %peer, peer = %peer,
dc = success.dc_idx, dc = success.dc_idx,

View File

@@ -846,16 +846,30 @@ impl Stats {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
self.user_stats.entry(user.to_string()).or_default() if let Some(stats) = self.user_stats.get(user) {
.connects.fetch_add(1, Ordering::Relaxed); stats.connects.fetch_add(1, Ordering::Relaxed);
return;
}
self.user_stats
.entry(user.to_string())
.or_default()
.connects
.fetch_add(1, Ordering::Relaxed);
} }
pub fn increment_user_curr_connects(&self, user: &str) { pub fn increment_user_curr_connects(&self, user: &str) {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
self.user_stats.entry(user.to_string()).or_default() if let Some(stats) = self.user_stats.get(user) {
.curr_connects.fetch_add(1, Ordering::Relaxed); stats.curr_connects.fetch_add(1, Ordering::Relaxed);
return;
}
self.user_stats
.entry(user.to_string())
.or_default()
.curr_connects
.fetch_add(1, Ordering::Relaxed);
} }
pub fn decrement_user_curr_connects(&self, user: &str) { pub fn decrement_user_curr_connects(&self, user: &str) {
@@ -889,32 +903,60 @@ impl Stats {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
self.user_stats.entry(user.to_string()).or_default() if let Some(stats) = self.user_stats.get(user) {
.octets_from_client.fetch_add(bytes, Ordering::Relaxed); stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
return;
}
self.user_stats
.entry(user.to_string())
.or_default()
.octets_from_client
.fetch_add(bytes, Ordering::Relaxed);
} }
pub fn add_user_octets_to(&self, user: &str, bytes: u64) { pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
self.user_stats.entry(user.to_string()).or_default() if let Some(stats) = self.user_stats.get(user) {
.octets_to_client.fetch_add(bytes, Ordering::Relaxed); stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
return;
}
self.user_stats
.entry(user.to_string())
.or_default()
.octets_to_client
.fetch_add(bytes, Ordering::Relaxed);
} }
pub fn increment_user_msgs_from(&self, user: &str) { pub fn increment_user_msgs_from(&self, user: &str) {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
self.user_stats.entry(user.to_string()).or_default() if let Some(stats) = self.user_stats.get(user) {
.msgs_from_client.fetch_add(1, Ordering::Relaxed); stats.msgs_from_client.fetch_add(1, Ordering::Relaxed);
return;
}
self.user_stats
.entry(user.to_string())
.or_default()
.msgs_from_client
.fetch_add(1, Ordering::Relaxed);
} }
pub fn increment_user_msgs_to(&self, user: &str) { pub fn increment_user_msgs_to(&self, user: &str) {
if !self.telemetry_user_enabled() { if !self.telemetry_user_enabled() {
return; return;
} }
self.user_stats.entry(user.to_string()).or_default() if let Some(stats) = self.user_stats.get(user) {
.msgs_to_client.fetch_add(1, Ordering::Relaxed); stats.msgs_to_client.fetch_add(1, Ordering::Relaxed);
return;
}
self.user_stats
.entry(user.to_string())
.or_default()
.msgs_to_client
.fetch_add(1, Ordering::Relaxed);
} }
pub fn get_user_total_octets(&self, user: &str) -> u64 { pub fn get_user_total_octets(&self, user: &str) -> u64 {

View File

@@ -387,9 +387,11 @@ impl MePool {
socks_bound_addr.map(|value| value.ip()), socks_bound_addr.map(|value| value.ip()),
client_port_source, client_port_source,
); );
let mut kdf_fingerprint_guard = self.kdf_material_fingerprint.lock().await; let previous_kdf_fingerprint = {
if let Some((prev_fingerprint, prev_client_port)) = let kdf_fingerprint_guard = self.kdf_material_fingerprint.read().await;
kdf_fingerprint_guard.get(&peer_addr_nat).copied() kdf_fingerprint_guard.get(&peer_addr_nat).copied()
};
if let Some((prev_fingerprint, prev_client_port)) = previous_kdf_fingerprint
{ {
if prev_fingerprint != kdf_fingerprint { if prev_fingerprint != kdf_fingerprint {
self.stats.increment_me_kdf_drift_total(); self.stats.increment_me_kdf_drift_total();
@@ -416,6 +418,9 @@ impl MePool {
); );
} }
} }
// Keep fingerprint updates eventually consistent for diagnostics while avoiding
// serializing all concurrent handshakes on a single async mutex.
let mut kdf_fingerprint_guard = self.kdf_material_fingerprint.write().await;
kdf_fingerprint_guard.insert(peer_addr_nat, (kdf_fingerprint, client_port_for_kdf)); kdf_fingerprint_guard.insert(peer_addr_nat, (kdf_fingerprint, client_port_for_kdf));
drop(kdf_fingerprint_guard); drop(kdf_fingerprint_guard);

View File

@@ -119,6 +119,8 @@ pub struct MePool {
pub(super) ping_tracker: Arc<Mutex<HashMap<i64, (std::time::Instant, u64)>>>, pub(super) ping_tracker: Arc<Mutex<HashMap<i64, (std::time::Instant, u64)>>>,
pub(super) rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>, pub(super) rtt_stats: Arc<Mutex<HashMap<u64, (f64, f64)>>>,
pub(super) nat_reflection_cache: Arc<Mutex<NatReflectionCache>>, pub(super) nat_reflection_cache: Arc<Mutex<NatReflectionCache>>,
pub(super) nat_reflection_singleflight_v4: Arc<Mutex<()>>,
pub(super) nat_reflection_singleflight_v6: Arc<Mutex<()>>,
pub(super) writer_available: Arc<Notify>, pub(super) writer_available: Arc<Notify>,
pub(super) refill_inflight: Arc<Mutex<HashSet<SocketAddr>>>, pub(super) refill_inflight: Arc<Mutex<HashSet<SocketAddr>>>,
pub(super) refill_inflight_dc: Arc<Mutex<HashSet<RefillDcKey>>>, pub(super) refill_inflight_dc: Arc<Mutex<HashSet<RefillDcKey>>>,
@@ -132,7 +134,7 @@ pub struct MePool {
pub(super) pending_hardswap_map_hash: AtomicU64, pub(super) pending_hardswap_map_hash: AtomicU64,
pub(super) hardswap: AtomicBool, pub(super) hardswap: AtomicBool,
pub(super) endpoint_quarantine: Arc<Mutex<HashMap<SocketAddr, Instant>>>, pub(super) endpoint_quarantine: Arc<Mutex<HashMap<SocketAddr, Instant>>>,
pub(super) kdf_material_fingerprint: Arc<Mutex<HashMap<SocketAddr, (u64, u16)>>>, pub(super) kdf_material_fingerprint: Arc<RwLock<HashMap<SocketAddr, (u64, u16)>>>,
pub(super) me_pool_drain_ttl_secs: AtomicU64, pub(super) me_pool_drain_ttl_secs: AtomicU64,
pub(super) me_pool_force_close_secs: AtomicU64, pub(super) me_pool_force_close_secs: AtomicU64,
pub(super) me_pool_min_fresh_ratio_permille: AtomicU32, pub(super) me_pool_min_fresh_ratio_permille: AtomicU32,
@@ -323,6 +325,8 @@ impl MePool {
ping_tracker: Arc::new(Mutex::new(HashMap::new())), ping_tracker: Arc::new(Mutex::new(HashMap::new())),
rtt_stats: Arc::new(Mutex::new(HashMap::new())), rtt_stats: Arc::new(Mutex::new(HashMap::new())),
nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())), nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())),
nat_reflection_singleflight_v4: Arc::new(Mutex::new(())),
nat_reflection_singleflight_v6: Arc::new(Mutex::new(())),
writer_available: Arc::new(Notify::new()), writer_available: Arc::new(Notify::new()),
refill_inflight: Arc::new(Mutex::new(HashSet::new())), refill_inflight: Arc::new(Mutex::new(HashSet::new())),
refill_inflight_dc: Arc::new(Mutex::new(HashSet::new())), refill_inflight_dc: Arc::new(Mutex::new(HashSet::new())),
@@ -335,7 +339,7 @@ impl MePool {
pending_hardswap_map_hash: AtomicU64::new(0), pending_hardswap_map_hash: AtomicU64::new(0),
hardswap: AtomicBool::new(hardswap), hardswap: AtomicBool::new(hardswap),
endpoint_quarantine: Arc::new(Mutex::new(HashMap::new())), endpoint_quarantine: Arc::new(Mutex::new(HashMap::new())),
kdf_material_fingerprint: Arc::new(Mutex::new(HashMap::new())), kdf_material_fingerprint: Arc::new(RwLock::new(HashMap::new())),
me_pool_drain_ttl_secs: AtomicU64::new(me_pool_drain_ttl_secs), me_pool_drain_ttl_secs: AtomicU64::new(me_pool_drain_ttl_secs),
me_pool_force_close_secs: AtomicU64::new(me_pool_force_close_secs), me_pool_force_close_secs: AtomicU64::new(me_pool_force_close_secs),
me_pool_min_fresh_ratio_permille: AtomicU32::new(Self::ratio_to_permille( me_pool_min_fresh_ratio_permille: AtomicU32::new(Self::ratio_to_permille(

View File

@@ -14,10 +14,12 @@ use super::pool::MePool;
impl MePool { impl MePool {
pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &Arc<SecureRandom>) -> Result<()> { pub async fn init(self: &Arc<Self>, pool_size: usize, rng: &Arc<SecureRandom>) -> Result<()> {
let family_order = self.family_order(); let family_order = self.family_order();
let connect_concurrency = self.me_reconnect_max_concurrent_per_dc.max(1) as usize;
let ks = self.key_selector().await; let ks = self.key_selector().await;
info!( info!(
me_servers = self.proxy_map_v4.read().await.len(), me_servers = self.proxy_map_v4.read().await.len(),
pool_size, pool_size,
connect_concurrency,
key_selector = format_args!("0x{ks:08x}"), key_selector = format_args!("0x{ks:08x}"),
secret_len = self.proxy_secret.read().await.secret.len(), secret_len = self.proxy_secret.read().await.secret.len(),
"Initializing ME pool" "Initializing ME pool"
@@ -41,23 +43,39 @@ impl MePool {
}) })
.collect(); .collect();
dc_addrs.sort_unstable_by_key(|(dc, _)| *dc); dc_addrs.sort_unstable_by_key(|(dc, _)| *dc);
dc_addrs.sort_by_key(|(_, addrs)| (addrs.len() != 1, addrs.len()));
// Ensure at least one live writer per DC group; run missing DCs in parallel. // Stage 1: build base coverage for conditional-cast.
// Single-endpoint DCs are prefilled first; multi-endpoint DCs require one live writer.
let mut join = tokio::task::JoinSet::new(); let mut join = tokio::task::JoinSet::new();
for (dc, addrs) in dc_addrs.iter().cloned() { for (dc, addrs) in dc_addrs.iter().cloned() {
if addrs.is_empty() { if addrs.is_empty() {
continue; continue;
} }
let target_writers = if addrs.len() == 1 {
self.required_writers_for_dc_with_floor_mode(addrs.len(), false)
} else {
1usize
};
let endpoints: HashSet<SocketAddr> = addrs let endpoints: HashSet<SocketAddr> = addrs
.iter() .iter()
.map(|(ip, port)| SocketAddr::new(*ip, *port)) .map(|(ip, port)| SocketAddr::new(*ip, *port))
.collect(); .collect();
if self.active_writer_count_for_endpoints(&endpoints).await > 0 { if self.active_writer_count_for_endpoints(&endpoints).await >= target_writers {
continue; continue;
} }
let pool = Arc::clone(self); let pool = Arc::clone(self);
let rng_clone = Arc::clone(rng); let rng_clone = Arc::clone(rng);
join.spawn(async move { pool.connect_primary_for_dc(dc, addrs, rng_clone).await }); join.spawn(async move {
pool.connect_primary_for_dc(
dc,
addrs,
target_writers,
rng_clone,
connect_concurrency,
)
.await
});
} }
while join.join_next().await.is_some() {} while join.join_next().await.is_some() {}
@@ -77,47 +95,35 @@ impl MePool {
))); )));
} }
// Warm reserve writers asynchronously so startup does not block after first working pool is ready. // Stage 2: continue saturating multi-endpoint DC groups in background.
let pool = Arc::clone(self); let pool = Arc::clone(self);
let rng_clone = Arc::clone(rng); let rng_clone = Arc::clone(rng);
let dc_addrs_bg = dc_addrs.clone(); let dc_addrs_bg = dc_addrs.clone();
tokio::spawn(async move { tokio::spawn(async move {
if pool.me_warmup_stagger_enabled { let mut join_bg = tokio::task::JoinSet::new();
for (dc, addrs) in &dc_addrs_bg { for (dc, addrs) in dc_addrs_bg {
for (ip, port) in addrs { if addrs.len() <= 1 {
if pool.connection_count() >= pool_size { continue;
break;
}
let addr = SocketAddr::new(*ip, *port);
let jitter = rand::rng()
.random_range(0..=pool.me_warmup_step_jitter.as_millis() as u64);
let delay_ms = pool.me_warmup_step_delay.as_millis() as u64 + jitter;
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
if let Err(e) = pool.connect_one(addr, rng_clone.as_ref()).await {
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed (staggered)");
}
}
}
} else {
for (dc, addrs) in &dc_addrs_bg {
for (ip, port) in addrs {
if pool.connection_count() >= pool_size {
break;
}
let addr = SocketAddr::new(*ip, *port);
if let Err(e) = pool.connect_one(addr, rng_clone.as_ref()).await {
debug!(%addr, dc = %dc, error = %e, "Extra ME connect failed");
}
}
if pool.connection_count() >= pool_size {
break;
}
} }
let target_writers = pool.required_writers_for_dc_with_floor_mode(addrs.len(), false);
let pool_clone = Arc::clone(&pool);
let rng_clone_local = Arc::clone(&rng_clone);
join_bg.spawn(async move {
pool_clone
.connect_primary_for_dc(
dc,
addrs,
target_writers,
rng_clone_local,
connect_concurrency,
)
.await
});
} }
while join_bg.join_next().await.is_some() {}
debug!( debug!(
target_pool_size = pool_size,
current_pool_size = pool.connection_count(), current_pool_size = pool.connection_count(),
"Background ME reserve warmup finished" "Background ME saturation warmup finished"
); );
}); });
@@ -140,62 +146,85 @@ impl MePool {
self: Arc<Self>, self: Arc<Self>,
dc: i32, dc: i32,
mut addrs: Vec<(IpAddr, u16)>, mut addrs: Vec<(IpAddr, u16)>,
target_writers: usize,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
connect_concurrency: usize,
) -> bool { ) -> bool {
if addrs.is_empty() { if addrs.is_empty() {
return false; return false;
} }
let target_writers = target_writers.max(1);
addrs.shuffle(&mut rand::rng()); addrs.shuffle(&mut rand::rng());
if addrs.len() > 1 { let endpoints: Vec<SocketAddr> = addrs
let concurrency = 2usize; .iter()
let mut join = tokio::task::JoinSet::new(); .map(|(ip, port)| SocketAddr::new(*ip, *port))
let mut next_idx = 0usize; .collect();
let endpoint_set: HashSet<SocketAddr> = endpoints.iter().copied().collect();
while next_idx < addrs.len() || !join.is_empty() { loop {
while next_idx < addrs.len() && join.len() < concurrency { let alive = self.active_writer_count_for_endpoints(&endpoint_set).await;
let (ip, port) = addrs[next_idx]; if alive >= target_writers {
next_idx += 1; info!(
let addr = SocketAddr::new(ip, port); dc = %dc,
alive,
target_writers,
"ME connected"
);
return true;
}
let missing = target_writers.saturating_sub(alive).max(1);
let concurrency = connect_concurrency.max(1).min(missing);
let mut join = tokio::task::JoinSet::new();
for _ in 0..concurrency {
let pool = Arc::clone(&self); let pool = Arc::clone(&self);
let rng_clone = Arc::clone(&rng); let rng_clone = Arc::clone(&rng);
let endpoints_clone = endpoints.clone();
join.spawn(async move { join.spawn(async move {
(addr, pool.connect_one(addr, rng_clone.as_ref()).await) pool.connect_endpoints_round_robin(&endpoints_clone, rng_clone.as_ref())
.await
}); });
} }
let Some(res) = join.join_next().await else { let mut progress = false;
break; while let Some(res) = join.join_next().await {
};
match res { match res {
Ok((addr, Ok(()))) => { Ok(true) => {
info!(%addr, dc = %dc, "ME connected"); progress = true;
join.abort_all();
while join.join_next().await.is_some() {}
return true;
}
Ok((addr, Err(e))) => {
warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next");
} }
Ok(false) => {}
Err(e) => { Err(e) => {
warn!(dc = %dc, error = %e, "ME connect task failed"); warn!(dc = %dc, error = %e, "ME connect task failed");
} }
} }
} }
warn!(dc = %dc, "All ME servers for DC failed at init");
let alive_after = self.active_writer_count_for_endpoints(&endpoint_set).await;
if alive_after >= target_writers {
info!(
dc = %dc,
alive = alive_after,
target_writers,
"ME connected"
);
return true;
}
if !progress {
warn!(
dc = %dc,
alive = alive_after,
target_writers,
"All ME servers for DC failed at init"
);
return false; return false;
} }
for (ip, port) in addrs { if self.me_warmup_stagger_enabled {
let addr = SocketAddr::new(ip, port); let jitter = rand::rng()
match self.connect_one(addr, rng.as_ref()).await { .random_range(0..=self.me_warmup_step_jitter.as_millis() as u64);
Ok(()) => { let delay_ms = self.me_warmup_step_delay.as_millis() as u64 + jitter;
info!(%addr, dc = %dc, "ME connected"); tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
return true;
}
Err(e) => warn!(%addr, dc = %dc, error = %e, "ME connect failed, trying next"),
} }
} }
warn!(dc = %dc, "All ME servers for DC failed at init");
false
} }
} }

View File

@@ -248,6 +248,43 @@ impl MePool {
} }
} }
let _singleflight_guard = if use_shared_cache {
Some(match family {
IpFamily::V4 => self.nat_reflection_singleflight_v4.lock().await,
IpFamily::V6 => self.nat_reflection_singleflight_v6.lock().await,
})
} else {
None
};
if use_shared_cache
&& let Some(until) = *self.stun_backoff_until.read().await
&& Instant::now() < until
{
if let Ok(cache) = self.nat_reflection_cache.try_lock() {
let slot = match family {
IpFamily::V4 => cache.v4,
IpFamily::V6 => cache.v6,
};
return slot.map(|(_, addr)| addr);
}
return None;
}
if use_shared_cache
&& let Ok(mut cache) = self.nat_reflection_cache.try_lock()
{
let slot = match family {
IpFamily::V4 => &mut cache.v4,
IpFamily::V6 => &mut cache.v6,
};
if let Some((ts, addr)) = slot
&& ts.elapsed() < STUN_CACHE_TTL
{
return Some(*addr);
}
}
let attempt = if use_shared_cache { let attempt = if use_shared_cache {
self.nat_probe_attempts.fetch_add(1, std::sync::atomic::Ordering::Relaxed) self.nat_probe_attempts.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
} else { } else {

View File

@@ -100,6 +100,68 @@ pub(crate) struct MeApiRuntimeSnapshot {
} }
impl MePool { impl MePool {
pub(crate) async fn admission_ready_conditional_cast(&self) -> bool {
let mut endpoints_by_dc = BTreeMap::<i16, BTreeSet<SocketAddr>>::new();
if self.decision.ipv4_me {
let map = self.proxy_map_v4.read().await.clone();
for (dc, addrs) in map {
let abs_dc = dc.abs();
if abs_dc == 0 {
continue;
}
let Ok(dc_idx) = i16::try_from(abs_dc) else {
continue;
};
let entry = endpoints_by_dc.entry(dc_idx).or_default();
for (ip, port) in addrs {
entry.insert(SocketAddr::new(ip, port));
}
}
}
if self.decision.ipv6_me {
let map = self.proxy_map_v6.read().await.clone();
for (dc, addrs) in map {
let abs_dc = dc.abs();
if abs_dc == 0 {
continue;
}
let Ok(dc_idx) = i16::try_from(abs_dc) else {
continue;
};
let entry = endpoints_by_dc.entry(dc_idx).or_default();
for (ip, port) in addrs {
entry.insert(SocketAddr::new(ip, port));
}
}
}
if endpoints_by_dc.is_empty() {
return false;
}
let writers = self.writers.read().await.clone();
let mut live_writers_by_endpoint = HashMap::<SocketAddr, usize>::new();
for writer in writers {
if writer.draining.load(Ordering::Relaxed) {
continue;
}
*live_writers_by_endpoint.entry(writer.addr).or_insert(0) += 1;
}
for endpoints in endpoints_by_dc.values() {
let alive: usize = endpoints
.iter()
.map(|endpoint| live_writers_by_endpoint.get(endpoint).copied().unwrap_or(0))
.sum();
if alive == 0 {
return false;
}
}
true
}
#[allow(dead_code)]
pub(crate) async fn admission_ready_full_floor(&self) -> bool { pub(crate) async fn admission_ready_full_floor(&self) -> bool {
let mut endpoints_by_dc = BTreeMap::<i16, BTreeSet<SocketAddr>>::new(); let mut endpoints_by_dc = BTreeMap::<i16, BTreeSet<SocketAddr>>::new();
if self.decision.ipv4_me { if self.decision.ipv4_me {

View File

@@ -124,7 +124,7 @@ pub(crate) async fn reader_loop(
let data = Bytes::copy_from_slice(&body[12..]); let data = Bytes::copy_from_slice(&body[12..]);
trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS"); trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS");
let routed = reg.route(cid, MeResponse::Data { flags, data }).await; let routed = reg.route_nowait(cid, MeResponse::Data { flags, data }).await;
if !matches!(routed, RouteResult::Routed) { if !matches!(routed, RouteResult::Routed) {
match routed { match routed {
RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), RouteResult::NoConn => stats.increment_me_route_drop_no_conn(),
@@ -147,7 +147,7 @@ pub(crate) async fn reader_loop(
let cfm = u32::from_le_bytes(body[8..12].try_into().unwrap()); let cfm = u32::from_le_bytes(body[8..12].try_into().unwrap());
trace!(cid, cfm, "RPC_SIMPLE_ACK"); trace!(cid, cfm, "RPC_SIMPLE_ACK");
let routed = reg.route(cid, MeResponse::Ack(cfm)).await; let routed = reg.route_nowait(cid, MeResponse::Ack(cfm)).await;
if !matches!(routed, RouteResult::Routed) { if !matches!(routed, RouteResult::Routed) {
match routed { match routed {
RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), RouteResult::NoConn => stats.increment_me_route_drop_no_conn(),

View File

@@ -208,6 +208,23 @@ impl ConnRegistry {
} }
} }
pub async fn route_nowait(&self, id: u64, resp: MeResponse) -> RouteResult {
let tx = {
let inner = self.inner.read().await;
inner.map.get(&id).cloned()
};
let Some(tx) = tx else {
return RouteResult::NoConn;
};
match tx.try_send(resp) {
Ok(()) => RouteResult::Routed,
Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed,
Err(TrySendError::Full(_)) => RouteResult::QueueFullBase,
}
}
pub async fn bind_writer( pub async fn bind_writer(
&self, &self,
conn_id: u64, conn_id: u64,

View File

@@ -225,6 +225,7 @@ pub struct UpstreamManager {
upstreams: Arc<RwLock<Vec<UpstreamState>>>, upstreams: Arc<RwLock<Vec<UpstreamState>>>,
connect_retry_attempts: u32, connect_retry_attempts: u32,
connect_retry_backoff: Duration, connect_retry_backoff: Duration,
connect_budget: Duration,
unhealthy_fail_threshold: u32, unhealthy_fail_threshold: u32,
connect_failfast_hard_errors: bool, connect_failfast_hard_errors: bool,
stats: Arc<Stats>, stats: Arc<Stats>,
@@ -235,6 +236,7 @@ impl UpstreamManager {
configs: Vec<UpstreamConfig>, configs: Vec<UpstreamConfig>,
connect_retry_attempts: u32, connect_retry_attempts: u32,
connect_retry_backoff_ms: u64, connect_retry_backoff_ms: u64,
connect_budget_ms: u64,
unhealthy_fail_threshold: u32, unhealthy_fail_threshold: u32,
connect_failfast_hard_errors: bool, connect_failfast_hard_errors: bool,
stats: Arc<Stats>, stats: Arc<Stats>,
@@ -248,6 +250,7 @@ impl UpstreamManager {
upstreams: Arc::new(RwLock::new(states)), upstreams: Arc::new(RwLock::new(states)),
connect_retry_attempts: connect_retry_attempts.max(1), connect_retry_attempts: connect_retry_attempts.max(1),
connect_retry_backoff: Duration::from_millis(connect_retry_backoff_ms), connect_retry_backoff: Duration::from_millis(connect_retry_backoff_ms),
connect_budget: Duration::from_millis(connect_budget_ms.max(1)),
unhealthy_fail_threshold: unhealthy_fail_threshold.max(1), unhealthy_fail_threshold: unhealthy_fail_threshold.max(1),
connect_failfast_hard_errors, connect_failfast_hard_errors,
stats, stats,
@@ -593,11 +596,27 @@ impl UpstreamManager {
let mut last_error: Option<ProxyError> = None; let mut last_error: Option<ProxyError> = None;
let mut attempts_used = 0u32; let mut attempts_used = 0u32;
for attempt in 1..=self.connect_retry_attempts { for attempt in 1..=self.connect_retry_attempts {
let elapsed = connect_started_at.elapsed();
if elapsed >= self.connect_budget {
last_error = Some(ProxyError::ConnectionTimeout {
addr: target.to_string(),
});
break;
}
let remaining_budget = self.connect_budget.saturating_sub(elapsed);
let attempt_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS)
.min(remaining_budget);
if attempt_timeout.is_zero() {
last_error = Some(ProxyError::ConnectionTimeout {
addr: target.to_string(),
});
break;
}
attempts_used = attempt; attempts_used = attempt;
self.stats.increment_upstream_connect_attempt_total(); self.stats.increment_upstream_connect_attempt_total();
let start = Instant::now(); let start = Instant::now();
match self match self
.connect_via_upstream(&upstream, target, bind_rr.clone()) .connect_via_upstream(&upstream, target, bind_rr.clone(), attempt_timeout)
.await .await
{ {
Ok((stream, egress)) => { Ok((stream, egress)) => {
@@ -707,6 +726,7 @@ impl UpstreamManager {
config: &UpstreamConfig, config: &UpstreamConfig,
target: SocketAddr, target: SocketAddr,
bind_rr: Option<Arc<AtomicUsize>>, bind_rr: Option<Arc<AtomicUsize>>,
connect_timeout: Duration,
) -> Result<(TcpStream, UpstreamEgressInfo)> { ) -> Result<(TcpStream, UpstreamEgressInfo)> {
match &config.upstream_type { match &config.upstream_type {
UpstreamType::Direct { interface, bind_addresses } => { UpstreamType::Direct { interface, bind_addresses } => {
@@ -735,7 +755,6 @@ impl UpstreamManager {
let std_stream: std::net::TcpStream = socket.into(); let std_stream: std::net::TcpStream = socket.into();
let stream = TcpStream::from_std(std_stream)?; let stream = TcpStream::from_std(std_stream)?;
let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS);
match tokio::time::timeout(connect_timeout, stream.writable()).await { match tokio::time::timeout(connect_timeout, stream.writable()).await {
Ok(Ok(())) => {} Ok(Ok(())) => {}
Ok(Err(e)) => return Err(ProxyError::Io(e)), Ok(Err(e)) => return Err(ProxyError::Io(e)),
@@ -762,7 +781,6 @@ impl UpstreamManager {
)) ))
}, },
UpstreamType::Socks4 { address, interface, user_id } => { UpstreamType::Socks4 { address, interface, user_id } => {
let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS);
// Try to parse as SocketAddr first (IP:port), otherwise treat as hostname:port // Try to parse as SocketAddr first (IP:port), otherwise treat as hostname:port
let mut stream = if let Ok(proxy_addr) = address.parse::<SocketAddr>() { let mut stream = if let Ok(proxy_addr) = address.parse::<SocketAddr>() {
// IP:port format - use socket with optional interface binding // IP:port format - use socket with optional interface binding
@@ -841,7 +859,6 @@ impl UpstreamManager {
)) ))
}, },
UpstreamType::Socks5 { address, interface, username, password } => { UpstreamType::Socks5 { address, interface, username, password } => {
let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS);
// Try to parse as SocketAddr first (IP:port), otherwise treat as hostname:port // Try to parse as SocketAddr first (IP:port), otherwise treat as hostname:port
let mut stream = if let Ok(proxy_addr) = address.parse::<SocketAddr>() { let mut stream = if let Ok(proxy_addr) = address.parse::<SocketAddr>() {
// IP:port format - use socket with optional interface binding // IP:port format - use socket with optional interface binding
@@ -1165,7 +1182,14 @@ impl UpstreamManager {
target: SocketAddr, target: SocketAddr,
) -> Result<f64> { ) -> Result<f64> {
let start = Instant::now(); let start = Instant::now();
let _ = self.connect_via_upstream(config, target, bind_rr).await?; let _ = self
.connect_via_upstream(
config,
target,
bind_rr,
Duration::from_secs(DC_PING_TIMEOUT_SECS),
)
.await?;
Ok(start.elapsed().as_secs_f64() * 1000.0) Ok(start.elapsed().as_secs_f64() * 1000.0)
} }
@@ -1337,7 +1361,12 @@ impl UpstreamManager {
let start = Instant::now(); let start = Instant::now();
let result = tokio::time::timeout( let result = tokio::time::timeout(
Duration::from_secs(HEALTH_CHECK_CONNECT_TIMEOUT_SECS), Duration::from_secs(HEALTH_CHECK_CONNECT_TIMEOUT_SECS),
self.connect_via_upstream(&config, endpoint, Some(bind_rr.clone())), self.connect_via_upstream(
&config,
endpoint,
Some(bind_rr.clone()),
Duration::from_secs(HEALTH_CHECK_CONNECT_TIMEOUT_SECS),
),
) )
.await; .await;