fix: pass correct dst address to outgoing PROXY protocol header

Previously handle_bad_client used stream.local_addr() (the ephemeral
socket to the mask backend) as the dst in the outgoing PROXY protocol
header. This is wrong: the dst should be the address telemt is listening
on, or the dst from the incoming PROXY protocol header if one was present.

- handle_bad_client now receives local_addr from the caller
- handle_client_stream resolves local_addr from PROXY protocol info.dst_addr
  or falls back to a synthetic address based on config.server.port
- RunningClientHandler.do_handshake resolves local_addr from stream.local_addr()
  overridden by PROXY protocol info.dst_addr when present, and passes it
  down to handle_tls_client / handle_direct_client
- masking.rs uses the caller-supplied local_addr directly, eliminating the
  stream.local_addr() call
This commit is contained in:
ivulit 2026-02-28 19:09:32 +03:00
parent c455869ef5
commit ffca94b60a
No known key found for this signature in database
2 changed files with 36 additions and 28 deletions

View File

@ -91,6 +91,11 @@ where
stats.increment_connects_all(); stats.increment_connects_all();
let mut real_peer = normalize_ip(peer); let mut real_peer = normalize_ip(peer);
// For non-TCP streams, use a synthetic local address; may be overridden by PROXY protocol dst
let mut local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
.parse()
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
if proxy_protocol_enabled { if proxy_protocol_enabled {
match parse_proxy_protocol(&mut stream, peer).await { match parse_proxy_protocol(&mut stream, peer).await {
Ok(info) => { Ok(info) => {
@ -101,6 +106,9 @@ where
"PROXY protocol header parsed" "PROXY protocol header parsed"
); );
real_peer = normalize_ip(info.src_addr); real_peer = normalize_ip(info.src_addr);
if let Some(dst) = info.dst_addr {
local_addr = dst;
}
} }
Err(e) => { Err(e) => {
stats.increment_connects_bad(); stats.increment_connects_bad();
@ -119,11 +127,6 @@ where
let beobachten_for_timeout = beobachten.clone(); let beobachten_for_timeout = beobachten.clone();
let peer_for_timeout = real_peer.ip(); let peer_for_timeout = real_peer.ip();
// For non-TCP streams, use a synthetic local address
let local_addr: SocketAddr = format!("0.0.0.0:{}", config.server.port)
.parse()
.unwrap_or_else(|_| "0.0.0.0:443".parse().unwrap());
// Phase 1: handshake (with timeout) // Phase 1: handshake (with timeout)
let outcome = match timeout(handshake_timeout, async { let outcome = match timeout(handshake_timeout, async {
let mut first_bytes = [0u8; 5]; let mut first_bytes = [0u8; 5];
@ -144,6 +147,7 @@ where
writer, writer,
&first_bytes, &first_bytes,
real_peer, real_peer,
local_addr,
&config, &config,
&beobachten, &beobachten,
) )
@ -169,6 +173,7 @@ where
writer, writer,
&handshake, &handshake,
real_peer, real_peer,
local_addr,
&config, &config,
&beobachten, &beobachten,
) )
@ -213,6 +218,7 @@ where
writer, writer,
&first_bytes, &first_bytes,
real_peer, real_peer,
local_addr,
&config, &config,
&beobachten, &beobachten,
) )
@ -238,6 +244,7 @@ where
writer, writer,
&handshake, &handshake,
real_peer, real_peer,
local_addr,
&config, &config,
&beobachten, &beobachten,
) )
@ -405,6 +412,8 @@ impl RunningClientHandler {
} }
async fn do_handshake(mut self) -> Result<HandshakeOutcome> { async fn do_handshake(mut self) -> Result<HandshakeOutcome> {
let mut local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
if self.proxy_protocol_enabled { if self.proxy_protocol_enabled {
match parse_proxy_protocol(&mut self.stream, self.peer).await { match parse_proxy_protocol(&mut self.stream, self.peer).await {
Ok(info) => { Ok(info) => {
@ -415,6 +424,9 @@ impl RunningClientHandler {
"PROXY protocol header parsed" "PROXY protocol header parsed"
); );
self.peer = normalize_ip(info.src_addr); self.peer = normalize_ip(info.src_addr);
if let Some(dst) = info.dst_addr {
local_addr = dst;
}
} }
Err(e) => { Err(e) => {
self.stats.increment_connects_bad(); self.stats.increment_connects_bad();
@ -440,13 +452,13 @@ impl RunningClientHandler {
debug!(peer = %peer, is_tls = is_tls, "Handshake type detected"); debug!(peer = %peer, is_tls = is_tls, "Handshake type detected");
if is_tls { if is_tls {
self.handle_tls_client(first_bytes).await self.handle_tls_client(first_bytes, local_addr).await
} else { } else {
self.handle_direct_client(first_bytes).await self.handle_direct_client(first_bytes, local_addr).await
} }
} }
async fn handle_tls_client(mut self, first_bytes: [u8; 5]) -> Result<HandshakeOutcome> { async fn handle_tls_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result<HandshakeOutcome> {
let peer = self.peer; let peer = self.peer;
let _ip_tracker = self.ip_tracker.clone(); let _ip_tracker = self.ip_tracker.clone();
@ -463,6 +475,7 @@ impl RunningClientHandler {
writer, writer,
&first_bytes, &first_bytes,
peer, peer,
local_addr,
&self.config, &self.config,
&self.beobachten, &self.beobachten,
) )
@ -479,7 +492,6 @@ impl RunningClientHandler {
let stats = self.stats.clone(); let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone(); let buffer_pool = self.buffer_pool.clone();
let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake( let (mut tls_reader, tls_writer, _tls_user) = match handle_tls_handshake(
@ -502,6 +514,7 @@ impl RunningClientHandler {
writer, writer,
&handshake, &handshake,
peer, peer,
local_addr,
&config, &config,
&self.beobachten, &self.beobachten,
) )
@ -558,7 +571,7 @@ impl RunningClientHandler {
))) )))
} }
async fn handle_direct_client(mut self, first_bytes: [u8; 5]) -> Result<HandshakeOutcome> { async fn handle_direct_client(mut self, first_bytes: [u8; 5], local_addr: SocketAddr) -> Result<HandshakeOutcome> {
let peer = self.peer; let peer = self.peer;
let _ip_tracker = self.ip_tracker.clone(); let _ip_tracker = self.ip_tracker.clone();
@ -571,6 +584,7 @@ impl RunningClientHandler {
writer, writer,
&first_bytes, &first_bytes,
peer, peer,
local_addr,
&self.config, &self.config,
&self.beobachten, &self.beobachten,
) )
@ -587,7 +601,6 @@ impl RunningClientHandler {
let stats = self.stats.clone(); let stats = self.stats.clone();
let buffer_pool = self.buffer_pool.clone(); let buffer_pool = self.buffer_pool.clone();
let local_addr = self.stream.local_addr().map_err(ProxyError::Io)?;
let (read_half, write_half) = self.stream.into_split(); let (read_half, write_half) = self.stream.into_split();
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake( let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
@ -609,6 +622,7 @@ impl RunningClientHandler {
writer, writer,
&handshake, &handshake,
peer, peer,
local_addr,
&config, &config,
&self.beobachten, &self.beobachten,
) )

View File

@ -55,6 +55,7 @@ pub async fn handle_bad_client<R, W>(
writer: W, writer: W,
initial_data: &[u8], initial_data: &[u8],
peer: SocketAddr, peer: SocketAddr,
local_addr: SocketAddr,
config: &ProxyConfig, config: &ProxyConfig,
beobachten: &BeobachtenStore, beobachten: &BeobachtenStore,
) )
@ -126,23 +127,16 @@ where
let proxy_header: Option<Vec<u8>> = match config.censorship.mask_proxy_protocol { let proxy_header: Option<Vec<u8>> = match config.censorship.mask_proxy_protocol {
0 => None, 0 => None,
version => { version => {
let header = if let Ok(local_addr) = stream.local_addr() { let header = match version {
match version { 2 => ProxyProtocolV2Builder::new().with_addrs(peer, local_addr).build(),
2 => ProxyProtocolV2Builder::new().with_addrs(peer, local_addr).build(), _ => match (peer, local_addr) {
_ => match (peer, local_addr) { (SocketAddr::V4(src), SocketAddr::V4(dst)) =>
(SocketAddr::V4(src), SocketAddr::V4(dst)) => ProxyProtocolV1Builder::new().tcp4(src.into(), dst.into()).build(),
ProxyProtocolV1Builder::new().tcp4(src.into(), dst.into()).build(), (SocketAddr::V6(src), SocketAddr::V6(dst)) =>
(SocketAddr::V6(src), SocketAddr::V6(dst)) => ProxyProtocolV1Builder::new().tcp6(src.into(), dst.into()).build(),
ProxyProtocolV1Builder::new().tcp6(src.into(), dst.into()).build(), _ =>
_ => ProxyProtocolV1Builder::new().build(),
ProxyProtocolV1Builder::new().build(), },
},
}
} else {
match version {
2 => ProxyProtocolV2Builder::new().build(),
_ => ProxyProtocolV1Builder::new().build(),
}
}; };
Some(header) Some(header)
} }