diff --git a/docs/FAQ.en.md b/docs/FAQ.en.md index 4af1c34..d61ee6e 100644 --- a/docs/FAQ.en.md +++ b/docs/FAQ.en.md @@ -122,11 +122,11 @@ enabled = true ``` #### Shadowsocks as Upstream -Requires `use_middle_proxy = false`. +Works with both `use_middle_proxy = false` and `use_middle_proxy = true`. ```toml [general] -use_middle_proxy = false +use_middle_proxy = true [[upstreams]] type = "shadowsocks" diff --git a/docs/TUNING.en.md b/docs/TUNING.en.md index 6a6a320..a62925d 100644 --- a/docs/TUNING.en.md +++ b/docs/TUNING.en.md @@ -117,7 +117,7 @@ Defaults below are code defaults (used when a key is omitted), not necessarily v 8. In ME mode, the selected upstream is also used for ME TCP dial path. 9. In ME mode for `direct` upstream with bind/interface, STUN reflection logic is bind-aware for KDF source material. 10. In ME mode for SOCKS upstream, SOCKS `BND.ADDR/BND.PORT` is used for KDF when it is valid/public for the same family. -11. `shadowsocks` upstreams require `general.use_middle_proxy = false`. Config load fails fast if ME mode is enabled. +11. `shadowsocks` upstreams work in both Direct and ME modes. In ME mode, the connected local Shadowsocks address is reused for bind-aware STUN reflection when available. ## Upstream Configuration Examples @@ -157,7 +157,7 @@ enabled = true ```toml [general] -use_middle_proxy = false +use_middle_proxy = true [[upstreams]] type = "shadowsocks" diff --git a/src/config/load.rs b/src/config/load.rs index c797637..590489a 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -129,16 +129,6 @@ fn sanitize_ad_tag(ad_tag: &mut Option) { } fn validate_upstreams(config: &ProxyConfig) -> Result<()> { - let has_enabled_shadowsocks = config.upstreams.iter().any(|upstream| { - upstream.enabled && matches!(upstream.upstream_type, UpstreamType::Shadowsocks { .. }) - }); - - if has_enabled_shadowsocks && config.general.use_middle_proxy { - return Err(ProxyError::Config( - "shadowsocks upstreams require general.use_middle_proxy = false".to_string(), - )); - } - for upstream in &config.upstreams { if let UpstreamType::Shadowsocks { url, .. } = &upstream.upstream_type { let parsed = ShadowsocksServerConfig::from_url(url) @@ -2275,7 +2265,7 @@ mod tests { } #[test] - fn shadowsocks_requires_direct_mode() { + fn shadowsocks_is_allowed_with_middle_proxy() { let toml = format!( r#" [general] @@ -2294,11 +2284,11 @@ mod tests { url = TEST_SHADOWSOCKS_URL, ); let dir = std::env::temp_dir(); - let path = dir.join("telemt_shadowsocks_me_reject_test.toml"); + let path = dir.join("telemt_shadowsocks_me_allow_test.toml"); std::fs::write(&path, toml).unwrap(); - let err = ProxyConfig::load(&path).unwrap_err().to_string(); + let loaded = ProxyConfig::load(&path); - assert!(err.contains("shadowsocks upstreams require general.use_middle_proxy = false")); + assert!(loaded.is_ok()); let _ = std::fs::remove_file(path); } diff --git a/src/transport/middle_proxy/codec.rs b/src/transport/middle_proxy/codec.rs index 7f51aaa..66abe34 100644 --- a/src/transport/middle_proxy/codec.rs +++ b/src/transport/middle_proxy/codec.rs @@ -4,8 +4,9 @@ use bytes::Bytes; use crate::crypto::{AesCbc, crc32, crc32c}; use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; +use crate::transport::UpstreamStream; -/// Commands sent to dedicated writer tasks to avoid mutex contention on TCP writes. +/// Commands sent to dedicated writer tasks to avoid mutex contention on upstream writes. pub(crate) enum WriterCommand { Data(Bytes), DataAndFlush(Bytes), @@ -213,7 +214,7 @@ pub(crate) fn cbc_decrypt_inplace( } pub(crate) struct RpcWriter { - pub(crate) writer: tokio::io::WriteHalf, + pub(crate) writer: tokio::io::WriteHalf, pub(crate) key: [u8; 32], pub(crate) iv: [u8; 16], pub(crate) seq_no: i32, diff --git a/src/transport/middle_proxy/handshake.rs b/src/transport/middle_proxy/handshake.rs index 39e34d7..4a40f8d 100644 --- a/src/transport/middle_proxy/handshake.rs +++ b/src/transport/middle_proxy/handshake.rs @@ -1,13 +1,15 @@ +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; use std::sync::atomic::Ordering; use std::time::{Duration, Instant}; -use std::collections::hash_map::DefaultHasher; -use std::hash::{Hash, Hasher}; use socket2::{SockRef, TcpKeepalive}; #[cfg(target_os = "linux")] use libc; +#[cfg(unix)] +use std::os::fd::BorrowedFd; #[cfg(target_os = "linux")] -use std::os::fd::{AsRawFd, RawFd}; +use std::os::fd::RawFd; #[cfg(target_os = "linux")] use std::os::raw::c_int; @@ -26,7 +28,7 @@ use crate::protocol::constants::{ ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32, RPC_HANDSHAKE_ERROR_U32, rpc_crypto_flags, }; -use crate::transport::{UpstreamEgressInfo, UpstreamRouteKind}; +use crate::transport::{UpstreamEgressInfo, UpstreamRouteKind, UpstreamStream}; use super::codec::{ RpcChecksumMode, build_handshake_payload, build_nonce_payload, build_rpc_frame, @@ -57,8 +59,8 @@ impl KdfClientPortSource { /// Result of a successful ME handshake with timings. pub(crate) struct HandshakeOutput { - pub rd: ReadHalf, - pub wr: WriteHalf, + pub rd: ReadHalf, + pub wr: WriteHalf, pub source_ip: IpAddr, pub read_key: [u8; 32], pub read_iv: [u8; 16], @@ -89,15 +91,19 @@ impl MePool { i16::try_from(self.resolve_dc_for_endpoint(addr).await).ok() } - fn direct_bind_ip_for_stun( + fn non_socks_bind_ip_for_stun( family: IpFamily, upstream_egress: Option, ) -> Option { let info = upstream_egress?; - if info.route_kind != UpstreamRouteKind::Direct { - return None; - } - match (family, info.direct_bind_ip) { + let bind_ip = match info.route_kind { + UpstreamRouteKind::Direct => info + .direct_bind_ip + .or_else(|| info.local_addr.map(|addr| addr.ip())), + UpstreamRouteKind::Shadowsocks => info.local_addr.map(|addr| addr.ip()), + UpstreamRouteKind::Socks4 | UpstreamRouteKind::Socks5 => None, + }; + match (family, bind_ip) { (IpFamily::V4, Some(IpAddr::V4(ip))) => Some(IpAddr::V4(ip)), (IpFamily::V6, Some(IpAddr::V6(ip))) => Some(IpAddr::V6(ip)), _ => None, @@ -141,12 +147,12 @@ impl MePool { } } - /// TCP connect with timeout + return RTT in milliseconds. + /// Connect to a middle-proxy endpoint and return RTT in milliseconds. pub(crate) async fn connect_tcp( &self, addr: SocketAddr, dc_idx_override: Option, - ) -> Result<(TcpStream, f64, Option)> { + ) -> Result<(UpstreamStream, f64, Option)> { let start = Instant::now(); let (stream, upstream_egress) = if let Some(upstream) = &self.upstream { let dc_idx = if let Some(dc_idx) = dc_idx_override { @@ -154,7 +160,9 @@ impl MePool { } else { self.resolve_dc_idx_for_endpoint(addr).await }; - let (stream, egress) = upstream.connect_with_details(addr, dc_idx, None).await?; + let (stream, egress) = upstream + .connect_stream_with_details(addr, dc_idx, None) + .await?; (stream, Some(egress)) } else { let connect_fut = async { @@ -183,7 +191,7 @@ impl MePool { .map_err(|_| ProxyError::ConnectionTimeout { addr: addr.to_string(), })??; - (stream, None) + (UpstreamStream::Tcp(stream), None) }; let connect_ms = start.elapsed().as_secs_f64() * 1000.0; @@ -192,14 +200,22 @@ impl MePool { warn!(error = %e, "ME keepalive setup failed"); } #[cfg(target_os = "linux")] - if let Err(e) = Self::configure_user_timeout(stream.as_raw_fd()) { + if let Err(e) = Self::configure_user_timeout(stream.raw_fd()) { warn!(error = %e, "ME TCP_USER_TIMEOUT setup failed"); } Ok((stream, connect_ms, upstream_egress)) } - fn configure_keepalive(stream: &TcpStream) -> std::io::Result<()> { - let sock = SockRef::from(stream); + fn configure_keepalive(stream: &UpstreamStream) -> std::io::Result<()> { + #[cfg(unix)] + let borrowed = unsafe { BorrowedFd::borrow_raw(stream.raw_fd()) }; + #[cfg(unix)] + let sock = SockRef::from(&borrowed); + #[cfg(not(unix))] + let sock = match stream { + UpstreamStream::Tcp(stream) => SockRef::from(stream), + UpstreamStream::Shadowsocks(_) => return Ok(()), + }; let ka = TcpKeepalive::new().with_time(Duration::from_secs(30)); // Mirror socket2 v0.5.10 target gate for with_retries(), the stricter method. @@ -243,11 +259,11 @@ impl MePool { Ok(()) } - /// Perform full ME RPC handshake on an established TCP stream. + /// Perform full ME RPC handshake on an established upstream stream. /// Returns cipher keys/ivs and split halves; does not register writer. pub(crate) async fn handshake_only( &self, - stream: TcpStream, + stream: UpstreamStream, addr: SocketAddr, upstream_egress: Option, rng: &SecureRandom, @@ -300,7 +316,7 @@ impl MePool { MeSocksKdfPolicy::Compat => { self.stats.increment_me_socks_kdf_compat_fallback(); if self.nat_probe { - let bind_ip = Self::direct_bind_ip_for_stun(family, upstream_egress); + let bind_ip = Self::non_socks_bind_ip_for_stun(family, upstream_egress); self.maybe_reflect_public_addr(family, bind_ip).await } else { None @@ -308,7 +324,7 @@ impl MePool { } } } else if self.nat_probe { - let bind_ip = Self::direct_bind_ip_for_stun(family, upstream_egress); + let bind_ip = Self::non_socks_bind_ip_for_stun(family, upstream_egress); self.maybe_reflect_public_addr(family, bind_ip).await } else { None @@ -722,6 +738,8 @@ mod tests { use std::io::ErrorKind; use tokio::net::{TcpListener, TcpStream}; + use crate::transport::{UpstreamEgressInfo, UpstreamRouteKind, UpstreamStream}; + #[tokio::test] async fn test_configure_keepalive_loopback() { let listener = match TcpListener::bind("127.0.0.1:0").await { @@ -740,6 +758,7 @@ mod tests { Err(error) if error.kind() == ErrorKind::PermissionDenied => return, Err(error) => panic!("connect failed: {error}"), }; + let stream = UpstreamStream::Tcp(stream); if let Err(error) = MePool::configure_keepalive(&stream) { if error.kind() == ErrorKind::PermissionDenied { @@ -777,4 +796,56 @@ mod tests { .with_interval(Duration::from_secs(10)) .with_retries(3); } + + #[test] + fn direct_route_prefers_explicit_bind_ip_for_stun() { + let bind_ip: IpAddr = "198.51.100.10".parse().unwrap(); + let local_addr: SocketAddr = "198.51.100.20:40000".parse().unwrap(); + let egress = UpstreamEgressInfo { + upstream_id: 0, + route_kind: UpstreamRouteKind::Direct, + local_addr: Some(local_addr), + direct_bind_ip: Some(bind_ip), + socks_bound_addr: None, + socks_proxy_addr: None, + }; + + let selected = MePool::non_socks_bind_ip_for_stun(IpFamily::V4, Some(egress)); + + assert_eq!(selected, Some(bind_ip)); + } + + #[test] + fn shadowsocks_route_uses_local_addr_for_stun() { + let local_addr: SocketAddr = "198.51.100.30:40001".parse().unwrap(); + let egress = UpstreamEgressInfo { + upstream_id: 1, + route_kind: UpstreamRouteKind::Shadowsocks, + local_addr: Some(local_addr), + direct_bind_ip: None, + socks_bound_addr: None, + socks_proxy_addr: None, + }; + + let selected = MePool::non_socks_bind_ip_for_stun(IpFamily::V4, Some(egress)); + + assert_eq!(selected, Some(local_addr.ip())); + } + + #[test] + fn socks_route_keeps_compat_fallback_unbound() { + let local_addr: SocketAddr = "198.51.100.40:40002".parse().unwrap(); + let egress = UpstreamEgressInfo { + upstream_id: 2, + route_kind: UpstreamRouteKind::Socks5, + local_addr: Some(local_addr), + direct_bind_ip: None, + socks_bound_addr: None, + socks_proxy_addr: Some("198.51.100.50:1080".parse().unwrap()), + }; + + let selected = MePool::non_socks_bind_ip_for_stun(IpFamily::V4, Some(egress)); + + assert_eq!(selected, None); + } } diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index 8b15fc1..ed4940f 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -6,7 +6,6 @@ use std::time::Instant; use bytes::{Bytes, BytesMut}; use tokio::io::AsyncReadExt; -use tokio::net::TcpStream; use tokio::sync::{Mutex, mpsc}; use tokio::sync::mpsc::error::TrySendError; use tokio_util::sync::CancellationToken; @@ -16,13 +15,14 @@ use crate::crypto::AesCbc; use crate::error::{ProxyError, Result}; use crate::protocol::constants::*; use crate::stats::Stats; +use crate::transport::UpstreamStream; use super::codec::{RpcChecksumMode, WriterCommand, rpc_crc}; use super::registry::RouteResult; use super::{ConnRegistry, MeResponse}; pub(crate) async fn reader_loop( - mut rd: tokio::io::ReadHalf, + mut rd: tokio::io::ReadHalf, dk: [u8; 32], mut div: [u8; 16], crc_mode: RpcChecksumMode, diff --git a/src/transport/upstream.rs b/src/transport/upstream.rs index b0d82b1..666ca7d 100644 --- a/src/transport/upstream.rs +++ b/src/transport/upstream.rs @@ -12,6 +12,8 @@ use std::sync::Arc; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::task::{Context, Poll}; use std::time::Duration; +#[cfg(unix)] +use std::os::fd::{AsRawFd, RawFd}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::TcpStream; use tokio::sync::RwLock; @@ -173,6 +175,7 @@ pub struct StartupPingResult { pub both_available: bool, } +/// Transport stream returned by an upstream connection attempt. pub enum UpstreamStream { Tcp(TcpStream), Shadowsocks(Box), @@ -188,6 +191,35 @@ impl std::fmt::Debug for UpstreamStream { } impl UpstreamStream { + pub(crate) fn local_addr(&self) -> std::io::Result { + match self { + Self::Tcp(stream) => stream.local_addr(), + Self::Shadowsocks(stream) => stream.get_ref().local_addr(), + } + } + + pub(crate) fn peer_addr(&self) -> std::io::Result { + match self { + Self::Tcp(stream) => stream.peer_addr(), + Self::Shadowsocks(stream) => stream.get_ref().peer_addr(), + } + } + + pub(crate) fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> { + match self { + Self::Tcp(stream) => stream.set_nodelay(nodelay), + Self::Shadowsocks(stream) => stream.get_ref().set_nodelay(nodelay), + } + } + + #[cfg(unix)] + pub(crate) fn raw_fd(&self) -> RawFd { + match self { + Self::Tcp(stream) => stream.as_raw_fd(), + Self::Shadowsocks(stream) => stream.get_ref().as_raw_fd(), + } + } + pub fn into_tcp(self) -> Result { match self { Self::Tcp(stream) => Ok(stream), @@ -744,13 +776,13 @@ impl UpstreamManager { Ok(stream) } - /// Connect to target through a selected upstream and return egress details. - pub async fn connect_with_details( + /// Connect to target through a selected upstream and return transport egress details. + pub(crate) async fn connect_stream_with_details( &self, target: SocketAddr, dc_idx: Option, scope: Option<&str>, - ) -> Result<(TcpStream, UpstreamEgressInfo)> { + ) -> Result<(UpstreamStream, UpstreamEgressInfo)> { let idx = self .select_upstream(dc_idx, scope) .await @@ -774,6 +806,19 @@ impl UpstreamManager { let (stream, egress) = self .connect_selected_upstream(idx, upstream, target, dc_idx, bind_rr) .await?; + Ok((stream, egress)) + } + + /// Connect to target through a selected upstream and return egress details. + pub async fn connect_with_details( + &self, + target: SocketAddr, + dc_idx: Option, + scope: Option<&str>, + ) -> Result<(TcpStream, UpstreamEgressInfo)> { + let (stream, egress) = self + .connect_stream_with_details(target, dc_idx, scope) + .await?; Ok((stream.into_tcp()?, egress)) }