From ffca94b60af3f8e90804709a037d14d6aa0ce005 Mon Sep 17 00:00:00 2001 From: ivulit Date: Sat, 28 Feb 2026 19:09:32 +0300 Subject: [PATCH] fix: pass correct dst address to outgoing PROXY protocol header Previously handle_bad_client used stream.local_addr() (the ephemeral socket to the mask backend) as the dst in the outgoing PROXY protocol header. This is wrong: the dst should be the address telemt is listening on, or the dst from the incoming PROXY protocol header if one was present. - handle_bad_client now receives local_addr from the caller - handle_client_stream resolves local_addr from PROXY protocol info.dst_addr or falls back to a synthetic address based on config.server.port - RunningClientHandler.do_handshake resolves local_addr from stream.local_addr() overridden by PROXY protocol info.dst_addr when present, and passes it down to handle_tls_client / handle_direct_client - masking.rs uses the caller-supplied local_addr directly, eliminating the stream.local_addr() call --- src/proxy/client.rs | 36 +++++++++++++++++++++++++----------- src/proxy/masking.rs | 28 +++++++++++----------------- 2 files changed, 36 insertions(+), 28 deletions(-) diff --git a/src/proxy/client.rs b/src/proxy/client.rs index d8bbc48..4bc4b65 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -91,6 +91,11 @@ where stats.increment_connects_all(); let mut real_peer = normalize_ip(peer); + // For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst + let mut local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port) + .parse() + .unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); + if proxy_protocol_enabled { match parse_proxy_protocol(&mut stream, peer).await { Ok(info) => { @@ -101,6 +106,9 @@ where "PROXY protocol header parsed" ); real_peer = normalize_ip(info.src_addr); + if let Some(dst) = info.dst_addr { + local_addr = dst; + } } Err(e) => { stats.increment_connects_bad(); @@ -119,11 +127,6 @@ where let beobachten_for_timeout = beobachten.clone(); let peer_for_timeout = real_peer.ip(); - // For non-TCP streams, use a synthetic local address - let local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port) - .parse() - .unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap()); - // Phase 1: handshake (with timeout) let outcome = match timeout(handshake_timeout, async { let mut first_bytes = [0u8; 5]; @@ -144,6 +147,7 @@ where writer, &first_bytes, real_peer, + local_addr, &config, &beobachten, ) @@ -169,6 +173,7 @@ where writer, &handshake, real_peer, + local_addr, &config, &beobachten, ) @@ -213,6 +218,7 @@ where writer, &first_bytes, real_peer, + local_addr, &config, &beobachten, ) @@ -238,6 +244,7 @@ where writer, &handshake, real_peer, + local_addr, &config, &beobachten, ) @@ -405,6 +412,8 @@ impl RunningClientHandler { } async fn do_handshake(mut self) -> Result { + let mut local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; + if self.proxy_protocol_enabled { match parse_proxy_protocol(&mut self.stream, self.peer).await { Ok(info) => { @@ -415,6 +424,9 @@ impl RunningClientHandler { "PROXY protocol header parsed" ); self.peer = normalize_ip(info.src_addr); + if let Some(dst) = info.dst_addr { + local_addr = dst; + } } Err(e) => { self.stats.increment_connects_bad(); @@ -440,13 +452,13 @@ impl RunningClientHandler { debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); if is_tls { - self.handle_tls_client(first_bytes).await + self.handle_tls_client(first_bytes, local_addr).await } else { - self.handle_direct_client(first_bytes).await + self.handle_direct_client(first_bytes, local_addr).await } } - async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result { + async fn handle_tls_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result { let peer = self.peer; let _ip_tracker = self.ip_tracker.clone(); @@ -463,6 +475,7 @@ impl RunningClientHandler { writer, &first_bytes, peer, + local_addr, &self.config, &self.beobachten, ) @@ -479,7 +492,6 @@ impl RunningClientHandler { let stats = self.stats.clone(); let buffer_pool = self.buffer_pool.clone(); - let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; let (read_half, write_half) = self.stream.into_split(); let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( @@ -502,6 +514,7 @@ impl RunningClientHandler { writer, &handshake, peer, + local_addr, &config, &self.beobachten, ) @@ -558,7 +571,7 @@ impl RunningClientHandler { ))) } - async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result { + async fn handle_direct_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result { let peer = self.peer; let _ip_tracker = self.ip_tracker.clone(); @@ -571,6 +584,7 @@ impl RunningClientHandler { writer, &first_bytes, peer, + local_addr, &self.config, &self.beobachten, ) @@ -587,7 +601,6 @@ impl RunningClientHandler { let stats = self.stats.clone(); let buffer_pool = self.buffer_pool.clone(); - let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?; let (read_half, write_half) = self.stream.into_split(); let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( @@ -609,6 +622,7 @@ impl RunningClientHandler { writer, &handshake, peer, + local_addr, &config, &self.beobachten, ) diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 8f19b40..b1e69d4 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -55,6 +55,7 @@ pub async fn handle_bad_client( writer: W, initial_data: &[u8], peer: SocketAddr, + local_addr: SocketAddr, config: &ProxyConfig, beobachten: &BeobachtenStore, ) @@ -126,23 +127,16 @@ where let proxy_header: Option> = match config.censorship.mask_proxy_protocol { 0 => None, version => { - let header = if let Ok(local_addr) = stream.local_addr() { - match version { - 2 => ProxyProtocolV2Builder::new().with_addrs(peer, local_addr).build(), - _ => match (peer, local_addr) { - (SocketAddr::V4(src), SocketAddr::V4(dst)) => - ProxyProtocolV1Builder::new().tcp4(src.into(), dst.into()).build(), - (SocketAddr::V6(src), SocketAddr::V6(dst)) => - ProxyProtocolV1Builder::new().tcp6(src.into(), dst.into()).build(), - _ => - ProxyProtocolV1Builder::new().build(), - }, - } - } else { - match version { - 2 => ProxyProtocolV2Builder::new().build(), - _ => ProxyProtocolV1Builder::new().build(), - } + let header = match version { + 2 => ProxyProtocolV2Builder::new().with_addrs(peer, local_addr).build(), + _ => match (peer, local_addr) { + (SocketAddr::V4(src), SocketAddr::V4(dst)) => + ProxyProtocolV1Builder::new().tcp4(src.into(), dst.into()).build(), + (SocketAddr::V6(src), SocketAddr::V6(dst)) => + ProxyProtocolV1Builder::new().tcp6(src.into(), dst.into()).build(), + _ => + ProxyProtocolV1Builder::new().build(), + }, }; Some(header) }