From 073eacbb37f6b9fd256763f23d907ecec4989f40 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Fri, 17 Apr 2026 10:42:58 +0300 Subject: [PATCH] PROXY Protocol V2 UNKNOWN/LOCAL misuse fixes for TLS-Fetcher by #713 Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/tls_front/fetcher.rs | 118 +++++++++++++++++++++++++++++++++------ 1 file changed, 100 insertions(+), 18 deletions(-) diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs index 45d56ce..2c37c34 100644 --- a/src/tls_front/fetcher.rs +++ b/src/tls_front/fetcher.rs @@ -1,6 +1,7 @@ #![allow(clippy::too_many_arguments)] use dashmap::DashMap; +use std::net::SocketAddr; use std::sync::Arc; use std::sync::OnceLock; use std::time::{Duration, Instant}; @@ -793,6 +794,45 @@ async fn connect_tcp_with_upstream( )) } +fn socket_addrs_from_upstream_stream(stream: &UpstreamStream) -> (Option, Option) { + match stream { + UpstreamStream::Tcp(tcp) => (tcp.local_addr().ok(), tcp.peer_addr().ok()), + UpstreamStream::Shadowsocks(_) => (None, None), + } +} + +fn build_tls_fetch_proxy_header( + proxy_protocol: u8, + src_addr: Option, + dst_addr: Option, +) -> Option> { + match proxy_protocol { + 0 => None, + 2 => { + let header = match (src_addr, dst_addr) { + (Some(src @ SocketAddr::V4(_)), Some(dst @ SocketAddr::V4(_))) + | (Some(src @ SocketAddr::V6(_)), Some(dst @ SocketAddr::V6(_))) => { + ProxyProtocolV2Builder::new().with_addrs(src, dst).build() + } + _ => ProxyProtocolV2Builder::new().build(), + }; + Some(header) + } + _ => { + let header = match (src_addr, dst_addr) { + (Some(SocketAddr::V4(src)), Some(SocketAddr::V4(dst))) => ProxyProtocolV1Builder::new() + .tcp4(src.into(), dst.into()) + .build(), + (Some(SocketAddr::V6(src)), Some(SocketAddr::V6(dst))) => ProxyProtocolV1Builder::new() + .tcp6(src.into(), dst.into()) + .build(), + _ => ProxyProtocolV1Builder::new().build(), + }; + Some(header) + } + } +} + fn encode_tls13_certificate_message(cert_chain_der: &[Vec]) -> Option> { if cert_chain_der.is_empty() { return None; @@ -824,7 +864,7 @@ async fn fetch_via_raw_tls_stream( mut stream: S, sni: &str, connect_timeout: Duration, - proxy_protocol: u8, + proxy_header: Option>, profile: TlsFetchProfile, grease_enabled: bool, deterministic: bool, @@ -835,11 +875,7 @@ where let rng = SecureRandom::new(); let client_hello = build_client_hello(sni, &rng, profile, grease_enabled, deterministic); timeout(connect_timeout, async { - if proxy_protocol > 0 { - let header = match proxy_protocol { - 2 => ProxyProtocolV2Builder::new().build(), - _ => ProxyProtocolV1Builder::new().build(), - }; + if let Some(header) = proxy_header.as_ref() { stream.write_all(&header).await?; } stream.write_all(&client_hello).await?; @@ -921,11 +957,12 @@ async fn fetch_via_raw_tls( sock = %sock_path, "Raw TLS fetch using mask unix socket" ); + let proxy_header = build_tls_fetch_proxy_header(proxy_protocol, None, None); return fetch_via_raw_tls_stream( stream, sni, connect_timeout, - proxy_protocol, + proxy_header, profile, grease_enabled, deterministic, @@ -956,11 +993,13 @@ async fn fetch_via_raw_tls( let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route) .await?; + let (src_addr, dst_addr) = socket_addrs_from_upstream_stream(&stream); + let proxy_header = build_tls_fetch_proxy_header(proxy_protocol, src_addr, dst_addr); fetch_via_raw_tls_stream( stream, sni, connect_timeout, - proxy_protocol, + proxy_header, profile, grease_enabled, deterministic, @@ -972,17 +1011,13 @@ async fn fetch_via_rustls_stream( mut stream: S, host: &str, sni: &str, - proxy_protocol: u8, + proxy_header: Option>, ) -> Result where S: AsyncRead + AsyncWrite + Unpin, { // rustls handshake path for certificate and basic negotiated metadata. - if proxy_protocol > 0 { - let header = match proxy_protocol { - 2 => ProxyProtocolV2Builder::new().build(), - _ => ProxyProtocolV1Builder::new().build(), - }; + if let Some(header) = proxy_header.as_ref() { stream.write_all(&header).await?; stream.flush().await?; } @@ -1082,7 +1117,8 @@ async fn fetch_via_rustls( sock = %sock_path, "Rustls fetch using mask unix socket" ); - return fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await; + let proxy_header = build_tls_fetch_proxy_header(proxy_protocol, None, None); + return fetch_via_rustls_stream(stream, host, sni, proxy_header).await; } Ok(Err(e)) => { warn!( @@ -1108,7 +1144,9 @@ async fn fetch_via_rustls( let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route) .await?; - fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await + let (src_addr, dst_addr) = socket_addrs_from_upstream_stream(&stream); + let proxy_header = build_tls_fetch_proxy_header(proxy_protocol, src_addr, dst_addr); + fetch_via_rustls_stream(stream, host, sni, proxy_header).await } /// Fetch real TLS metadata with an adaptive multi-profile strategy. @@ -1278,11 +1316,13 @@ pub async fn fetch_real_tls( #[cfg(test)] mod tests { + use std::net::SocketAddr; use std::time::{Duration, Instant}; use super::{ - ProfileCacheValue, TlsFetchStrategy, build_client_hello, derive_behavior_profile, - encode_tls13_certificate_message, order_profiles, profile_cache, profile_cache_key, + ProfileCacheValue, TlsFetchStrategy, build_client_hello, build_tls_fetch_proxy_header, + derive_behavior_profile, encode_tls13_certificate_message, order_profiles, profile_cache, + profile_cache_key, }; use crate::config::TlsFetchProfile; use crate::crypto::SecureRandom; @@ -1423,4 +1463,46 @@ mod tests { assert_eq!(first, second); } + + #[test] + fn test_build_tls_fetch_proxy_header_v2_with_tcp_addrs() { + let src: SocketAddr = "198.51.100.10:42000".parse().expect("valid src"); + let dst: SocketAddr = "203.0.113.20:443".parse().expect("valid dst"); + let header = build_tls_fetch_proxy_header(2, Some(src), Some(dst)).expect("header"); + + assert_eq!( + &header[..12], + &[0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a] + ); + assert_eq!(header[12], 0x21); + assert_eq!(header[13], 0x11); + assert_eq!(u16::from_be_bytes([header[14], header[15]]), 12); + assert_eq!(&header[16..20], &[198, 51, 100, 10]); + assert_eq!(&header[20..24], &[203, 0, 113, 20]); + assert_eq!(u16::from_be_bytes([header[24], header[25]]), 42000); + assert_eq!(u16::from_be_bytes([header[26], header[27]]), 443); + } + + #[test] + fn test_build_tls_fetch_proxy_header_v2_mixed_family_falls_back_to_local_command() { + let src: SocketAddr = "198.51.100.10:42000".parse().expect("valid src"); + let dst: SocketAddr = "[2001:db8::20]:443".parse().expect("valid dst"); + let header = build_tls_fetch_proxy_header(2, Some(src), Some(dst)).expect("header"); + + assert_eq!(header[12], 0x20); + assert_eq!(header[13], 0x00); + assert_eq!(u16::from_be_bytes([header[14], header[15]]), 0); + } + + #[test] + fn test_build_tls_fetch_proxy_header_v1_with_tcp_addrs() { + let src: SocketAddr = "198.51.100.10:42000".parse().expect("valid src"); + let dst: SocketAddr = "203.0.113.20:443".parse().expect("valid dst"); + let header = build_tls_fetch_proxy_header(1, Some(src), Some(dst)).expect("header"); + + assert_eq!( + header, + b"PROXY TCP4 198.51.100.10 203.0.113.20 42000 443\r\n" + ); + } }