Hardened KDF-Tuple + NAT Probing + Paddings

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey
2026-06-14 16:15:41 +03:00
parent b153782597
commit d414c73c9b
14 changed files with 355 additions and 192 deletions

View File

@@ -208,6 +208,8 @@ pub(crate) async fn initialize_me_pool(
me_nat_probe,
None,
config.network.stun_servers.clone(),
config.network.stun_tcp_fallback,
config.network.http_ip_detect_urls.clone(),
config.general.stun_nat_probe_concurrency,
probe.detected_ipv6,
config.timeouts.me_one_retry,

View File

@@ -12,7 +12,7 @@ use tracing::{debug, info, warn};
use crate::config::{NetworkConfig, UpstreamConfig, UpstreamType};
use crate::error::Result;
use crate::network::stun::{
DualStunResult, IpFamily, StunProbeResult, stun_probe_family_with_bind,
DualStunResult, IpFamily, StunProbeResult, stun_probe_family_with_bind_and_tcp_fallback,
};
use crate::transport::UpstreamManager;
@@ -58,6 +58,7 @@ impl NetworkDecision {
}
const STUN_BATCH_TIMEOUT: Duration = Duration::from_secs(5);
const STUN_BATCH_TCP_FALLBACK_TIMEOUT: Duration = Duration::from_secs(12);
pub async fn run_probe(
config: &NetworkConfig,
@@ -81,8 +82,14 @@ pub async fn run_probe(
warn!("STUN probe is enabled but network.stun_servers is empty");
DualStunResult::default()
} else {
probe_stun_servers_parallel(&servers, stun_nat_probe_concurrency.max(1), None, None)
.await
probe_stun_servers_parallel(
&servers,
stun_nat_probe_concurrency.max(1),
None,
None,
config.stun_tcp_fallback,
)
.await
}
} else if nat_probe {
info!("STUN probe is disabled by network.stun_use=false");
@@ -163,6 +170,7 @@ pub async fn run_probe(
stun_nat_probe_concurrency.max(1),
bind_v4,
bind_v6,
config.stun_tcp_fallback,
)
.await;
if let Some(reflected) = direct_stun_res.v4.map(|r| r.reflected_addr) {
@@ -234,7 +242,7 @@ pub async fn run_probe(
Ok(probe)
}
async fn detect_public_ipv4_http(urls: &[String]) -> Option<Ipv4Addr> {
pub(crate) async fn detect_public_ipv4_http(urls: &[String]) -> Option<Ipv4Addr> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(3))
.build()
@@ -277,6 +285,7 @@ async fn probe_stun_servers_parallel(
concurrency: usize,
bind_v4: Option<IpAddr>,
bind_v6: Option<IpAddr>,
tcp_fallback: bool,
) -> DualStunResult {
let mut join_set = JoinSet::new();
let mut next_idx = 0usize;
@@ -288,9 +297,26 @@ async fn probe_stun_servers_parallel(
let stun_addr = servers[next_idx].clone();
next_idx += 1;
join_set.spawn(async move {
let res = timeout(STUN_BATCH_TIMEOUT, async {
let v4 = stun_probe_family_with_bind(&stun_addr, IpFamily::V4, bind_v4).await?;
let v6 = stun_probe_family_with_bind(&stun_addr, IpFamily::V6, bind_v6).await?;
let batch_timeout = if tcp_fallback {
STUN_BATCH_TCP_FALLBACK_TIMEOUT
} else {
STUN_BATCH_TIMEOUT
};
let res = timeout(batch_timeout, async {
let v4 = stun_probe_family_with_bind_and_tcp_fallback(
&stun_addr,
IpFamily::V4,
bind_v4,
tcp_fallback,
)
.await?;
let v6 = stun_probe_family_with_bind_and_tcp_fallback(
&stun_addr,
IpFamily::V6,
bind_v6,
tcp_fallback,
)
.await?;
Ok::<DualStunResult, crate::error::ProxyError>(DualStunResult { v4, v6 })
})
.await;

View File

@@ -4,7 +4,8 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::OnceLock;
use tokio::net::{UdpSocket, lookup_host};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpSocket, UdpSocket, lookup_host};
use tokio::time::{Duration, sleep, timeout};
use crate::crypto::SecureRandom;
@@ -36,9 +37,16 @@ pub struct DualStunResult {
}
pub async fn stun_probe_dual(stun_addr: &str) -> Result<DualStunResult> {
stun_probe_dual_with_tcp_fallback(stun_addr, false).await
}
pub async fn stun_probe_dual_with_tcp_fallback(
stun_addr: &str,
tcp_fallback: bool,
) -> Result<DualStunResult> {
let (v4, v6) = tokio::join!(
stun_probe_family(stun_addr, IpFamily::V4),
stun_probe_family(stun_addr, IpFamily::V6),
stun_probe_family_with_tcp_fallback(stun_addr, IpFamily::V4, tcp_fallback),
stun_probe_family_with_tcp_fallback(stun_addr, IpFamily::V6, tcp_fallback),
);
Ok(DualStunResult { v4: v4?, v6: v6? })
@@ -48,13 +56,44 @@ pub async fn stun_probe_family(
stun_addr: &str,
family: IpFamily,
) -> Result<Option<StunProbeResult>> {
stun_probe_family_with_bind(stun_addr, family, None).await
stun_probe_family_with_tcp_fallback(stun_addr, family, false).await
}
pub async fn stun_probe_family_with_tcp_fallback(
stun_addr: &str,
family: IpFamily,
tcp_fallback: bool,
) -> Result<Option<StunProbeResult>> {
stun_probe_family_with_bind_and_tcp_fallback(stun_addr, family, None, tcp_fallback).await
}
pub async fn stun_probe_family_with_bind(
stun_addr: &str,
family: IpFamily,
bind_ip: Option<IpAddr>,
) -> Result<Option<StunProbeResult>> {
stun_probe_family_with_bind_and_tcp_fallback(stun_addr, family, bind_ip, false).await
}
pub async fn stun_probe_family_with_bind_and_tcp_fallback(
stun_addr: &str,
family: IpFamily,
bind_ip: Option<IpAddr>,
tcp_fallback: bool,
) -> Result<Option<StunProbeResult>> {
let udp_attempts = if tcp_fallback { 1 } else { 3 };
let udp_result = stun_probe_family_udp(stun_addr, family, bind_ip, udp_attempts).await?;
if udp_result.is_some() || !tcp_fallback {
return Ok(udp_result);
}
stun_probe_family_tcp(stun_addr, family, bind_ip).await
}
async fn stun_probe_family_udp(
stun_addr: &str,
family: IpFamily,
bind_ip: Option<IpAddr>,
max_attempts: u8,
) -> Result<Option<StunProbeResult>> {
let bind_addr = match (family, bind_ip) {
(IpFamily::V4, Some(IpAddr::V4(ip))) => SocketAddr::new(IpAddr::V4(ip), 0),
@@ -94,12 +133,7 @@ pub async fn stun_probe_family_with_bind(
return Ok(None);
}
let mut req = [0u8; 20];
req[0..2].copy_from_slice(&0x0001u16.to_be_bytes()); // Binding Request
req[2..4].copy_from_slice(&0u16.to_be_bytes()); // length
req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes()); // magic cookie
stun_rng().fill(&mut req[8..20]); // transaction ID
let req = build_binding_request();
let mut buf = [0u8; 256];
let mut attempt = 0;
let mut backoff = Duration::from_secs(1);
@@ -115,7 +149,7 @@ pub async fn stun_probe_family_with_bind(
Ok(Err(e)) => return Err(ProxyError::Proxy(format!("STUN recv failed: {e}"))),
Err(_) => {
attempt += 1;
if attempt >= 3 {
if attempt >= max_attempts {
return Ok(None);
}
sleep(backoff).await;
@@ -128,19 +162,139 @@ pub async fn stun_probe_family_with_bind(
return Ok(None);
}
let magic = 0x2112A442u32.to_be_bytes();
let txid = &req[8..20];
let mut idx = 20;
while idx + 4 <= n {
let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().unwrap());
let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().unwrap()) as usize;
idx += 4;
if idx + alen > n {
break;
}
if let Some(reflected_addr) = parse_reflected_addr(&buf[..n], txid) {
let local_addr = socket
.local_addr()
.map_err(|e| ProxyError::Proxy(format!("STUN local_addr failed: {e}")))?;
return Ok(Some(StunProbeResult {
local_addr,
reflected_addr,
family,
}));
}
}
match atype {
0x0020 /* XOR-MAPPED-ADDRESS */ | 0x0001 /* MAPPED-ADDRESS */ => {
Ok(None)
}
async fn stun_probe_family_tcp(
stun_addr: &str,
family: IpFamily,
bind_ip: Option<IpAddr>,
) -> Result<Option<StunProbeResult>> {
let target_addr = match resolve_stun_addr(stun_addr, family).await? {
Some(addr) => addr,
None => return Ok(None),
};
let socket = match family {
IpFamily::V4 => TcpSocket::new_v4(),
IpFamily::V6 => TcpSocket::new_v6(),
}
.map_err(|e| ProxyError::Proxy(format!("STUN TCP socket failed: {e}")))?;
match (family, bind_ip) {
(IpFamily::V4, Some(IpAddr::V4(ip))) => {
if socket.bind(SocketAddr::new(IpAddr::V4(ip), 0)).is_err() {
return Ok(None);
}
}
(IpFamily::V6, Some(IpAddr::V6(ip))) => {
if socket.bind(SocketAddr::new(IpAddr::V6(ip), 0)).is_err() {
return Ok(None);
}
}
(IpFamily::V4, Some(IpAddr::V6(_))) | (IpFamily::V6, Some(IpAddr::V4(_))) => {
return Ok(None);
}
(_, None) => {}
}
let connect_res = timeout(Duration::from_secs(3), socket.connect(target_addr)).await;
let mut stream = match connect_res {
Ok(Ok(stream)) => stream,
Ok(Err(e))
if family == IpFamily::V6
&& matches!(
e.kind(),
std::io::ErrorKind::NetworkUnreachable
| std::io::ErrorKind::HostUnreachable
| std::io::ErrorKind::Unsupported
| std::io::ErrorKind::NetworkDown
) =>
{
return Ok(None);
}
Ok(Err(e)) => return Err(ProxyError::Proxy(format!("STUN TCP connect failed: {e}"))),
Err(_) => return Ok(None),
};
let req = build_binding_request();
timeout(Duration::from_secs(3), stream.write_all(&req))
.await
.map_err(|_| ProxyError::Proxy("STUN TCP send timeout".to_string()))?
.map_err(|e| ProxyError::Proxy(format!("STUN TCP send failed: {e}")))?;
let mut header = [0u8; 20];
timeout(Duration::from_secs(3), stream.read_exact(&mut header))
.await
.map_err(|_| ProxyError::Proxy("STUN TCP header timeout".to_string()))?
.map_err(|e| ProxyError::Proxy(format!("STUN TCP header read failed: {e}")))?;
let body_len = u16::from_be_bytes([header[2], header[3]]) as usize;
if body_len > 236 {
return Ok(None);
}
let mut buf = [0u8; 256];
buf[..20].copy_from_slice(&header);
if body_len > 0 {
timeout(
Duration::from_secs(3),
stream.read_exact(&mut buf[20..20 + body_len]),
)
.await
.map_err(|_| ProxyError::Proxy("STUN TCP body timeout".to_string()))?
.map_err(|e| ProxyError::Proxy(format!("STUN TCP body read failed: {e}")))?;
}
let txid = &req[8..20];
let Some(reflected_addr) = parse_reflected_addr(&buf[..20 + body_len], txid) else {
return Ok(None);
};
let local_addr = stream
.local_addr()
.map_err(|e| ProxyError::Proxy(format!("STUN TCP local_addr failed: {e}")))?;
Ok(Some(StunProbeResult {
local_addr,
reflected_addr,
family,
}))
}
fn build_binding_request() -> [u8; 20] {
let mut req = [0u8; 20];
req[0..2].copy_from_slice(&0x0001u16.to_be_bytes());
req[2..4].copy_from_slice(&0u16.to_be_bytes());
req[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes());
stun_rng().fill(&mut req[8..20]);
req
}
fn parse_reflected_addr(buf: &[u8], txid: &[u8]) -> Option<SocketAddr> {
if buf.len() < 20 {
return None;
}
let magic = 0x2112A442u32.to_be_bytes();
let mut idx = 20;
while idx + 4 <= buf.len() {
let atype = u16::from_be_bytes(buf[idx..idx + 2].try_into().ok()?);
let alen = u16::from_be_bytes(buf[idx + 2..idx + 4].try_into().ok()?) as usize;
idx += 4;
if idx + alen > buf.len() {
break;
}
match atype {
0x0020 | 0x0001 => {
if alen < 8 {
break;
}
@@ -157,7 +311,6 @@ pub async fn stun_probe_family_with_bind(
let raw_ip = &buf[idx + 4..idx + 4 + len_check];
let mut port = u16::from_be_bytes(port_bytes);
let reflected_ip = if atype == 0x0020 {
port ^= ((magic[0] as u16) << 8) | magic[1] as u16;
match family_byte {
@@ -172,7 +325,9 @@ pub async fn stun_probe_family_with_bind(
}
0x02 => {
let mut ip = [0u8; 16];
let xor_key = [magic.as_slice(), txid].concat();
let mut xor_key = [0u8; 16];
xor_key[..4].copy_from_slice(&magic);
xor_key[4..].copy_from_slice(txid.get(..12)?);
for (i, b) in raw_ip.iter().enumerate().take(16) {
ip[i] = *b ^ xor_key[i];
}
@@ -185,34 +340,24 @@ pub async fn stun_probe_family_with_bind(
}
} else {
match family_byte {
0x01 => IpAddr::V4(Ipv4Addr::new(raw_ip[0], raw_ip[1], raw_ip[2], raw_ip[3])),
0x02 => IpAddr::V6(Ipv6Addr::from(<[u8; 16]>::try_from(raw_ip).unwrap())),
0x01 => {
IpAddr::V4(Ipv4Addr::new(raw_ip[0], raw_ip[1], raw_ip[2], raw_ip[3]))
}
0x02 => IpAddr::V6(Ipv6Addr::from(<[u8; 16]>::try_from(raw_ip).ok()?)),
_ => {
idx += (alen + 3) & !3;
continue;
}
}
};
let reflected_addr = SocketAddr::new(reflected_ip, port);
let local_addr = socket
.local_addr()
.map_err(|e| ProxyError::Proxy(format!("STUN local_addr failed: {e}")))?;
return Ok(Some(StunProbeResult {
local_addr,
reflected_addr,
family,
}));
return Some(SocketAddr::new(reflected_ip, port));
}
_ => {}
}
idx += (alen + 3) & !3;
}
idx += (alen + 3) & !3;
}
Ok(None)
None
}
async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result<Option<SocketAddr>> {
@@ -245,3 +390,52 @@ async fn resolve_stun_addr(stun_addr: &str, family: IpFamily) -> Result<Option<S
});
Ok(target)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_reflected_addr_reads_mapped_ipv4() {
let txid = [0u8; 12];
let mut response = [0u8; 32];
response[0..2].copy_from_slice(&0x0101u16.to_be_bytes());
response[2..4].copy_from_slice(&12u16.to_be_bytes());
response[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes());
response[20..22].copy_from_slice(&0x0001u16.to_be_bytes());
response[22..24].copy_from_slice(&8u16.to_be_bytes());
response[25] = 0x01;
response[26..28].copy_from_slice(&443u16.to_be_bytes());
response[28..32].copy_from_slice(&[203, 0, 113, 9]);
let reflected = parse_reflected_addr(&response, &txid).unwrap();
assert_eq!(reflected, SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 9)), 443));
}
#[test]
fn parse_reflected_addr_reads_xor_mapped_ipv4() {
let txid = [0u8; 12];
let magic = 0x2112A442u32.to_be_bytes();
let port = 443u16;
let ip = [203u8, 0, 113, 9];
let xport = port ^ (((magic[0] as u16) << 8) | magic[1] as u16);
let xip = [
ip[0] ^ magic[0],
ip[1] ^ magic[1],
ip[2] ^ magic[2],
ip[3] ^ magic[3],
];
let mut response = [0u8; 32];
response[0..2].copy_from_slice(&0x0101u16.to_be_bytes());
response[2..4].copy_from_slice(&12u16.to_be_bytes());
response[4..8].copy_from_slice(&0x2112A442u32.to_be_bytes());
response[20..22].copy_from_slice(&0x0020u16.to_be_bytes());
response[22..24].copy_from_slice(&8u16.to_be_bytes());
response[25] = 0x01;
response[26..28].copy_from_slice(&xport.to_be_bytes());
response[28..32].copy_from_slice(&xip);
let reflected = parse_reflected_addr(&response, &txid).unwrap();
assert_eq!(reflected, SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 9)), 443));
}
}

View File

@@ -236,7 +236,8 @@ pub fn is_valid_secure_payload_len(data_len: usize) -> bool {
}
/// Compute Secure Intermediate payload length from wire length.
/// Secure mode strips up to 3 random tail bytes by truncating to 4-byte boundary.
/// Secure mode cannot distinguish full-word padding from payload, so only the
/// non-aligned tail bytes are stripped.
pub fn secure_payload_len_from_wire_len(wire_len: usize) -> Option<usize> {
if wire_len < 4 {
return None;
@@ -245,13 +246,13 @@ pub fn secure_payload_len_from_wire_len(wire_len: usize) -> Option<usize> {
}
/// Generate padding length for Secure Intermediate protocol.
/// Data must be 4-byte aligned; padding is 1..=3 so total is never divisible by 4.
/// Telegram Desktop uses a 4-bit random padding length for VersionD packets.
pub fn secure_padding_len(data_len: usize, rng: &SecureRandom) -> usize {
debug_assert!(
is_valid_secure_payload_len(data_len),
"Secure payload must be 4-byte aligned, got {data_len}"
);
rng.range(3) + 1
rng.range(16)
}
// ============= Timeouts =============
@@ -424,21 +425,15 @@ mod tests {
}
#[test]
fn secure_padding_never_produces_aligned_total() {
fn secure_padding_matches_tdesktop_range() {
let rng = SecureRandom::new();
for data_len in (0..1000).step_by(4) {
for _ in 0..100 {
let padding = secure_padding_len(data_len, &rng);
assert!(
padding <= 3,
padding <= 15,
"padding out of range: data_len={data_len}, padding={padding}"
);
assert_ne!(
(data_len + padding) % 4,
0,
"invariant violated: data_len={data_len}, padding={padding}, total={}",
data_len + padding
);
}
}
}
@@ -454,6 +449,16 @@ mod tests {
}
}
#[test]
fn secure_wire_len_preserves_full_word_tail() {
let payload_len = 64;
for padding in [4usize, 8, 12] {
let wire_len = payload_len + padding;
let recovered = secure_payload_len_from_wire_len(wire_len);
assert_eq!(recovered, Some(wire_len));
}
}
#[test]
fn secure_wire_len_rejects_too_short_frames() {
assert_eq!(secure_payload_len_from_wire_len(0), None);

View File

@@ -18,7 +18,7 @@ use tokio::time::timeout;
use tracing::{debug, info, warn};
use crate::config::MeSocksKdfPolicy;
use crate::crypto::{SecureRandom, build_middleproxy_prekey, derive_middleproxy_keys, sha256};
use crate::crypto::{SecureRandom, derive_middleproxy_keys};
use crate::error::{ProxyError, Result};
use crate::network::IpFamily;
use crate::network::probe::is_bogon;
@@ -292,14 +292,15 @@ impl MePool {
BndPortStatus::Error
};
record_bnd_status(bnd_addr_status, bnd_port_status, raw_socks_bound_addr);
let reflected = if let Some(bound) = socks_bound_addr {
let socks_bound_kdf_addr = socks_bound_addr.filter(|bound| bound.port() != 0);
let reflected = if let Some(bound) = socks_bound_kdf_addr {
Some(bound)
} else if is_socks_route {
match self.socks_kdf_policy() {
MeSocksKdfPolicy::Strict => {
self.stats.increment_me_socks_kdf_strict_reject();
return Err(ProxyError::InvalidHandshake(
"SOCKS route returned no valid BND.ADDR for ME KDF (strict policy)"
"SOCKS route returned no valid BND tuple for ME KDF (strict policy)"
.to_string(),
));
}
@@ -323,16 +324,14 @@ impl MePool {
let local_addr_nat = self.translate_our_addr_with_reflection(local_addr, reflected);
let peer_addr_nat =
SocketAddr::new(self.translate_ip_for_nat(peer_addr.ip()), peer_addr.port());
let client_addr_for_kdf = socks_bound_kdf_addr.unwrap_or(local_addr_nat);
if let Some(upstream_info) = upstream_egress {
let client_ip_for_kdf = socks_bound_addr
.map(|value| value.ip())
.unwrap_or(local_addr_nat.ip());
record_upstream_bnd_status(
upstream_info.upstream_id,
bnd_addr_status,
bnd_port_status,
raw_socks_bound_addr,
Some(client_ip_for_kdf),
Some(client_addr_for_kdf.ip()),
);
}
let (mut rd, mut wr) = tokio::io::split(stream);
@@ -409,6 +408,7 @@ impl MePool {
info!(
%local_addr,
%local_addr_nat,
%client_addr_for_kdf,
reflected_ip = reflected.map(|r| r.ip()).as_ref().map(ToString::to_string),
%peer_addr,
%transport_peer_addr,
@@ -422,16 +422,14 @@ impl MePool {
let ts_bytes = crypto_ts.to_le_bytes();
let server_port_bytes = peer_addr_nat.port().to_le_bytes();
let socks_bound_port = socks_bound_addr
.map(|bound| bound.port())
.filter(|port| *port != 0);
let client_port_for_kdf = socks_bound_port.unwrap_or(local_addr_nat.port());
let socks_bound_port = socks_bound_kdf_addr.map(|bound| bound.port());
let client_port_for_kdf = client_addr_for_kdf.port();
let client_port_source = KdfClientPortSource::from_socks_bound_port(socks_bound_port);
let kdf_fingerprint = Self::kdf_material_fingerprint(
local_addr_nat.ip(),
client_addr_for_kdf.ip(),
peer_addr_nat,
reflected.map(|value| value.ip()),
socks_bound_addr.map(|value| value.ip()),
socks_bound_kdf_addr.map(|value| value.ip()),
client_port_source,
);
let previous_kdf_fingerprint = {
@@ -473,7 +471,7 @@ impl MePool {
let client_port_bytes = client_port_for_kdf.to_le_bytes();
let server_ip = extract_ip_material(peer_addr_nat);
let client_ip = extract_ip_material(local_addr_nat);
let client_ip = extract_ip_material(client_addr_for_kdf);
let (srv_ip_opt, clt_ip_opt, clt_v6_opt, srv_v6_opt, hs_our_ip, hs_peer_ip) =
match (server_ip, client_ip) {
@@ -494,38 +492,6 @@ impl MePool {
}
};
let diag_level: u8 = std::env::var("ME_DIAG")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(0);
let prekey_client = build_middleproxy_prekey(
&srv_nonce,
&my_nonce,
&ts_bytes,
srv_ip_opt.as_ref().map(|x| &x[..]),
&client_port_bytes,
b"CLIENT",
clt_ip_opt.as_ref().map(|x| &x[..]),
&server_port_bytes,
&secret,
clt_v6_opt.as_ref(),
srv_v6_opt.as_ref(),
);
let prekey_server = build_middleproxy_prekey(
&srv_nonce,
&my_nonce,
&ts_bytes,
srv_ip_opt.as_ref().map(|x| &x[..]),
&client_port_bytes,
b"SERVER",
clt_ip_opt.as_ref().map(|x| &x[..]),
&server_port_bytes,
&secret,
clt_v6_opt.as_ref(),
srv_v6_opt.as_ref(),
);
let (wk, wi) = derive_middleproxy_keys(
&srv_nonce,
&my_nonce,
@@ -556,47 +522,14 @@ impl MePool {
let requested_crc_mode = RpcChecksumMode::Crc32c;
let hs_payload = build_handshake_payload(
hs_our_ip,
local_addr.port(),
client_port_for_kdf,
hs_peer_ip,
peer_addr.port(),
peer_addr_nat.port(),
requested_crc_mode.advertised_flags(),
);
let hs_frame = build_rpc_frame(-1, &hs_payload, RpcChecksumMode::Crc32);
if diag_level >= 1 {
info!(
write_key = %hex_dump(&wk),
write_iv = %hex_dump(&wi),
read_key = %hex_dump(&rk),
read_iv = %hex_dump(&ri),
srv_ip = %srv_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(),
clt_ip = %clt_ip_opt.map(|ip| hex_dump(&ip)).unwrap_or_default(),
srv_port = %hex_dump(&server_port_bytes),
clt_port = %hex_dump(&client_port_bytes),
crypto_ts = %hex_dump(&ts_bytes),
nonce_srv = %hex_dump(&srv_nonce),
nonce_clt = %hex_dump(&my_nonce),
prekey_sha256_client = %hex_dump(&sha256(&prekey_client)),
prekey_sha256_server = %hex_dump(&sha256(&prekey_server)),
hs_plain = %hex_dump(&hs_frame),
proxy_secret_sha256 = %hex_dump(&sha256(&secret)),
"ME diag: derived keys and handshake plaintext"
);
}
if diag_level >= 2 {
info!(
prekey_client = %hex_dump(&prekey_client),
prekey_server = %hex_dump(&prekey_server),
"ME diag: full prekey buffers"
);
}
let (encrypted_hs, write_iv) = cbc_encrypt_padded(&wk, &wi, &hs_frame)?;
if diag_level >= 1 {
info!(
hs_cipher = %hex_dump(&encrypted_hs),
"ME diag: handshake ciphertext"
);
}
wr.write_all(&encrypted_hs).await.map_err(ProxyError::Io)?;
wr.flush().await.map_err(ProxyError::Io)?;

View File

@@ -1728,6 +1728,8 @@ mod tests {
false,
None,
Vec::new(),
false,
Vec::new(),
1,
None,
12,

View File

@@ -336,6 +336,8 @@ pub(super) struct NatRuntimeCore {
pub(super) nat_probe: bool,
pub(super) nat_stun: Option<String>,
pub(super) nat_stun_servers: Vec<String>,
pub(super) stun_tcp_fallback: bool,
pub(super) http_ip_detect_urls: Vec<String>,
pub(super) nat_stun_live_servers: Arc<RwLock<Vec<String>>>,
pub(super) nat_probe_concurrency: usize,
pub(super) detected_ipv6: Option<Ipv6Addr>,
@@ -484,6 +486,8 @@ impl MePool {
nat_probe: bool,
nat_stun: Option<String>,
nat_stun_servers: Vec<String>,
stun_tcp_fallback: bool,
http_ip_detect_urls: Vec<String>,
nat_probe_concurrency: usize,
detected_ipv6: Option<Ipv6Addr>,
me_one_retry: u8,
@@ -706,6 +710,8 @@ impl MePool {
nat_probe,
nat_stun,
nat_stun_servers,
stun_tcp_fallback,
http_ip_detect_urls,
nat_stun_live_servers: Arc::new(RwLock::new(Vec::new())),
nat_probe_concurrency: nat_probe_concurrency.max(1),
detected_ipv6,

View File

@@ -1,19 +1,22 @@
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr};
use std::net::IpAddr;
use std::time::Duration;
use tokio::task::JoinSet;
use tokio::time::timeout;
use tracing::{debug, info, warn};
use tracing::{debug, info};
use crate::error::{ProxyError, Result};
use crate::network::probe::is_bogon;
use crate::network::stun::{IpFamily, stun_probe_dual, stun_probe_family_with_bind};
use crate::network::probe::{detect_public_ipv4_http, is_bogon};
use crate::network::stun::{
IpFamily, stun_probe_dual_with_tcp_fallback, stun_probe_family_with_bind_and_tcp_fallback,
};
use super::MePool;
use std::time::Instant;
const STUN_BATCH_TIMEOUT: Duration = Duration::from_secs(5);
const STUN_BATCH_TCP_FALLBACK_TIMEOUT: Duration = Duration::from_secs(12);
#[allow(dead_code)]
pub async fn stun_probe(stun_addr: Option<String>) -> Result<crate::network::stun::DualStunResult> {
@@ -28,15 +31,14 @@ pub async fn stun_probe(stun_addr: Option<String>) -> Result<crate::network::stu
"STUN server is not configured".to_string(),
));
}
stun_probe_dual(&stun_addr).await
stun_probe_dual_with_tcp_fallback(&stun_addr, false).await
}
#[allow(dead_code)]
pub async fn detect_public_ip() -> Option<IpAddr> {
fetch_public_ipv4_with_retry()
let urls = crate::config::defaults::default_http_ip_detect_urls();
detect_public_ipv4_http(&urls)
.await
.ok()
.flatten()
.map(IpAddr::V4)
}
@@ -65,15 +67,26 @@ impl MePool {
let mut live_servers = Vec::new();
let mut best_by_ip: HashMap<IpAddr, (usize, std::net::SocketAddr)> = HashMap::new();
let concurrency = self.nat_runtime.nat_probe_concurrency.max(1);
let tcp_fallback = self.nat_runtime.stun_tcp_fallback;
while next_idx < servers.len() || !join_set.is_empty() {
while next_idx < servers.len() && join_set.len() < concurrency {
let stun_addr = servers[next_idx].clone();
next_idx += 1;
join_set.spawn(async move {
let batch_timeout = if tcp_fallback {
STUN_BATCH_TCP_FALLBACK_TIMEOUT
} else {
STUN_BATCH_TIMEOUT
};
let res = timeout(
STUN_BATCH_TIMEOUT,
stun_probe_family_with_bind(&stun_addr, family, bind_ip),
batch_timeout,
stun_probe_family_with_bind_and_tcp_fallback(
&stun_addr,
family,
bind_ip,
tcp_fallback,
),
)
.await;
(stun_addr, res)
@@ -193,6 +206,10 @@ impl MePool {
return self.nat_runtime.nat_ip_cfg;
}
if !self.nat_runtime.nat_probe {
return None;
}
if !(is_bogon(local_ip) || local_ip.is_loopback() || local_ip.is_unspecified()) {
return None;
}
@@ -201,21 +218,15 @@ impl MePool {
return Some(ip);
}
match fetch_public_ipv4_with_retry().await {
Ok(Some(ip)) => {
{
let mut guard = self.nat_runtime.nat_ip_detected.write().await;
*guard = Some(IpAddr::V4(ip));
}
info!(public_ip = %ip, "Auto-detected public IP for NAT translation");
Some(IpAddr::V4(ip))
}
Ok(None) => None,
Err(e) => {
warn!(error = %e, "Failed to auto-detect public IP");
None
}
let Some(ip) = detect_public_ipv4_http(&self.nat_runtime.http_ip_detect_urls).await else {
return None;
};
{
let mut guard = self.nat_runtime.nat_ip_detected.write().await;
*guard = Some(IpAddr::V4(ip));
}
info!(public_ip = %ip, "Auto-detected public IP for NAT translation");
Some(IpAddr::V4(ip))
}
pub(super) async fn maybe_reflect_public_addr(
@@ -365,31 +376,3 @@ impl MePool {
None
}
}
async fn fetch_public_ipv4_with_retry() -> Result<Option<Ipv4Addr>> {
let providers = [
"https://checkip.amazonaws.com",
"http://v4.ident.me",
"http://ipv4.icanhazip.com",
];
for url in providers {
if let Ok(Some(ip)) = fetch_public_ipv4_once(url).await {
return Ok(Some(ip));
}
}
Ok(None)
}
async fn fetch_public_ipv4_once(url: &str) -> Result<Option<Ipv4Addr>> {
let res = reqwest::get(url)
.await
.map_err(|e| ProxyError::Proxy(format!("public IP detection request failed: {e}")))?;
let text = res
.text()
.await
.map_err(|e| ProxyError::Proxy(format!("public IP detection read failed: {e}")))?;
let ip = text.trim().parse().ok();
Ok(ip)
}

View File

@@ -38,6 +38,8 @@ async fn make_pool(
false,
None,
Vec::new(),
false,
Vec::new(),
1,
None,
12,

View File

@@ -36,6 +36,8 @@ async fn make_pool(
false,
None,
Vec::new(),
false,
Vec::new(),
1,
None,
12,

View File

@@ -31,6 +31,8 @@ async fn make_pool(me_pool_drain_threshold: u64) -> Arc<MePool> {
false,
None,
Vec::new(),
false,
Vec::new(),
1,
None,
12,

View File

@@ -20,6 +20,8 @@ async fn make_pool() -> Arc<MePool> {
false,
None,
Vec::new(),
false,
Vec::new(),
1,
None,
12,

View File

@@ -25,6 +25,8 @@ async fn make_pool() -> Arc<MePool> {
false,
None,
Vec::new(),
false,
Vec::new(),
1,
None,
12,

View File

@@ -31,6 +31,8 @@ async fn make_pool() -> (Arc<MePool>, Arc<SecureRandom>) {
false,
None,
Vec::new(),
false,
Vec::new(),
1,
None,
12,