feat: add mask_proxy_protocol option for PROXY protocol to mask_host

Adds mask_proxy_protocol config option (0 = off, 1 = v1 text, 2 = v2 binary)
that sends a PROXY protocol header when connecting to mask_host. This lets
the backend see the real client IP address.

Particularly useful when the masking site (nginx/HAProxy) runs on the same
host as telemt and listens on a local port — without this, the backend loses
the original client IP entirely.

PROXY protocol header is also sent during TLS emulation fetches so that
backends with proxy_protocol required don't reject the connection.
This commit is contained in:
ivulit 2026-02-26 13:36:33 +03:00
parent 7ead0cd753
commit da684b11fe
No known key found for this signature in database
7 changed files with 83 additions and 19 deletions

View File

@ -135,6 +135,7 @@ mask = true
# mask_host = "www.google.com" # example, defaults to tls_domain when both mask_host/mask_unix_sock are unset # mask_host = "www.google.com" # example, defaults to tls_domain when both mask_host/mask_unix_sock are unset
# mask_unix_sock = "/var/run/nginx.sock" # example, mutually exclusive with mask_host # mask_unix_sock = "/var/run/nginx.sock" # example, mutually exclusive with mask_host
mask_port = 443 mask_port = 443
# mask_proxy_protocol = 0 # Send PROXY protocol header to mask_host: 0 = off, 1 = v1 (text), 2 = v2 (binary)
fake_cert_len = 2048 # if tls_emulation=false and default value is used, loader may randomize this value at runtime fake_cert_len = 2048 # if tls_emulation=false and default value is used, loader may randomize this value at runtime
tls_emulation = true tls_emulation = true
tls_front_dir = "tlsfront" tls_front_dir = "tlsfront"

View File

@ -611,6 +611,12 @@ pub struct AntiCensorshipConfig {
/// Enforce ALPN echo of client preference. /// Enforce ALPN echo of client preference.
#[serde(default = "default_alpn_enforce")] #[serde(default = "default_alpn_enforce")]
pub alpn_enforce: bool, pub alpn_enforce: bool,
/// Send PROXY protocol header when connecting to mask_host.
/// 0 = disabled, 1 = v1 (text), 2 = v2 (binary).
/// Allows the backend to see the real client IP.
#[serde(default)]
pub mask_proxy_protocol: u8,
} }
impl Default for AntiCensorshipConfig { impl Default for AntiCensorshipConfig {
@ -630,6 +636,7 @@ impl Default for AntiCensorshipConfig {
tls_new_session_tickets: default_tls_new_session_tickets(), tls_new_session_tickets: default_tls_new_session_tickets(),
tls_full_cert_ttl_secs: default_tls_full_cert_ttl_secs(), tls_full_cert_ttl_secs: default_tls_full_cert_ttl_secs(),
alpn_enforce: default_alpn_enforce(), alpn_enforce: default_alpn_enforce(),
mask_proxy_protocol: 0,
} }
} }
} }

View File

@ -474,6 +474,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
&domain, &domain,
Duration::from_secs(5), Duration::from_secs(5),
Some(upstream_manager.clone()), Some(upstream_manager.clone()),
config.censorship.mask_proxy_protocol,
) )
.await .await
{ {
@ -486,6 +487,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let cache_clone = cache.clone(); let cache_clone = cache.clone();
let domains = tls_domains.clone(); let domains = tls_domains.clone();
let upstream_for_task = upstream_manager.clone(); let upstream_for_task = upstream_manager.clone();
let proxy_protocol = config.censorship.mask_proxy_protocol;
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
let base_secs = rand::rng().random_range(4 * 3600..=6 * 3600); let base_secs = rand::rng().random_range(4 * 3600..=6 * 3600);
@ -498,6 +500,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
domain, domain,
Duration::from_secs(5), Duration::from_secs(5),
Some(upstream_for_task.clone()), Some(upstream_for_task.clone()),
proxy_protocol,
) )
.await .await
{ {

View File

@ -143,7 +143,7 @@ where
reader, reader,
writer, writer,
&first_bytes, &first_bytes,
real_peer.ip(), real_peer,
&config, &config,
&beobachten, &beobachten,
) )
@ -168,7 +168,7 @@ where
reader, reader,
writer, writer,
&handshake, &handshake,
real_peer.ip(), real_peer,
&config, &config,
&beobachten, &beobachten,
) )
@ -212,7 +212,7 @@ where
reader, reader,
writer, writer,
&first_bytes, &first_bytes,
real_peer.ip(), real_peer,
&config, &config,
&beobachten, &beobachten,
) )
@ -237,7 +237,7 @@ where
reader, reader,
writer, writer,
&handshake, &handshake,
real_peer.ip(), real_peer,
&config, &config,
&beobachten, &beobachten,
) )
@ -462,7 +462,7 @@ impl RunningClientHandler {
reader, reader,
writer, writer,
&first_bytes, &first_bytes,
peer.ip(), peer,
&self.config, &self.config,
&self.beobachten, &self.beobachten,
) )
@ -501,7 +501,7 @@ impl RunningClientHandler {
reader, reader,
writer, writer,
&handshake, &handshake,
peer.ip(), peer,
&config, &config,
&self.beobachten, &self.beobachten,
) )
@ -570,7 +570,7 @@ impl RunningClientHandler {
reader, reader,
writer, writer,
&first_bytes, &first_bytes,
peer.ip(), peer,
&self.config, &self.config,
&self.beobachten, &self.beobachten,
) )
@ -608,7 +608,7 @@ impl RunningClientHandler {
reader, reader,
writer, writer,
&handshake, &handshake,
peer.ip(), peer,
&config, &config,
&self.beobachten, &self.beobachten,
) )

