From e8454ea37084f3f1022aab77758edf5a7b80824b Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Fri, 20 Feb 2026 16:42:40 +0300 Subject: [PATCH] HAProxy PROXY Protocol Fixes Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/config/load.rs | 2 ++ src/config/types.rs | 3 +++ src/main.rs | 11 +++++--- src/protocol/tls.rs | 61 +++++++++++++++++++++++++++++++++++++++++++++ src/proxy/client.rs | 8 ++++-- 5 files changed, 80 insertions(+), 5 deletions(-) diff --git a/src/config/load.rs b/src/config/load.rs index eeda2b0..ec8011a 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -226,6 +226,7 @@ impl ProxyConfig { ip: ipv4, announce: None, announce_ip: None, + proxy_protocol: None, }); } if let Some(ipv6_str) = &config.server.listen_addr_ipv6 { @@ -234,6 +235,7 @@ impl ProxyConfig { ip: ipv6, announce: None, announce_ip: None, + proxy_protocol: None, }); } } diff --git a/src/config/types.rs b/src/config/types.rs index 9aea28a..03529bb 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -513,6 +513,9 @@ pub struct ListenerConfig { /// Migrated to `announce` automatically if `announce` is not set. #[serde(default)] pub announce_ip: Option, + /// Per-listener PROXY protocol override. When set, overrides global server.proxy_protocol. + #[serde(default)] + pub proxy_protocol: Option, } // ============= ShowLink ============= diff --git a/src/main.rs b/src/main.rs index 20e00db..31f4f94 100644 --- a/src/main.rs +++ b/src/main.rs @@ -699,6 +699,8 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai Ok(socket) => { let listener = TcpListener::from_std(socket.into())?; info!("Listening on {}", addr); + let listener_proxy_protocol = + listener_conf.proxy_protocol.unwrap_or(config.server.proxy_protocol); // Resolve the public host for link generation let public_host = if let Some(ref announce) = listener_conf.announce { @@ -724,7 +726,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai print_proxy_links(&public_host, link_port, &config); } - listeners.push(listener); + listeners.push((listener, listener_proxy_protocol)); } Err(e) => { error!("Failed to bind to {}: {}", addr, e); @@ -810,12 +812,13 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai let me_pool = me_pool.clone(); let tls_cache = tls_cache.clone(); let ip_tracker = ip_tracker.clone(); + let proxy_protocol_enabled = config.server.proxy_protocol; tokio::spawn(async move { if let Err(e) = crate::proxy::client::handle_client_stream( stream, fake_peer, config, stats, upstream_manager, replay_checker, buffer_pool, rng, - me_pool, tls_cache, ip_tracker, + me_pool, tls_cache, ip_tracker, proxy_protocol_enabled, ).await { debug!(error = %e, "Unix socket connection error"); } @@ -855,7 +858,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai }); } - for listener in listeners { + for (listener, listener_proxy_protocol) in listeners { let config = config.clone(); let stats = stats.clone(); let upstream_manager = upstream_manager.clone(); @@ -879,6 +882,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai let me_pool = me_pool.clone(); let tls_cache = tls_cache.clone(); let ip_tracker = ip_tracker.clone(); + let proxy_protocol_enabled = listener_proxy_protocol; tokio::spawn(async move { if let Err(e) = ClientHandler::new( @@ -893,6 +897,7 @@ match crate::transport::middle_proxy::fetch_proxy_secret(proxy_secret_path).awai me_pool, tls_cache, ip_tracker, + proxy_protocol_enabled, ) .run() .await diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index c0efc78..fe1e8b6 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -755,4 +755,65 @@ mod tests { // Should return None (no match) but not panic assert!(result.is_none()); } + + fn build_client_hello_with_exts(exts: Vec<(u16, Vec)>, host: &str) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); // legacy version + body.extend_from_slice(&[0u8; 32]); // random + body.push(0); // session id len + body.extend_from_slice(&2u16.to_be_bytes()); // cipher suites len + body.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256 + body.push(1); // compression len + body.push(0); // null compression + + // Build SNI extension + let host_bytes = host.as_bytes(); + let mut sni_ext = Vec::new(); + sni_ext.extend_from_slice(&(host_bytes.len() as u16 + 3).to_be_bytes()); + sni_ext.push(0); + sni_ext.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_ext.extend_from_slice(host_bytes); + + let mut ext_blob = Vec::new(); + for (typ, data) in exts { + ext_blob.extend_from_slice(&typ.to_be_bytes()); + ext_blob.extend_from_slice(&(data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&data); + } + // SNI last + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_ext); + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); // ClientHello + let len_bytes = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&len_bytes[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + record + } + + #[test] + fn test_extract_sni_with_grease_extension() { + // GREASE type 0x0a0a with zero length before SNI + let ch = build_client_hello_with_exts(vec![(0x0a0a, Vec::new())], "example.com"); + let sni = extract_sni_from_client_hello(&ch); + assert_eq!(sni.as_deref(), Some("example.com")); + } + + #[test] + fn test_extract_sni_tolerates_empty_unknown_extension() { + let ch = build_client_hello_with_exts(vec![(0x1234, Vec::new())], "test.local"); + let sni = extract_sni_from_client_hello(&ch); + assert_eq!(sni.as_deref(), Some("test.local")); + } } diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 525c0e9..14b45da 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -51,6 +51,7 @@ pub async fn handle_client_stream( me_pool: Option>, tls_cache: Option>, ip_tracker: Arc, + proxy_protocol_enabled: bool, ) -> Result<()> where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, @@ -58,7 +59,7 @@ where stats.increment_connects_all(); let mut real_peer = normalize_ip(peer); - if config.server.proxy_protocol { + if proxy_protocol_enabled { match parse_proxy_protocol(&mut stream, peer).await { Ok(info) => { debug!( @@ -229,6 +230,7 @@ pub struct RunningClientHandler { me_pool: Option>, tls_cache: Option>, ip_tracker: Arc, + proxy_protocol_enabled: bool, } impl ClientHandler { @@ -244,6 +246,7 @@ impl ClientHandler { me_pool: Option>, tls_cache: Option>, ip_tracker: Arc, + proxy_protocol_enabled: bool, ) -> RunningClientHandler { RunningClientHandler { stream, @@ -257,6 +260,7 @@ impl ClientHandler { me_pool, tls_cache, ip_tracker, + proxy_protocol_enabled, } } } @@ -303,7 +307,7 @@ impl RunningClientHandler { } async fn do_handshake(mut self) -> Result { - if self.config.server.proxy_protocol { + if self.proxy_protocol_enabled { match parse_proxy_protocol(&mut self.stream, self.peer).await { Ok(info) => { debug!(