Upstreams for ME + Egress-data from UM + ME-over-SOCKS + Bind-aware STUN

This commit is contained in:
Alexey 2026-02-28 01:20:17 +03:00
parent ac064fe773
commit 3d9660f83e
No known key found for this signature in database
10 changed files with 307 additions and 78 deletions

View File

@ -509,6 +509,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
cfg_v6.map.clone(), cfg_v6.map.clone(),
cfg_v4.default_dc.or(cfg_v6.default_dc), cfg_v4.default_dc.or(cfg_v6.default_dc),
decision.clone(), decision.clone(),
Some(upstream_manager.clone()),
rng.clone(), rng.clone(),
stats.clone(), stats.clone(),
config.general.me_keepalive_enabled, config.general.me_keepalive_enabled,

View File

@ -41,16 +41,31 @@ pub async fn stun_probe_dual(stun_addr: &str) -> Result<DualStunResult> {
} }
pub async fn stun_probe_family(stun_addr: &str, family: IpFamily) -> Result<Option<StunProbeResult>> { pub async fn stun_probe_family(stun_addr: &str, family: IpFamily) -> Result<Option<StunProbeResult>> {
stun_probe_family_with_bind(stun_addr, family, None).await
}
pub async fn stun_probe_family_with_bind(
stun_addr: &str,
family: IpFamily,
bind_ip: Option<IpAddr>,
) -> Result<Option<StunProbeResult>> {
use rand::RngCore; use rand::RngCore;
let bind_addr = match family { let bind_addr = match (family, bind_ip) {
IpFamily::V4 => "0.0.0.0:0", (IpFamily::V4, Some(IpAddr::V4(ip))) => SocketAddr::new(IpAddr::V4(ip), 0),
IpFamily::V6 => "[::]:0", (IpFamily::V6, Some(IpAddr::V6(ip))) => SocketAddr::new(IpAddr::V6(ip), 0),
(IpFamily::V4, Some(IpAddr::V6(_))) | (IpFamily::V6, Some(IpAddr::V4(_))) => {
return Ok(None);
}
(IpFamily::V4, None) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
(IpFamily::V6, None) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
}; };
let socket = UdpSocket::bind(bind_addr) let socket = match UdpSocket::bind(bind_addr).await {
.await Ok(socket) => socket,
.map_err(|e| ProxyError::Proxy(format!("STUN bind failed: {e}")))?; Err(_) if bind_ip.is_some() => return Ok(None),
Err(e) => return Err(ProxyError::Proxy(format!("STUN bind failed: {e}"))),
};
let target_addr = resolve_stun_addr(stun_addr, family).await?; let target_addr = resolve_stun_addr(stun_addr, family).await?;
if let Some(addr) = target_addr { if let Some(addr) = target_addr {

View File

@ -17,10 +17,12 @@ use tracing::{debug, info, warn};
use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256}; use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::network::IpFamily; use crate::network::IpFamily;
use crate::network::probe::is_bogon;
use crate::protocol::constants::{ use crate::protocol::constants::{
ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32, ME_CONNECT_TIMEOUT_SECS, ME_HANDSHAKE_TIMEOUT_SECS, RPC_CRYPTO_AES_U32,
RPC_HANDSHAKE_ERROR_U32, rpc_crypto_flags, RPC_HANDSHAKE_ERROR_U32, rpc_crypto_flags,
}; };
use crate::transport::{UpstreamEgressInfo, UpstreamRouteKind};
use super::codec::{ use super::codec::{
RpcChecksumMode, build_handshake_payload, build_nonce_payload, build_rpc_frame, RpcChecksumMode, build_handshake_payload, build_nonce_payload, build_rpc_frame,
@ -43,9 +45,96 @@ pub(crate) struct HandshakeOutput {
} }
impl MePool { impl MePool {
async fn resolve_dc_idx_for_endpoint(&self, addr: SocketAddr) -> Option<i16> {
if addr.is_ipv4() {
let map = self.proxy_map_v4.read().await;
for (dc, addrs) in map.iter() {
if addrs
.iter()
.any(|(ip, port)| SocketAddr::new(*ip, *port) == addr)
{
let abs_dc = dc.abs();
if abs_dc > 0
&& let Ok(dc_idx) = i16::try_from(abs_dc)
{
return Some(dc_idx);
}
}
}
} else {
let map = self.proxy_map_v6.read().await;
for (dc, addrs) in map.iter() {
if addrs
.iter()
.any(|(ip, port)| SocketAddr::new(*ip, *port) == addr)
{
let abs_dc = dc.abs();
if abs_dc > 0
&& let Ok(dc_idx) = i16::try_from(abs_dc)
{
return Some(dc_idx);
}
}
}
}
None
}
fn direct_bind_ip_for_stun(
family: IpFamily,
upstream_egress: Option<UpstreamEgressInfo>,
) -> Option<IpAddr> {
let info = upstream_egress?;
if info.route_kind != UpstreamRouteKind::Direct {
return None;
}
match (family, info.direct_bind_ip) {
(IpFamily::V4, Some(IpAddr::V4(ip))) => Some(IpAddr::V4(ip)),
(IpFamily::V6, Some(IpAddr::V6(ip))) => Some(IpAddr::V6(ip)),
_ => None,
}
}
fn select_socks_bound_addr(
family: IpFamily,
upstream_egress: Option<UpstreamEgressInfo>,
) -> Option<SocketAddr> {
let info = upstream_egress?;
if !matches!(
info.route_kind,
UpstreamRouteKind::Socks4 | UpstreamRouteKind::Socks5
) {
return None;
}
let bound = info.socks_bound_addr?;
let family_matches = matches!(
(family, bound.ip()),
(IpFamily::V4, IpAddr::V4(_)) | (IpFamily::V6, IpAddr::V6(_))
);
if !family_matches || is_bogon(bound.ip()) || bound.ip().is_unspecified() {
return None;
}
Some(bound)
}
/// TCP connect with timeout + return RTT in milliseconds. /// TCP connect with timeout + return RTT in milliseconds.
pub(crate) async fn connect_tcp(&self, addr: SocketAddr) -> Result<(TcpStream, f64)> { pub(crate) async fn connect_tcp(
&self,
addr: SocketAddr,
) -> Result<(TcpStream, f64, Option<UpstreamEgressInfo>)> {
let start = Instant::now(); let start = Instant::now();
let (stream, upstream_egress) = if let Some(upstream) = &self.upstream {
let dc_idx = self.resolve_dc_idx_for_endpoint(addr).await;
let (stream, egress) = timeout(
Duration::from_secs(ME_CONNECT_TIMEOUT_SECS),
upstream.connect_with_details(addr, dc_idx, None),
)
.await
.map_err(|_| ProxyError::ConnectionTimeout {
addr: addr.to_string(),
})??;
(stream, Some(egress))
} else {
let connect_fut = async { let connect_fut = async {
if addr.is_ipv6() if addr.is_ipv6()
&& let Some(v6) = self.detected_ipv6 && let Some(v6) = self.detected_ipv6
@ -69,7 +158,12 @@ impl MePool {
let stream = timeout(Duration::from_secs(ME_CONNECT_TIMEOUT_SECS), connect_fut) let stream = timeout(Duration::from_secs(ME_CONNECT_TIMEOUT_SECS), connect_fut)
.await .await
.map_err(|_| ProxyError::ConnectionTimeout { addr: addr.to_string() })??; .map_err(|_| ProxyError::ConnectionTimeout {
addr: addr.to_string(),
})??;
(stream, None)
};
let connect_ms = start.elapsed().as_secs_f64() * 1000.0; let connect_ms = start.elapsed().as_secs_f64() * 1000.0;
stream.set_nodelay(true).ok(); stream.set_nodelay(true).ok();
if let Err(e) = Self::configure_keepalive(&stream) { if let Err(e) = Self::configure_keepalive(&stream) {
@ -79,7 +173,7 @@ impl MePool {
if let Err(e) = Self::configure_user_timeout(stream.as_raw_fd()) { if let Err(e) = Self::configure_user_timeout(stream.as_raw_fd()) {
warn!(error = %e, "ME TCP_USER_TIMEOUT setup failed"); warn!(error = %e, "ME TCP_USER_TIMEOUT setup failed");
} }
Ok((stream, connect_ms)) Ok((stream, connect_ms, upstream_egress))
} }
fn configure_keepalive(stream: &TcpStream) -> std::io::Result<()> { fn configure_keepalive(stream: &TcpStream) -> std::io::Result<()> {
@ -117,12 +211,14 @@ impl MePool {
&self, &self,
stream: TcpStream, stream: TcpStream,
addr: SocketAddr, addr: SocketAddr,
upstream_egress: Option<UpstreamEgressInfo>,
rng: &SecureRandom, rng: &SecureRandom,
) -> Result<HandshakeOutput> { ) -> Result<HandshakeOutput> {
let hs_start = Instant::now(); let hs_start = Instant::now();
let local_addr = stream.local_addr().map_err(ProxyError::Io)?; let local_addr = stream.local_addr().map_err(ProxyError::Io)?;
let peer_addr = stream.peer_addr().map_err(ProxyError::Io)?; let transport_peer_addr = stream.peer_addr().map_err(ProxyError::Io)?;
let peer_addr = addr;
let _ = self.maybe_detect_nat_ip(local_addr.ip()).await; let _ = self.maybe_detect_nat_ip(local_addr.ip()).await;
let family = if local_addr.ip().is_ipv4() { let family = if local_addr.ip().is_ipv4() {
@ -130,8 +226,12 @@ impl MePool {
} else { } else {
IpFamily::V6 IpFamily::V6
}; };
let reflected = if self.nat_probe { let socks_bound_addr = Self::select_socks_bound_addr(family, upstream_egress);
self.maybe_reflect_public_addr(family).await let reflected = if let Some(bound) = socks_bound_addr {
Some(bound)
} else if self.nat_probe {
let bind_ip = Self::direct_bind_ip_for_stun(family, upstream_egress);
self.maybe_reflect_public_addr(family, bind_ip).await
} else { } else {
None None
}; };
@ -197,7 +297,9 @@ impl MePool {
%local_addr_nat, %local_addr_nat,
reflected_ip = reflected.map(|r| r.ip()).as_ref().map(ToString::to_string), reflected_ip = reflected.map(|r| r.ip()).as_ref().map(ToString::to_string),
%peer_addr, %peer_addr,
%transport_peer_addr,
%peer_addr_nat, %peer_addr_nat,
socks_bound_addr = socks_bound_addr.map(|v| v.to_string()),
key_selector = format_args!("0x{ks:08x}"), key_selector = format_args!("0x{ks:08x}"),
crypto_schema = format_args!("0x{schema:08x}"), crypto_schema = format_args!("0x{schema:08x}"),
skew_secs = skew, skew_secs = skew,
@ -206,7 +308,11 @@ impl MePool {
let ts_bytes = crypto_ts.to_le_bytes(); let ts_bytes = crypto_ts.to_le_bytes();
let server_port_bytes = peer_addr_nat.port().to_le_bytes(); let server_port_bytes = peer_addr_nat.port().to_le_bytes();
let client_port_bytes = local_addr_nat.port().to_le_bytes(); let client_port_for_kdf = socks_bound_addr
.map(|bound| bound.port())
.filter(|port| *port != 0)
.unwrap_or(local_addr_nat.port());
let client_port_bytes = client_port_for_kdf.to_le_bytes();
let server_ip = extract_ip_material(peer_addr_nat); let server_ip = extract_ip_material(peer_addr_nat);
let client_ip = extract_ip_material(local_addr_nat); let client_ip = extract_ip_material(local_addr_nat);

View File

@ -122,9 +122,9 @@ pub async fn run_me_ping(pool: &Arc<MePool>, rng: &SecureRandom) -> Vec<MePingRe
let mut error = None; let mut error = None;
match pool.connect_tcp(addr).await { match pool.connect_tcp(addr).await {
Ok((stream, conn_rtt)) => { Ok((stream, conn_rtt, upstream_egress)) => {
connect_ms = Some(conn_rtt); connect_ms = Some(conn_rtt);
match pool.handshake_only(stream, addr, rng).await { match pool.handshake_only(stream, addr, upstream_egress, rng).await {
Ok(hs) => { Ok(hs) => {
handshake_ms = Some(hs.handshake_ms); handshake_ms = Some(hs.handshake_ms);
// drop halves to close // drop halves to close

View File

@ -10,6 +10,7 @@ use tokio_util::sync::CancellationToken;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::network::IpFamily; use crate::network::IpFamily;
use crate::network::probe::NetworkDecision; use crate::network::probe::NetworkDecision;
use crate::transport::UpstreamManager;
use super::ConnRegistry; use super::ConnRegistry;
use super::codec::WriterCommand; use super::codec::WriterCommand;
@ -33,6 +34,7 @@ pub struct MePool {
pub(super) writers: Arc<RwLock<Vec<MeWriter>>>, pub(super) writers: Arc<RwLock<Vec<MeWriter>>>,
pub(super) rr: AtomicU64, pub(super) rr: AtomicU64,
pub(super) decision: NetworkDecision, pub(super) decision: NetworkDecision,
pub(super) upstream: Option<Arc<UpstreamManager>>,
pub(super) rng: Arc<SecureRandom>, pub(super) rng: Arc<SecureRandom>,
pub(super) proxy_tag: Option<Vec<u8>>, pub(super) proxy_tag: Option<Vec<u8>>,
pub(super) proxy_secret: Arc<RwLock<Vec<u8>>>, pub(super) proxy_secret: Arc<RwLock<Vec<u8>>>,
@ -121,6 +123,7 @@ impl MePool {
proxy_map_v6: HashMap<i32, Vec<(IpAddr, u16)>>, proxy_map_v6: HashMap<i32, Vec<(IpAddr, u16)>>,
default_dc: Option<i32>, default_dc: Option<i32>,
decision: NetworkDecision, decision: NetworkDecision,
upstream: Option<Arc<UpstreamManager>>,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
stats: Arc<crate::stats::Stats>, stats: Arc<crate::stats::Stats>,
me_keepalive_enabled: bool, me_keepalive_enabled: bool,
@ -148,6 +151,7 @@ impl MePool {
writers: Arc::new(RwLock::new(Vec::new())), writers: Arc::new(RwLock::new(Vec::new())),
rr: AtomicU64::new(0), rr: AtomicU64::new(0),
decision, decision,
upstream,
rng, rng,
proxy_tag, proxy_tag,
proxy_secret: Arc::new(RwLock::new(proxy_secret)), proxy_secret: Arc::new(RwLock::new(proxy_secret)),

View File

@ -8,7 +8,7 @@ use tracing::{debug, info, warn};
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
use crate::network::probe::is_bogon; use crate::network::probe::is_bogon;
use crate::network::stun::{stun_probe_dual, IpFamily, StunProbeResult}; use crate::network::stun::{stun_probe_dual, stun_probe_family_with_bind, IpFamily};
use super::MePool; use super::MePool;
use std::time::Instant; use std::time::Instant;
@ -52,6 +52,7 @@ impl MePool {
servers: &[String], servers: &[String],
family: IpFamily, family: IpFamily,
attempt: u8, attempt: u8,
bind_ip: Option<IpAddr>,
) -> (Vec<String>, Option<std::net::SocketAddr>) { ) -> (Vec<String>, Option<std::net::SocketAddr>) {
let mut join_set = JoinSet::new(); let mut join_set = JoinSet::new();
let mut next_idx = 0usize; let mut next_idx = 0usize;
@ -64,7 +65,11 @@ impl MePool {
let stun_addr = servers[next_idx].clone(); let stun_addr = servers[next_idx].clone();
next_idx += 1; next_idx += 1;
join_set.spawn(async move { join_set.spawn(async move {
let res = timeout(STUN_BATCH_TIMEOUT, stun_probe_dual(&stun_addr)).await; let res = timeout(
STUN_BATCH_TIMEOUT,
stun_probe_family_with_bind(&stun_addr, family, bind_ip),
)
.await;
(stun_addr, res) (stun_addr, res)
}); });
} }
@ -74,12 +79,7 @@ impl MePool {
}; };
match task { match task {
Ok((stun_addr, Ok(Ok(res)))) => { Ok((stun_addr, Ok(Ok(picked)))) => {
let picked: Option<StunProbeResult> = match family {
IpFamily::V4 => res.v4,
IpFamily::V6 => res.v6,
};
if let Some(result) = picked { if let Some(result) = picked {
live_servers.push(stun_addr.clone()); live_servers.push(stun_addr.clone());
let entry = best_by_ip let entry = best_by_ip
@ -207,10 +207,21 @@ impl MePool {
pub(super) async fn maybe_reflect_public_addr( pub(super) async fn maybe_reflect_public_addr(
&self, &self,
family: IpFamily, family: IpFamily,
bind_ip: Option<IpAddr>,
) -> Option<std::net::SocketAddr> { ) -> Option<std::net::SocketAddr> {
const STUN_CACHE_TTL: Duration = Duration::from_secs(600); const STUN_CACHE_TTL: Duration = Duration::from_secs(600);
let use_shared_cache = bind_ip.is_none();
if !use_shared_cache {
match (family, bind_ip) {
(IpFamily::V4, Some(IpAddr::V4(_)))
| (IpFamily::V6, Some(IpAddr::V6(_)))
| (_, None) => {}
_ => return None,
}
}
// Backoff window // Backoff window
if let Some(until) = *self.stun_backoff_until.read().await if use_shared_cache
&& let Some(until) = *self.stun_backoff_until.read().await
&& Instant::now() < until && Instant::now() < until
{ {
if let Ok(cache) = self.nat_reflection_cache.try_lock() { if let Ok(cache) = self.nat_reflection_cache.try_lock() {
@ -223,7 +234,9 @@ impl MePool {
return None; return None;
} }
if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { if use_shared_cache
&& let Ok(mut cache) = self.nat_reflection_cache.try_lock()
{
let slot = match family { let slot = match family {
IpFamily::V4 => &mut cache.v4, IpFamily::V4 => &mut cache.v4,
IpFamily::V6 => &mut cache.v6, IpFamily::V6 => &mut cache.v6,
@ -235,7 +248,11 @@ impl MePool {
} }
} }
let attempt = self.nat_probe_attempts.fetch_add(1, std::sync::atomic::Ordering::Relaxed); let attempt = if use_shared_cache {
self.nat_probe_attempts.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
} else {
0
};
let configured_servers = self.configured_stun_servers(); let configured_servers = self.configured_stun_servers();
let live_snapshot = self.nat_stun_live_servers.read().await.clone(); let live_snapshot = self.nat_stun_live_servers.read().await.clone();
let primary_servers = if live_snapshot.is_empty() { let primary_servers = if live_snapshot.is_empty() {
@ -245,12 +262,12 @@ impl MePool {
}; };
let (mut live_servers, mut selected_reflected) = self let (mut live_servers, mut selected_reflected) = self
.probe_stun_batch_for_family(&primary_servers, family, attempt) .probe_stun_batch_for_family(&primary_servers, family, attempt, bind_ip)
.await; .await;
if selected_reflected.is_none() && !configured_servers.is_empty() && primary_servers != configured_servers { if selected_reflected.is_none() && !configured_servers.is_empty() && primary_servers != configured_servers {
let (rediscovered_live, rediscovered_reflected) = self let (rediscovered_live, rediscovered_reflected) = self
.probe_stun_batch_for_family(&configured_servers, family, attempt) .probe_stun_batch_for_family(&configured_servers, family, attempt, bind_ip)
.await; .await;
live_servers = rediscovered_live; live_servers = rediscovered_live;
selected_reflected = rediscovered_reflected; selected_reflected = rediscovered_reflected;
@ -264,14 +281,18 @@ impl MePool {
} }
if let Some(reflected_addr) = selected_reflected { if let Some(reflected_addr) = selected_reflected {
if use_shared_cache {
self.nat_probe_attempts.store(0, std::sync::atomic::Ordering::Relaxed); self.nat_probe_attempts.store(0, std::sync::atomic::Ordering::Relaxed);
}
info!( info!(
family = ?family, family = ?family,
live_servers = live_server_count, live_servers = live_server_count,
"STUN-Quorum reached, IP: {}", "STUN-Quorum reached, IP: {}",
reflected_addr.ip() reflected_addr.ip()
); );
if let Ok(mut cache) = self.nat_reflection_cache.try_lock() { if use_shared_cache
&& let Ok(mut cache) = self.nat_reflection_cache.try_lock()
{
let slot = match family { let slot = match family {
IpFamily::V4 => &mut cache.v4, IpFamily::V4 => &mut cache.v4,
IpFamily::V6 => &mut cache.v6, IpFamily::V6 => &mut cache.v6,
@ -281,8 +302,10 @@ impl MePool {
return Some(reflected_addr); return Some(reflected_addr);
} }
if use_shared_cache {
let backoff = Duration::from_secs(60 * 2u64.pow((attempt as u32).min(6))); let backoff = Duration::from_secs(60 * 2u64.pow((attempt as u32).min(6)));
*self.stun_backoff_until.write().await = Some(Instant::now() + backoff); *self.stun_backoff_until.write().await = Some(Instant::now() + backoff);
}
None None
} }
} }

View File

@ -47,8 +47,8 @@ impl MePool {
return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into())); return Err(ProxyError::Proxy("proxy-secret too short for ME auth".into()));
} }
let (stream, _connect_ms) = self.connect_tcp(addr).await?; let (stream, _connect_ms, upstream_egress) = self.connect_tcp(addr).await?;
let hs = self.handshake_only(stream, addr, rng).await?; let hs = self.handshake_only(stream, addr, upstream_egress, rng).await?;
let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed); let writer_id = self.next_writer_id.fetch_add(1, Ordering::Relaxed);
let generation = self.current_generation(); let generation = self.current_generation();

View File

@ -14,5 +14,5 @@ pub use socket::*;
#[allow(unused_imports)] #[allow(unused_imports)]
pub use socks::*; pub use socks::*;
#[allow(unused_imports)] #[allow(unused_imports)]
pub use upstream::{DcPingResult, StartupPingResult, UpstreamManager}; pub use upstream::{DcPingResult, StartupPingResult, UpstreamEgressInfo, UpstreamManager, UpstreamRouteKind};
pub mod middle_proxy; pub mod middle_proxy;

View File

@ -5,11 +5,16 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use crate::error::{ProxyError, Result}; use crate::error::{ProxyError, Result};
#[derive(Debug, Clone, Copy)]
pub struct SocksBoundAddr {
pub addr: SocketAddr,
}
pub async fn connect_socks4( pub async fn connect_socks4(
stream: &mut TcpStream, stream: &mut TcpStream,
target: SocketAddr, target: SocketAddr,
user_id: Option<&str>, user_id: Option<&str>,
) -> Result<()> { ) -> Result<SocksBoundAddr> {
let ip = match target.ip() { let ip = match target.ip() {
IpAddr::V4(ip) => ip, IpAddr::V4(ip) => ip,
IpAddr::V6(_) => return Err(ProxyError::Proxy("SOCKS4 does not support IPv6".to_string())), IpAddr::V6(_) => return Err(ProxyError::Proxy("SOCKS4 does not support IPv6".to_string())),
@ -37,7 +42,12 @@ pub async fn connect_socks4(
return Err(ProxyError::Proxy(format!("SOCKS4 request rejected: code {}", resp[1]))); return Err(ProxyError::Proxy(format!("SOCKS4 request rejected: code {}", resp[1])));
} }
Ok(()) let bound_port = u16::from_be_bytes([resp[2], resp[3]]);
let bound_ip = IpAddr::from([resp[4], resp[5], resp[6], resp[7]]);
Ok(SocksBoundAddr {
addr: SocketAddr::new(bound_ip, bound_port),
})
} }
pub async fn connect_socks5( pub async fn connect_socks5(
@ -45,7 +55,7 @@ pub async fn connect_socks5(
target: SocketAddr, target: SocketAddr,
username: Option<&str>, username: Option<&str>,
password: Option<&str>, password: Option<&str>,
) -> Result<()> { ) -> Result<SocksBoundAddr> {
// 1. Auth negotiation // 1. Auth negotiation
// VER (1) | NMETHODS (1) | METHODS (variable) // VER (1) | NMETHODS (1) | METHODS (variable)
let mut methods = vec![0u8]; // No auth let mut methods = vec![0u8]; // No auth
@ -122,24 +132,36 @@ pub async fn connect_socks5(
return Err(ProxyError::Proxy(format!("SOCKS5 request failed: code {}", head[1]))); return Err(ProxyError::Proxy(format!("SOCKS5 request failed: code {}", head[1])));
} }
// Skip address part of response // Parse bound address from response.
match head[3] { let bound_addr = match head[3] {
1 => { // IPv4 1 => { // IPv4
let mut addr = [0u8; 4 + 2]; let mut addr = [0u8; 4 + 2];
stream.read_exact(&mut addr).await.map_err(ProxyError::Io)?; stream.read_exact(&mut addr).await.map_err(ProxyError::Io)?;
let ip = IpAddr::from([addr[0], addr[1], addr[2], addr[3]]);
let port = u16::from_be_bytes([addr[4], addr[5]]);
SocketAddr::new(ip, port)
}, },
3 => { // Domain 3 => { // Domain
let mut len = [0u8; 1]; let mut len = [0u8; 1];
stream.read_exact(&mut len).await.map_err(ProxyError::Io)?; stream.read_exact(&mut len).await.map_err(ProxyError::Io)?;
let mut addr = vec![0u8; len[0] as usize + 2]; let mut addr = vec![0u8; len[0] as usize + 2];
stream.read_exact(&mut addr).await.map_err(ProxyError::Io)?; stream.read_exact(&mut addr).await.map_err(ProxyError::Io)?;
// Domain-bound response is not useful for KDF IP material.
let port_pos = addr.len().saturating_sub(2);
let port = u16::from_be_bytes([addr[port_pos], addr[port_pos + 1]]);
SocketAddr::new(IpAddr::from([0, 0, 0, 0]), port)
}, },
4 => { // IPv6 4 => { // IPv6
let mut addr = [0u8; 16 + 2]; let mut addr = [0u8; 16 + 2];
stream.read_exact(&mut addr).await.map_err(ProxyError::Io)?; stream.read_exact(&mut addr).await.map_err(ProxyError::Io)?;
let ip = IpAddr::from(<[u8; 16]>::try_from(&addr[..16]).map_err(|_| {
ProxyError::Proxy("Invalid SOCKS5 IPv6 bound address".to_string())
})?);
let port = u16::from_be_bytes([addr[16], addr[17]]);
SocketAddr::new(ip, port)
}, },
_ => return Err(ProxyError::Proxy("Invalid address type in SOCKS5 response".to_string())), _ => return Err(ProxyError::Proxy("Invalid address type in SOCKS5 response".to_string())),
} };
Ok(()) Ok(SocksBoundAddr { addr: bound_addr })
} }

View File

@ -151,6 +151,21 @@ pub struct StartupPingResult {
pub both_available: bool, pub both_available: bool,
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UpstreamRouteKind {
Direct,
Socks4,
Socks5,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct UpstreamEgressInfo {
pub route_kind: UpstreamRouteKind,
pub local_addr: Option<SocketAddr>,
pub direct_bind_ip: Option<IpAddr>,
pub socks_bound_addr: Option<SocketAddr>,
}
// ============= Upstream Manager ============= // ============= Upstream Manager =============
#[derive(Clone)] #[derive(Clone)]
@ -316,6 +331,17 @@ impl UpstreamManager {
/// Connect to target through a selected upstream. /// Connect to target through a selected upstream.
pub async fn connect(&self, target: SocketAddr, dc_idx: Option<i16>, scope: Option<&str>) -> Result<TcpStream> { pub async fn connect(&self, target: SocketAddr, dc_idx: Option<i16>, scope: Option<&str>) -> Result<TcpStream> {
let (stream, _) = self.connect_with_details(target, dc_idx, scope).await?;
Ok(stream)
}
/// Connect to target through a selected upstream and return egress details.
pub async fn connect_with_details(
&self,
target: SocketAddr,
dc_idx: Option<i16>,
scope: Option<&str>,
) -> Result<(TcpStream, UpstreamEgressInfo)> {
let idx = self.select_upstream(dc_idx, scope).await let idx = self.select_upstream(dc_idx, scope).await
.ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?; .ok_or_else(|| ProxyError::Config("No upstreams available".to_string()))?;
@ -337,7 +363,7 @@ impl UpstreamManager {
}; };
match self.connect_via_upstream(&upstream, target, bind_rr).await { match self.connect_via_upstream(&upstream, target, bind_rr).await {
Ok(stream) => { Ok((stream, egress)) => {
let rtt_ms = start.elapsed().as_secs_f64() * 1000.0; let rtt_ms = start.elapsed().as_secs_f64() * 1000.0;
let mut guard = self.upstreams.write().await; let mut guard = self.upstreams.write().await;
if let Some(u) = guard.get_mut(idx) { if let Some(u) = guard.get_mut(idx) {
@ -351,7 +377,7 @@ impl UpstreamManager {
u.dc_latency[di].update(rtt_ms); u.dc_latency[di].update(rtt_ms);
} }
} }
Ok(stream) Ok((stream, egress))
}, },
Err(e) => { Err(e) => {
let mut guard = self.upstreams.write().await; let mut guard = self.upstreams.write().await;
@ -373,7 +399,7 @@ impl UpstreamManager {
config: &UpstreamConfig, config: &UpstreamConfig,
target: SocketAddr, target: SocketAddr,
bind_rr: Option<Arc<AtomicUsize>>, bind_rr: Option<Arc<AtomicUsize>>,
) -> Result<TcpStream> { ) -> Result<(TcpStream, UpstreamEgressInfo)> {
match &config.upstream_type { match &config.upstream_type {
UpstreamType::Direct { interface, bind_addresses } => { UpstreamType::Direct { interface, bind_addresses } => {
let bind_ip = Self::resolve_bind_address( let bind_ip = Self::resolve_bind_address(
@ -414,7 +440,16 @@ impl UpstreamManager {
return Err(ProxyError::Io(e)); return Err(ProxyError::Io(e));
} }
Ok(stream) let local_addr = stream.local_addr().ok();
Ok((
stream,
UpstreamEgressInfo {
route_kind: UpstreamRouteKind::Direct,
local_addr,
direct_bind_ip: bind_ip,
socks_bound_addr: None,
},
))
}, },
UpstreamType::Socks4 { address, interface, user_id } => { UpstreamType::Socks4 { address, interface, user_id } => {
let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS); let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS);
@ -467,16 +502,30 @@ impl UpstreamManager {
.filter(|s| !s.is_empty()); .filter(|s| !s.is_empty());
let _user_id: Option<&str> = scope.or(user_id.as_deref()); let _user_id: Option<&str> = scope.or(user_id.as_deref());
match tokio::time::timeout(connect_timeout, connect_socks4(&mut stream, target, _user_id)).await { let bound = match tokio::time::timeout(
Ok(Ok(())) => {} connect_timeout,
connect_socks4(&mut stream, target, _user_id),
)
.await
{
Ok(Ok(bound)) => bound,
Ok(Err(e)) => return Err(e), Ok(Err(e)) => return Err(e),
Err(_) => { Err(_) => {
return Err(ProxyError::ConnectionTimeout { return Err(ProxyError::ConnectionTimeout {
addr: target.to_string(), addr: target.to_string(),
}); });
} }
} };
Ok(stream) let local_addr = stream.local_addr().ok();
Ok((
stream,
UpstreamEgressInfo {
route_kind: UpstreamRouteKind::Socks4,
local_addr,
direct_bind_ip: None,
socks_bound_addr: Some(bound.addr),
},
))
}, },
UpstreamType::Socks5 { address, interface, username, password } => { UpstreamType::Socks5 { address, interface, username, password } => {
let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS); let connect_timeout = Duration::from_secs(DIRECT_CONNECT_TIMEOUT_SECS);
@ -531,21 +580,30 @@ impl UpstreamManager {
let _username: Option<&str> = scope.or(username.as_deref()); let _username: Option<&str> = scope.or(username.as_deref());
let _password: Option<&str> = scope.or(password.as_deref()); let _password: Option<&str> = scope.or(password.as_deref());
match tokio::time::timeout( let bound = match tokio::time::timeout(
connect_timeout, connect_timeout,
connect_socks5(&mut stream, target, _username, _password), connect_socks5(&mut stream, target, _username, _password),
) )
.await .await
{ {
Ok(Ok(())) => {} Ok(Ok(bound)) => bound,
Ok(Err(e)) => return Err(e), Ok(Err(e)) => return Err(e),
Err(_) => { Err(_) => {
return Err(ProxyError::ConnectionTimeout { return Err(ProxyError::ConnectionTimeout {
addr: target.to_string(), addr: target.to_string(),
}); });
} }
} };
Ok(stream) let local_addr = stream.local_addr().ok();
Ok((
stream,
UpstreamEgressInfo {
route_kind: UpstreamRouteKind::Socks5,
local_addr,
direct_bind_ip: None,
socks_bound_addr: Some(bound.addr),
},
))
}, },
} }
} }
@ -777,7 +835,7 @@ impl UpstreamManager {
target: SocketAddr, target: SocketAddr,
) -> Result<f64> { ) -> Result<f64> {
let start = Instant::now(); let start = Instant::now();
let _stream = self.connect_via_upstream(config, target, bind_rr).await?; let _ = self.connect_via_upstream(config, target, bind_rr).await?;
Ok(start.elapsed().as_secs_f64() * 1000.0) Ok(start.elapsed().as_secs_f64() * 1000.0)
} }