diff --git a/src/transport/middle_proxy/pool.rs b/src/transport/middle_proxy/pool.rs index 062db67..851c0b7 100644 --- a/src/transport/middle_proxy/pool.rs +++ b/src/transport/middle_proxy/pool.rs @@ -452,7 +452,7 @@ impl MePool { }; self.writers.write().await.push(writer.clone()); self.conn_count.fetch_add(1, Ordering::Relaxed); - self.writer_available.notify_waiters(); + self.writer_available.notify_one(); let reg = self.registry.clone(); let writers_arc = self.writers_arc(); diff --git a/src/transport/middle_proxy/send.rs b/src/transport/middle_proxy/send.rs index 627906d..2ebafea 100644 --- a/src/transport/middle_proxy/send.rs +++ b/src/transport/middle_proxy/send.rs @@ -62,6 +62,8 @@ impl MePool { let mut writers_snapshot = { let ws = self.writers.read().await; if ws.is_empty() { + // Create waiter before recovery attempts so notify_one permits are not missed. + let waiter = self.writer_available.notified(); drop(ws); for family in self.family_order() { let map = match family { @@ -72,13 +74,19 @@ impl MePool { for (ip, port) in addrs { let addr = SocketAddr::new(*ip, *port); if self.connect_one(addr, self.rng.as_ref()).await.is_ok() { - self.writer_available.notify_waiters(); + self.writer_available.notify_one(); break; } } } } - if tokio::time::timeout(Duration::from_secs(3), self.writer_available.notified()).await.is_err() { + if !self.writers.read().await.is_empty() { + continue; + } + if tokio::time::timeout(Duration::from_secs(3), waiter).await.is_err() { + if !self.writers.read().await.is_empty() { + continue; + } return Err(ProxyError::Proxy("All ME connections dead (waited 3s)".into())); } continue; diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index 0f458f2..6dcc36f 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -394,6 +394,7 @@ impl UpstreamManager { Ok(stream) }, UpstreamType::Socks4 { address, interface, user_id } => { + let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS); // Try to parse as SocketAddr first (IP:port), otherwise treat as hostname:port let mut stream = if let Ok(proxy_addr) = address.parse::() { // IP:port format - use socket with optional interface binding @@ -416,7 +417,15 @@ impl UpstreamManager { let std_stream: std::net::TcpStream = socket.into(); let stream = TcpStream::from_std(std_stream)?; - stream.writable().await?; + match tokio::time::timeout(connect_timeout, stream.writable()).await { + Ok(Ok(())) => {} + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) => { + return Err(ProxyError::ConnectionTimeout { + addr: proxy_addr.to_string(), + }); + } + } if let Some(e) = stream.take_error()? { return Err(ProxyError::Io(e)); } @@ -427,8 +436,15 @@ impl UpstreamManager { if interface.is_some() { warn!("SOCKS4 interface binding is not supported for hostname addresses, ignoring"); } - TcpStream::connect(address).await - .map_err(ProxyError::Io)? + match tokio::time::timeout(connect_timeout, TcpStream::connect(address)).await { + Ok(Ok(stream)) => stream, + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) => { + return Err(ProxyError::ConnectionTimeout { + addr: address.clone(), + }); + } + } }; // replace socks user_id with config.selected_scope, if set @@ -436,10 +452,19 @@ impl UpstreamManager { .filter(|s| !s.is_empty()); let _user_id: Option<&str> = scope.or(user_id.as_deref()); - connect_socks4(&mut stream, target, _user_id).await?; + match tokio::time::timeout(connect_timeout, connect_socks4(&mut stream, target, _user_id)).await { + Ok(Ok(())) => {} + Ok(Err(e)) => return Err(e), + Err(_) => { + return Err(ProxyError::ConnectionTimeout { + addr: target.to_string(), + }); + } + } Ok(stream) }, UpstreamType::Socks5 { address, interface, username, password } => { + let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS); // Try to parse as SocketAddr first (IP:port), otherwise treat as hostname:port let mut stream = if let Ok(proxy_addr) = address.parse::() { // IP:port format - use socket with optional interface binding @@ -462,7 +487,15 @@ impl UpstreamManager { let std_stream: std::net::TcpStream = socket.into(); let stream = TcpStream::from_std(std_stream)?; - stream.writable().await?; + match tokio::time::timeout(connect_timeout, stream.writable()).await { + Ok(Ok(())) => {} + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) => { + return Err(ProxyError::ConnectionTimeout { + addr: proxy_addr.to_string(), + }); + } + } if let Some(e) = stream.take_error()? { return Err(ProxyError::Io(e)); } @@ -473,8 +506,15 @@ impl UpstreamManager { if interface.is_some() { warn!("SOCKS5 interface binding is not supported for hostname addresses, ignoring"); } - TcpStream::connect(address).await - .map_err(ProxyError::Io)? + match tokio::time::timeout(connect_timeout, TcpStream::connect(address)).await { + Ok(Ok(stream)) => stream, + Ok(Err(e)) => return Err(ProxyError::Io(e)), + Err(_) => { + return Err(ProxyError::ConnectionTimeout { + addr: address.clone(), + }); + } + } }; debug!(config = ?config, "Socks5 connection"); @@ -484,7 +524,20 @@ impl UpstreamManager { let _username: Option<&str> = scope.or(username.as_deref()); let _password: Option<&str> = scope.or(password.as_deref()); - connect_socks5(&mut stream, target, _username, _password).await?; + match tokio::time::timeout( + connect_timeout, + connect_socks5(&mut stream, target, _username, _password), + ) + .await + { + Ok(Ok(())) => {} + Ok(Err(e)) => return Err(e), + Err(_) => { + return Err(ProxyError::ConnectionTimeout { + addr: target.to_string(), + }); + } + } Ok(stream) }, }