View File

@ -1,7 +1,7 @@
//! Masking - forward unrecognized traffic to mask host //! Masking - forward unrecognized traffic to mask host
use std::str; use std::str;
use std::net::IpAddr; use std::net::SocketAddr;
use std::time::Duration; use std::time::Duration;
use tokio::net::TcpStream; use tokio::net::TcpStream;
#[cfg(unix)] #[cfg(unix)]
@ -11,6 +11,7 @@ use tokio::time::timeout;
use tracing::debug; use tracing::debug;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
const MASK_TIMEOUT: Duration = Duration::from_secs(5); const MASK_TIMEOUT: Duration = Duration::from_secs(5);
/// Maximum duration for the entire masking relay. /// Maximum duration for the entire masking relay.
@ -52,7 +53,7 @@ pub async fn handle_bad_client<R, W>(
reader: R, reader: R,
writer: W, writer: W,
initial_data: &[u8], initial_data: &[u8],
peer_ip: IpAddr, peer: SocketAddr,
config: &ProxyConfig, config: &ProxyConfig,
beobachten: &BeobachtenStore, beobachten: &BeobachtenStore,
) )
@ -63,7 +64,7 @@ where
let client_type = detect_client_type(initial_data); let client_type = detect_client_type(initial_data);
if config.general.beobachten { if config.general.beobachten {
let ttl = Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60)); let ttl = Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60));
beobachten.record(client_type, peer_ip, ttl); beobachten.record(client_type, peer.ip(), ttl);
} }
if !config.censorship.mask { if !config.censorship.mask {
@ -119,7 +120,37 @@ where
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await; let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
match connect_result { match connect_result {
Ok(Ok(stream)) => { Ok(Ok(stream)) => {
let (mask_read, mask_write) = stream.into_split(); let proxy_header: Option<Vec<u8>> = match config.censorship.mask_proxy_protocol {
0 => None,
version => {
let header = if let Ok(local_addr) = stream.local_addr() {
match version {
2 => ProxyProtocolV2Builder::new().with_addrs(peer, local_addr).build(),
_ => match (peer, local_addr) {
(SocketAddr::V4(src), SocketAddr::V4(dst)) =>
ProxyProtocolV1Builder::new().tcp4(src.into(), dst.into()).build(),
(SocketAddr::V6(src), SocketAddr::V6(dst)) =>
ProxyProtocolV1Builder::new().tcp6(src.into(), dst.into()).build(),
_ =>
ProxyProtocolV1Builder::new().build(),
},
}
} else {
match version {
2 => ProxyProtocolV2Builder::new().build(),
_ => ProxyProtocolV1Builder::new().build(),
}
};
Some(header)
}
};
let (mask_read, mut mask_write) = stream.into_split();
if let Some(header) = proxy_header {
if mask_write.write_all(&header).await.is_err() {
return;
}
}
if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() { if timeout(MASK_RELAY_TIMEOUT, relay_to_mask(reader, writer, mask_read, mask_write, initial_data)).await.is_err() {
debug!("Mask relay timed out"); debug!("Mask relay timed out");
} }

View File

@ -19,6 +19,7 @@ use x509_parser::certificate::X509Certificate;
use crate::crypto::SecureRandom; use crate::crypto::SecureRandom;
use crate::protocol::constants::{TLS_RECORD_APPLICATION, TLS_RECORD_HANDSHAKE}; use crate::protocol::constants::{TLS_RECORD_APPLICATION, TLS_RECORD_HANDSHAKE};
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
use crate::tls_front::types::{ use crate::tls_front::types::{
ParsedCertificateInfo, ParsedCertificateInfo,
ParsedServerHello, ParsedServerHello,
@ -366,6 +367,7 @@ async fn fetch_via_raw_tls(
port: u16, port: u16,
sni: &str, sni: &str,
connect_timeout: Duration, connect_timeout: Duration,
proxy_protocol: u8,
) -> Result<TlsFetchResult> { ) -> Result<TlsFetchResult> {
let addr = format!("{host}:{port}"); let addr = format!("{host}:{port}");
let mut stream = timeout(connect_timeout, TcpStream::connect(addr)).await??; let mut stream = timeout(connect_timeout, TcpStream::connect(addr)).await??;
@ -373,6 +375,13 @@ async fn fetch_via_raw_tls(
let rng = SecureRandom::new(); let rng = SecureRandom::new();
let client_hello = build_client_hello(sni, &rng); let client_hello = build_client_hello(sni, &rng);
timeout(connect_timeout, async { timeout(connect_timeout, async {
if proxy_protocol > 0 {
let header = match proxy_protocol {
2 => ProxyProtocolV2Builder::new().build(),
_ => ProxyProtocolV1Builder::new().build(),
};
stream.write_all(&header).await?;
}
stream.write_all(&client_hello).await?; stream.write_all(&client_hello).await?;
stream.flush().await?; stream.flush().await?;
Ok::<(), std::io::Error>(()) Ok::<(), std::io::Error>(())
@ -424,9 +433,10 @@ async fn fetch_via_rustls(
sni: &str, sni: &str,
connect_timeout: Duration, connect_timeout: Duration,
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>, upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
proxy_protocol: u8,
) -> Result<TlsFetchResult> { ) -> Result<TlsFetchResult> {
// rustls handshake path for certificate and basic negotiated metadata. // rustls handshake path for certificate and basic negotiated metadata.
let stream = if let Some(manager) = upstream { let mut stream = if let Some(manager) = upstream {
// Resolve host to SocketAddr // Resolve host to SocketAddr
if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await { if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await {
if let Some(addr) = addrs.find(|a| a.is_ipv4()) { if let Some(addr) = addrs.find(|a| a.is_ipv4()) {
@ -447,6 +457,15 @@ async fn fetch_via_rustls(
timeout(connect_timeout, TcpStream::connect((host, port))).await?? timeout(connect_timeout, TcpStream::connect((host, port))).await??
}; };
if proxy_protocol > 0 {
let header = match proxy_protocol {
2 => ProxyProtocolV2Builder::new().build(),
_ => ProxyProtocolV1Builder::new().build(),
};
stream.write_all(&header).await?;
stream.flush().await?;
}
let config = build_client_config(); let config = build_client_config();
let connector = TlsConnector::from(config); let connector = TlsConnector::from(config);
@ -527,8 +546,9 @@ pub async fn fetch_real_tls(
sni: &str, sni: &str,
connect_timeout: Duration, connect_timeout: Duration,
upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>, upstream: Option<std::sync::Arc<crate::transport::UpstreamManager>>,
proxy_protocol: u8,
) -> Result<TlsFetchResult> { ) -> Result<TlsFetchResult> {
let raw_result = match fetch_via_raw_tls(host, port, sni, connect_timeout).await { let raw_result = match fetch_via_raw_tls(host, port, sni, connect_timeout, proxy_protocol).await {
Ok(res) => Some(res), Ok(res) => Some(res),
Err(e) => { Err(e) => {
warn!(sni = %sni, error = %e, "Raw TLS fetch failed"); warn!(sni = %sni, error = %e, "Raw TLS fetch failed");
@ -536,7 +556,7 @@ pub async fn fetch_real_tls(
} }
}; };
match fetch_via_rustls(host, port, sni, connect_timeout, upstream).await { match fetch_via_rustls(host, port, sni, connect_timeout, upstream, proxy_protocol).await {
Ok(rustls_result) => { Ok(rustls_result) => {
if let Some(mut raw) = raw_result { if let Some(mut raw) = raw_result {
raw.cert_info = rustls_result.cert_info; raw.cert_info = rustls_result.cert_info;

View File

@ -233,14 +233,12 @@ async fn parse_v2<R: AsyncRead + Unpin>(
} }
/// Builder for PROXY protocol v1 header /// Builder for PROXY protocol v1 header
#[allow(dead_code)]
pub struct ProxyProtocolV1Builder { pub struct ProxyProtocolV1Builder {
family: &'static str, family: &'static str,
src_addr: Option<SocketAddr>, src_addr: Option<SocketAddr>,
dst_addr: Option<SocketAddr>, dst_addr: Option<SocketAddr>,
} }
#[allow(dead_code)]
impl ProxyProtocolV1Builder { impl ProxyProtocolV1Builder {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
@ -288,13 +286,17 @@ impl Default for ProxyProtocolV1Builder {
} }
/// Builder for PROXY protocol v2 header /// Builder for PROXY protocol v2 header
#[allow(dead_code)]
pub struct ProxyProtocolV2Builder { pub struct ProxyProtocolV2Builder {
src: Option<SocketAddr>, src: Option<SocketAddr>,
dst: Option<SocketAddr>, dst: Option<SocketAddr>,
} }
#[allow(dead_code)] impl Default for ProxyProtocolV2Builder {
fn default() -> Self {
Self::new()
}
}
impl ProxyProtocolV2Builder { impl ProxyProtocolV2Builder {
pub fn new() -> Self { pub fn new() -> Self {
Self { src: None, dst: None } Self { src: None, dst: None }