diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 4b94be6..9851216 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -15,6 +15,7 @@ const DEFAULT_ME_ADAPTIVE_FLOOR_RECOVER_GRACE_SECS: u64 = 180; const DEFAULT_USER_MAX_UNIQUE_IPS_WINDOW_SECS: u64 = 30; const DEFAULT_UPSTREAM_CONNECT_RETRY_ATTEMPTS: u32 = 2; const DEFAULT_UPSTREAM_UNHEALTHY_FAIL_THRESHOLD: u32 = 5; +const DEFAULT_UPSTREAM_CONNECT_BUDGET_MS: u64 = 3000; const DEFAULT_LISTEN_ADDR_IPV6: &str = "::"; const DEFAULT_ACCESS_USER: &str = "default"; const DEFAULT_ACCESS_SECRET: &str = "00000000000000000000000000000000"; @@ -113,6 +114,10 @@ pub(crate) fn default_api_minimal_runtime_cache_ttl_ms() -> u64 { 1000 } +pub(crate) fn default_proxy_protocol_header_timeout_ms() -> u64 { + 500 +} + pub(crate) fn default_prefer_4() -> u8 { 4 } @@ -253,6 +258,10 @@ pub(crate) fn default_upstream_unhealthy_fail_threshold() -> u32 { DEFAULT_UPSTREAM_UNHEALTHY_FAIL_THRESHOLD } +pub(crate) fn default_upstream_connect_budget_ms() -> u64 { + DEFAULT_UPSTREAM_CONNECT_BUDGET_MS +} + pub(crate) fn default_upstream_connect_failfast_hard_errors() -> bool { false } diff --git a/src/config/load.rs b/src/config/load.rs index dcca2a0..470bc37 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -265,6 +265,12 @@ impl ProxyConfig { )); } + if config.general.upstream_connect_budget_ms == 0 { + return Err(ProxyError::Config( + "general.upstream_connect_budget_ms must be > 0".to_string(), + )); + } + if config.general.upstream_unhealthy_fail_threshold == 0 { return Err(ProxyError::Config( "general.upstream_unhealthy_fail_threshold must be > 0".to_string(), @@ -462,6 +468,12 @@ impl ProxyConfig { )); } + if config.server.proxy_protocol_header_timeout_ms == 0 { + return Err(ProxyError::Config( + "server.proxy_protocol_header_timeout_ms must be > 0".to_string(), + )); + } + if config.general.effective_me_pool_force_close_secs() > 0 && config.general.effective_me_pool_force_close_secs() < config.general.me_pool_drain_ttl_secs diff --git a/src/config/types.rs b/src/config/types.rs index 88bf8d3..be238d3 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -532,6 +532,10 @@ pub struct GeneralConfig { #[serde(default = "default_upstream_connect_retry_backoff_ms")] pub upstream_connect_retry_backoff_ms: u64, + /// Total wall-clock budget in milliseconds for one upstream connect request across retries. + #[serde(default = "default_upstream_connect_budget_ms")] + pub upstream_connect_budget_ms: u64, + /// Consecutive failed requests before upstream is marked unhealthy. #[serde(default = "default_upstream_unhealthy_fail_threshold")] pub upstream_unhealthy_fail_threshold: u32, @@ -774,6 +778,7 @@ impl Default for GeneralConfig { me_adaptive_floor_recover_grace_secs: default_me_adaptive_floor_recover_grace_secs(), upstream_connect_retry_attempts: default_upstream_connect_retry_attempts(), upstream_connect_retry_backoff_ms: default_upstream_connect_retry_backoff_ms(), + upstream_connect_budget_ms: default_upstream_connect_budget_ms(), upstream_unhealthy_fail_threshold: default_upstream_unhealthy_fail_threshold(), upstream_connect_failfast_hard_errors: default_upstream_connect_failfast_hard_errors(), stun_iface_mismatch_ignore: false, @@ -962,6 +967,10 @@ pub struct ServerConfig { #[serde(default)] pub proxy_protocol: bool, + /// Timeout in milliseconds for reading and parsing PROXY protocol headers. + #[serde(default = "default_proxy_protocol_header_timeout_ms")] + pub proxy_protocol_header_timeout_ms: u64, + #[serde(default)] pub metrics_port: Option, @@ -985,6 +994,7 @@ impl Default for ServerConfig { listen_unix_sock_perm: None, listen_tcp: None, proxy_protocol: false, + proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(), metrics_port: None, metrics_whitelist: default_metrics_whitelist(), api: ApiConfig::default(), diff --git a/src/crypto/random.rs b/src/crypto/random.rs index 6313610..a88efc6 100644 --- a/src/crypto/random.rs +++ b/src/crypto/random.rs @@ -21,6 +21,7 @@ struct SecureRandomInner { rng: StdRng, cipher: AesCtr, buffer: Vec, + buffer_start: usize, } impl Drop for SecureRandomInner { @@ -48,6 +49,7 @@ impl SecureRandom { rng, cipher, buffer: Vec::with_capacity(1024), + buffer_start: 0, }), } } @@ -59,16 +61,29 @@ impl SecureRandom { let mut written = 0usize; while written < out.len() { + if inner.buffer_start >= inner.buffer.len() { + inner.buffer.clear(); + inner.buffer_start = 0; + } + if inner.buffer.is_empty() { let mut chunk = vec![0u8; CHUNK_SIZE]; inner.rng.fill_bytes(&mut chunk); inner.cipher.apply(&mut chunk); inner.buffer.extend_from_slice(&chunk); + inner.buffer_start = 0; } - let take = (out.len() - written).min(inner.buffer.len()); - out[written..written + take].copy_from_slice(&inner.buffer[..take]); - inner.buffer.drain(..take); + let available = inner.buffer.len().saturating_sub(inner.buffer_start); + let take = (out.len() - written).min(available); + let start = inner.buffer_start; + let end = start + take; + out[written..written + take].copy_from_slice(&inner.buffer[start..end]); + inner.buffer_start = end; + if inner.buffer_start >= inner.buffer.len() { + inner.buffer.clear(); + inner.buffer_start = 0; + } written += take; } } diff --git a/src/main.rs b/src/main.rs index 064df16..7f546d8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -464,6 +464,7 @@ async fn main() -> std::result::Result<(), Box> { config.upstreams.clone(), config.general.upstream_connect_retry_attempts, config.general.upstream_connect_retry_backoff_ms, + config.general.upstream_connect_budget_ms, config.general.upstream_unhealthy_fail_threshold, config.general.upstream_connect_failfast_hard_errors, stats.clone(), diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 2c9fa0c..ebfabcb 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -97,8 +97,11 @@ where .unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); if proxy_protocol_enabled { - match parse_proxy_protocol(&mut stream, peer).await { - Ok(info) => { + let proxy_header_timeout = Duration::from_millis( + config.server.proxy_protocol_header_timeout_ms.max(1), + ); + match timeout(proxy_header_timeout, parse_proxy_protocol(&mut stream, peer)).await { + Ok(Ok(info)) => { debug!( peer = %peer, client = %info.src_addr, @@ -110,12 +113,18 @@ where local_addr = dst; } } - Err(e) => { + Ok(Err(e)) => { stats.increment_connects_bad(); warn!(peer = %peer, error = %e, "Invalid PROXY protocol header"); record_beobachten_class(&beobachten, &config, peer.ip(), "other"); return Err(e); } + Err(_) => { + stats.increment_connects_bad(); + warn!(peer = %peer, timeout_ms = proxy_header_timeout.as_millis(), "PROXY protocol header timeout"); + record_beobachten_class(&beobachten, &config, peer.ip(), "other"); + return Err(ProxyError::InvalidProxyProtocol); + } } } @@ -161,7 +170,7 @@ where let (read_half, write_half) = tokio::io::split(stream); - let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( + let (mut tls_reader, tls_writer, tls_user) = match handle_tls_handshake( &handshake, read_half, write_half, real_peer, &config, &replay_checker, &rng, tls_cache.clone(), ).await { @@ -190,7 +199,7 @@ where let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( &mtproto_handshake, tls_reader, tls_writer, real_peer, - &config, &replay_checker, true, + &config, &replay_checker, true, Some(tls_user.as_str()), ).await { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader: _, writer: _ } => { @@ -234,7 +243,7 @@ where let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( &handshake, read_half, write_half, real_peer, - &config, &replay_checker, false, + &config, &replay_checker, false, None, ).await { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { @@ -415,8 +424,16 @@ impl RunningClientHandler { let mut local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; if self.proxy_protocol_enabled { - match parse_proxy_protocol(&mut self.stream, self.peer).await { - Ok(info) => { + let proxy_header_timeout = Duration::from_millis( + self.config.server.proxy_protocol_header_timeout_ms.max(1), + ); + match timeout( + proxy_header_timeout, + parse_proxy_protocol(&mut self.stream, self.peer), + ) + .await + { + Ok(Ok(info)) => { debug!( peer = %self.peer, client = %info.src_addr, @@ -428,7 +445,7 @@ impl RunningClientHandler { local_addr = dst; } } - Err(e) => { + Ok(Err(e)) => { self.stats.increment_connects_bad(); warn!(peer = %self.peer, error = %e, "Invalid PROXY protocol header"); record_beobachten_class( @@ -439,6 +456,21 @@ impl RunningClientHandler { ); return Err(e); } + Err(_) => { + self.stats.increment_connects_bad(); + warn!( + peer = %self.peer, + timeout_ms = proxy_header_timeout.as_millis(), + "PROXY protocol header timeout" + ); + record_beobachten_class( + &self.beobachten, + &self.config, + self.peer.ip(), + "other", + ); + return Err(ProxyError::InvalidProxyProtocol); + } } } @@ -494,7 +526,7 @@ impl RunningClientHandler { let (read_half, write_half) = self.stream.into_split(); - let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( + let (mut tls_reader, tls_writer, tls_user) = match handle_tls_handshake( &handshake, read_half, write_half, @@ -538,6 +570,7 @@ impl RunningClientHandler { &config, &replay_checker, true, + Some(tls_user.as_str()), ) .await { @@ -611,6 +644,7 @@ impl RunningClientHandler { &config, &replay_checker, false, + None, ) .await { diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index a1f4945..1245f34 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -34,7 +34,7 @@ where let user = &success.user; let dc_addr = get_dc_addr_static(success.dc_idx, &config)?; - info!( + debug!( user = %user, peer = %success.peer, dc = success.dc_idx, diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 5c63636..296432f 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -6,7 +6,7 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tracing::{debug, warn, trace, info}; +use tracing::{debug, warn, trace}; use zeroize::Zeroize; use crate::crypto::{sha256, AesCtr, SecureRandom}; @@ -19,6 +19,31 @@ use crate::stats::ReplayChecker; use crate::config::ProxyConfig; use crate::tls_front::{TlsFrontCache, emulator}; +fn decode_user_secrets( + config: &ProxyConfig, + preferred_user: Option<&str>, +) -> Vec<(String, Vec)> { + let mut secrets = Vec::with_capacity(config.access.users.len()); + + if let Some(preferred) = preferred_user + && let Some(secret_hex) = config.access.users.get(preferred) + && let Ok(bytes) = hex::decode(secret_hex) + { + secrets.push((preferred.to_string(), bytes)); + } + + for (name, secret_hex) in &config.access.users { + if preferred_user.is_some_and(|preferred| preferred == name.as_str()) { + continue; + } + if let Ok(bytes) = hex::decode(secret_hex) { + secrets.push((name.clone(), bytes)); + } + } + + secrets +} + /// Result of successful handshake /// /// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is @@ -82,11 +107,7 @@ where return HandshakeResult::BadClient { reader, writer }; } - let secrets: Vec<(String, Vec)> = config.access.users.iter() - .filter_map(|(name, hex)| { - hex::decode(hex).ok().map(|bytes| (name.clone(), bytes)) - }) - .collect(); + let secrets = decode_user_secrets(config, None); let validation = match tls::validate_tls_handshake( handshake, @@ -201,7 +222,7 @@ where return HandshakeResult::Error(ProxyError::Io(e)); } - info!( + debug!( peer = %peer, user = %validation.user, "TLS handshake successful" @@ -223,6 +244,7 @@ pub async fn handle_mtproto_handshake( config: &ProxyConfig, replay_checker: &ReplayChecker, is_tls: bool, + preferred_user: Option<&str>, ) -> HandshakeResult<(CryptoReader, CryptoWriter, HandshakeSuccess), R, W> where R: AsyncRead + Unpin + Send, @@ -239,11 +261,9 @@ where let enc_prekey_iv: Vec = dec_prekey_iv.iter().rev().copied().collect(); - for (user, secret_hex) in &config.access.users { - let secret = match hex::decode(secret_hex) { - Ok(s) => s, - Err(_) => continue, - }; + let decoded_users = decode_user_secrets(config, preferred_user); + + for (user, secret) in decoded_users { let dec_prekey = &dec_prekey_iv[..PREKEY_LEN]; let dec_iv_bytes = &dec_prekey_iv[PREKEY_LEN..]; @@ -311,7 +331,7 @@ where is_tls, }; - info!( + debug!( peer = %peer, user = %user, dc = dc_idx, diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 0690906..8f5fc36 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -8,7 +8,7 @@ use std::time::{Duration, Instant}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::{mpsc, oneshot}; -use tracing::{debug, info, trace, warn}; +use tracing::{debug, trace, warn}; use crate::config::ProxyConfig; use crate::crypto::SecureRandom; @@ -210,7 +210,7 @@ where let proto_tag = success.proto_tag; let pool_generation = me_pool.current_generation(); - info!( + debug!( user = %user, peer = %peer, dc = success.dc_idx, diff --git a/src/stats/mod.rs b/src/stats/mod.rs index eedc7f6..4cc9933 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -846,16 +846,30 @@ impl Stats { if !self.telemetry_user_enabled() { return; } - self.user_stats.entry(user.to_string()).or_default() - .connects.fetch_add(1, Ordering::Relaxed); + if let Some(stats) = self.user_stats.get(user) { + stats.connects.fetch_add(1, Ordering::Relaxed); + return; + } + self.user_stats + .entry(user.to_string()) + .or_default() + .connects + .fetch_add(1, Ordering::Relaxed); } pub fn increment_user_curr_connects(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.user_stats.entry(user.to_string()).or_default() - .curr_connects.fetch_add(1, Ordering::Relaxed); + if let Some(stats) = self.user_stats.get(user) { + stats.curr_connects.fetch_add(1, Ordering::Relaxed); + return; + } + self.user_stats + .entry(user.to_string()) + .or_default() + .curr_connects + .fetch_add(1, Ordering::Relaxed); } pub fn decrement_user_curr_connects(&self, user: &str) { @@ -889,32 +903,60 @@ impl Stats { if !self.telemetry_user_enabled() { return; } - self.user_stats.entry(user.to_string()).or_default() - .octets_from_client.fetch_add(bytes, Ordering::Relaxed); + if let Some(stats) = self.user_stats.get(user) { + stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); + return; + } + self.user_stats + .entry(user.to_string()) + .or_default() + .octets_from_client + .fetch_add(bytes, Ordering::Relaxed); } pub fn add_user_octets_to(&self, user: &str, bytes: u64) { if !self.telemetry_user_enabled() { return; } - self.user_stats.entry(user.to_string()).or_default() - .octets_to_client.fetch_add(bytes, Ordering::Relaxed); + if let Some(stats) = self.user_stats.get(user) { + stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); + return; + } + self.user_stats + .entry(user.to_string()) + .or_default() + .octets_to_client + .fetch_add(bytes, Ordering::Relaxed); } pub fn increment_user_msgs_from(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.user_stats.entry(user.to_string()).or_default() - .msgs_from_client.fetch_add(1, Ordering::Relaxed); + if let Some(stats) = self.user_stats.get(user) { + stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); + return; + } + self.user_stats + .entry(user.to_string()) + .or_default() + .msgs_from_client + .fetch_add(1, Ordering::Relaxed); } pub fn increment_user_msgs_to(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.user_stats.entry(user.to_string()).or_default() - .msgs_to_client.fetch_add(1, Ordering::Relaxed); + if let Some(stats) = self.user_stats.get(user) { + stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); + return; + } + self.user_stats + .entry(user.to_string()) + .or_default() + .msgs_to_client + .fetch_add(1, Ordering::Relaxed); } pub fn get_user_total_octets(&self, user: &str) -> u64 { diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 1dab2f4..22f40b5 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -119,6 +119,8 @@ pub struct MePool { pub(super) ping_tracker: Arc>>, pub(super) rtt_stats: Arc>>, pub(super) nat_reflection_cache: Arc>, + pub(super) nat_reflection_singleflight_v4: Arc>, + pub(super) nat_reflection_singleflight_v6: Arc>, pub(super) writer_available: Arc, pub(super) refill_inflight: Arc>>, pub(super) refill_inflight_dc: Arc>>, @@ -323,6 +325,8 @@ impl MePool { ping_tracker: Arc::new(Mutex::new(HashMap::new())), rtt_stats: Arc::new(Mutex::new(HashMap::new())), nat_reflection_cache: Arc::new(Mutex::new(NatReflectionCache::default())), + nat_reflection_singleflight_v4: Arc::new(Mutex::new(())), + nat_reflection_singleflight_v6: Arc::new(Mutex::new(())), writer_available: Arc::new(Notify::new()), refill_inflight: Arc::new(Mutex::new(HashSet::new())), refill_inflight_dc: Arc::new(Mutex::new(HashSet::new())), diff --git a/src/transport/middle_proxy/pool_nat.rs b/src/transport/middle_proxy/pool_nat.rs index bfcb0e2..07ae0b8 100644 --- a/src/transport/middle_proxy/pool_nat.rs +++ b/src/transport/middle_proxy/pool_nat.rs @@ -248,6 +248,43 @@ impl MePool { } } + let _singleflight_guard = if use_shared_cache { + Some(match family { + IpFamily::V4 => self.nat_reflection_singleflight_v4.lock().await, + IpFamily::V6 => self.nat_reflection_singleflight_v6.lock().await, + }) + } else { + None + }; + + if use_shared_cache + && let Some(until) = *self.stun_backoff_until.read().await + && Instant::now() < until + { + if let Ok(cache) = self.nat_reflection_cache.try_lock() { + let slot = match family { + IpFamily::V4 => cache.v4, + IpFamily::V6 => cache.v6, + }; + return slot.map(|(_, addr)| addr); + } + return None; + } + + if use_shared_cache + && let Ok(mut cache) = self.nat_reflection_cache.try_lock() + { + let slot = match family { + IpFamily::V4 => &mut cache.v4, + IpFamily::V6 => &mut cache.v6, + }; + if let Some((ts, addr)) = slot + && ts.elapsed() < STUN_CACHE_TTL + { + return Some(*addr); + } + } + let attempt = if use_shared_cache { self.nat_probe_attempts.fetch_add(1, std::sync::atomic::Ordering::Relaxed) } else { diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index e907d25..2a99164 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -124,7 +124,7 @@ pub(crate) async fn reader_loop( let data = Bytes::copy_from_slice(&body[12..]); trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS"); - let routed = reg.route(cid, MeResponse::Data { flags, data }).await; + let routed = reg.route_nowait(cid, MeResponse::Data { flags, data }).await; if !matches!(routed, RouteResult::Routed) { match routed { RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), @@ -147,7 +147,7 @@ pub(crate) async fn reader_loop( let cfm = u32::from_le_bytes(body[8..12].try_into().unwrap()); trace!(cid, cfm, "RPC_SIMPLE_ACK"); - let routed = reg.route(cid, MeResponse::Ack(cfm)).await; + let routed = reg.route_nowait(cid, MeResponse::Ack(cfm)).await; if !matches!(routed, RouteResult::Routed) { match routed { RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index e4d0031..66a7f81 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -208,6 +208,23 @@ impl ConnRegistry { } } + pub async fn route_nowait(&self, id: u64, resp: MeResponse) -> RouteResult { + let tx = { + let inner = self.inner.read().await; + inner.map.get(&id).cloned() + }; + + let Some(tx) = tx else { + return RouteResult::NoConn; + }; + + match tx.try_send(resp) { + Ok(()) => RouteResult::Routed, + Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed, + Err(TrySendError::Full(_)) => RouteResult::QueueFullBase, + } + } + pub async fn bind_writer( &self, conn_id: u64, diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index d9f0ede..b9db0eb 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -225,6 +225,7 @@ pub struct UpstreamManager { upstreams: Arc>>, connect_retry_attempts: u32, connect_retry_backoff: Duration, + connect_budget: Duration, unhealthy_fail_threshold: u32, connect_failfast_hard_errors: bool, stats: Arc, @@ -235,6 +236,7 @@ impl UpstreamManager { configs: Vec, connect_retry_attempts: u32, connect_retry_backoff_ms: u64, + connect_budget_ms: u64, unhealthy_fail_threshold: u32, connect_failfast_hard_errors: bool, stats: Arc, @@ -248,6 +250,7 @@ impl UpstreamManager { upstreams: Arc::new(RwLock::new(states)), connect_retry_attempts: connect_retry_attempts.max(1), connect_retry_backoff: Duration::from_millis(connect_retry_backoff_ms), + connect_budget: Duration::from_millis(connect_budget_ms.max(1)), unhealthy_fail_threshold: unhealthy_fail_threshold.max(1), connect_failfast_hard_errors, stats, @@ -593,11 +596,27 @@ impl UpstreamManager { let mut last_error: Option = None; let mut attempts_used = 0u32; for attempt in 1..=self.connect_retry_attempts { + let elapsed = connect_started_at.elapsed(); + if elapsed >= self.connect_budget { + last_error = Some(ProxyError::ConnectionTimeout { + addr: target.to_string(), + }); + break; + } + let remaining_budget = self.connect_budget.saturating_sub(elapsed); + let attempt_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS) + .min(remaining_budget); + if attempt_timeout.is_zero() { + last_error = Some(ProxyError::ConnectionTimeout { + addr: target.to_string(), + }); + break; + } attempts_used = attempt; self.stats.increment_upstream_connect_attempt_total(); let start = Instant::now(); match self - .connect_via_upstream(&upstream, target, bind_rr.clone()) + .connect_via_upstream(&upstream, target, bind_rr.clone(), attempt_timeout) .await { Ok((stream, egress)) => { @@ -707,6 +726,7 @@ impl UpstreamManager { config: &UpstreamConfig, target: SocketAddr, bind_rr: Option>, + connect_timeout: Duration, ) -> Result<(TcpStream, UpstreamEgressInfo)> { match &config.upstream_type { UpstreamType::Direct { interface, bind_addresses } => { @@ -735,7 +755,6 @@ impl UpstreamManager { let std_stream: std::net::TcpStream = socket.into(); let stream = TcpStream::from_std(std_stream)?; - let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS); match tokio::time::timeout(connect_timeout, stream.writable()).await { Ok(Ok(())) => {} Ok(Err(e)) => return Err(ProxyError::Io(e)), @@ -762,7 +781,6 @@ impl UpstreamManager { )) }, UpstreamType::Socks4 { address, interface, user_id } => { - let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS); // Try to parse as SocketAddr first (IP:port), otherwise treat as hostname:port let mut stream = if let Ok(proxy_addr) = address.parse::() { // IP:port format - use socket with optional interface binding @@ -841,7 +859,6 @@ impl UpstreamManager { )) }, UpstreamType::Socks5 { address, interface, username, password } => { - let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS); // Try to parse as SocketAddr first (IP:port), otherwise treat as hostname:port let mut stream = if let Ok(proxy_addr) = address.parse::() { // IP:port format - use socket with optional interface binding @@ -1165,7 +1182,14 @@ impl UpstreamManager { target: SocketAddr, ) -> Result { let start = Instant::now(); - let _ = self.connect_via_upstream(config, target, bind_rr).await?; + let _ = self + .connect_via_upstream( + config, + target, + bind_rr, + Duration::from_secs(DC_PING_TIMEOUT_SECS), + ) + .await?; Ok(start.elapsed().as_secs_f64() * 1000.0) } @@ -1337,7 +1361,12 @@ impl UpstreamManager { let start = Instant::now(); let result = tokio::time::timeout( Duration::from_secs(HEALTH_CHECK_CONNECT_TIMEOUT_SECS), - self.connect_via_upstream(&config, endpoint, Some(bind_rr.clone())), + self.connect_via_upstream( + &config, + endpoint, + Some(bind_rr.clone()), + Duration::from_secs(HEALTH_CHECK_CONNECT_TIMEOUT_SECS), + ), ) .await;