Refactor TLS fallback tests to remove unnecessary client hello assertions

- Removed assertions for expected client hello messages in multiple TLS fallback tests to streamline the test logic.
- Updated the tests to focus on verifying the trailing TLS records received after the fallback.
- Enhanced the masking functionality by adding shape hardening features, including dynamic padding based on sent data size.
- Modified the relay_to_mask function to accommodate new parameters for shape hardening.
- Updated masking security tests to reflect changes in the relay_to_mask function signature.
This commit is contained in:
David Osipov
2026-03-20 22:44:39 +04:00
parent 3abde52de8
commit 0eca535955
16 changed files with 3354 additions and 346 deletions

View File

@@ -8,6 +8,7 @@ use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use ipnetwork::IpNetwork;
use rand::Rng;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::net::TcpStream;
use tokio::time::timeout;
@@ -20,8 +21,8 @@ type PostHandshakeFuture = Pin<Box<dyn Future<Output = Result<()>> + Send>>;
enum HandshakeOutcome {
/// Handshake succeeded, relay work to do (outside timeout)
NeedsRelay(PostHandshakeFuture),
/// Already fully handled (bad client masking, etc.)
Handled,
/// Handshake failed and masking must run outside handshake timeout budget
NeedsMasking(PostHandshakeFuture),
}
#[must_use = "UserConnectionReservation must be kept alive to retain user/IP reservation until release or drop"]
@@ -130,6 +131,24 @@ async fn read_with_progress<R: AsyncRead + Unpin>(reader: &mut R, mut buf: &mut
Ok(total)
}
async fn maybe_apply_mask_reject_delay(config: &ProxyConfig) {
let min = config.censorship.server_hello_delay_min_ms;
let max = config.censorship.server_hello_delay_max_ms;
if max == 0 {
return;
}
let delay_ms = if min >= max {
max
} else {
rand::rng().random_range(min..=max)
};
if delay_ms > 0 {
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
}
fn handshake_timeout_with_mask_grace(config: &ProxyConfig) -> Duration {
let base = Duration::from_secs(config.timeouts.client_handshake);
if config.censorship.mask {
@@ -139,6 +158,34 @@ fn handshake_timeout_with_mask_grace(config: &ProxyConfig) -> Duration {
}
}
fn masking_outcome<R, W>(
reader: R,
writer: W,
initial_data: Vec<u8>,
peer: SocketAddr,
local_addr: SocketAddr,
config: Arc<ProxyConfig>,
beobachten: Arc<BeobachtenStore>,
) -> HandshakeOutcome
where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
{
HandshakeOutcome::NeedsMasking(Box::pin(async move {
handle_bad_client(
reader,
writer,
&initial_data,
peer,
local_addr,
&config,
&beobachten,
)
.await;
Ok(())
}))
}
fn record_beobachten_class(
beobachten: &BeobachtenStore,
config: &ProxyConfig,
@@ -283,18 +330,17 @@ where
if !tls_clienthello_len_in_bounds(tls_len) {
debug!(peer = %real_peer, tls_len = tls_len, max_tls_len = MAX_TLS_PLAINTEXT_SIZE, "TLS handshake length out of bounds");
stats.increment_connects_bad();
maybe_apply_mask_reject_delay(&config).await;
let (reader, writer) = tokio::io::split(stream);
handle_bad_client(
return Ok(masking_outcome(
reader,
writer,
&first_bytes,
first_bytes.to_vec(),
real_peer,
local_addr,
&config,
&beobachten,
)
.await;
return Ok(HandshakeOutcome::Handled);
config.clone(),
beobachten.clone(),
));
}
let mut handshake = vec![0u8; 5 + tls_len];
@@ -304,38 +350,36 @@ where
Err(e) => {
debug!(peer = %real_peer, error = %e, tls_len = tls_len, "TLS ClientHello body read failed; engaging masking fallback");
stats.increment_connects_bad();
maybe_apply_mask_reject_delay(&config).await;
let initial_len = 5;
let (reader, writer) = tokio::io::split(stream);
handle_bad_client(
return Ok(masking_outcome(
reader,
writer,
&handshake[..initial_len],
handshake[..initial_len].to_vec(),
real_peer,
local_addr,
&config,
&beobachten,
)
.await;
return Ok(HandshakeOutcome::Handled);
config.clone(),
beobachten.clone(),
));
}
};
if body_read < tls_len {
debug!(peer = %real_peer, got = body_read, expected = tls_len, "Truncated in-range TLS ClientHello; engaging masking fallback");
stats.increment_connects_bad();
maybe_apply_mask_reject_delay(&config).await;
let initial_len = 5 + body_read;
let (reader, writer) = tokio::io::split(stream);
handle_bad_client(
return Ok(masking_outcome(
reader,
writer,
&handshake[..initial_len],
handshake[..initial_len].to_vec(),
real_peer,
local_addr,
&config,
&beobachten,
)
.await;
return Ok(HandshakeOutcome::Handled);
config.clone(),
beobachten.clone(),
));
}
let (read_half, write_half) = tokio::io::split(stream);
@@ -347,17 +391,15 @@ where
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
handle_bad_client(
return Ok(masking_outcome(
reader,
writer,
&handshake,
handshake.clone(),
real_peer,
local_addr,
&config,
&beobachten,
)
.await;
return Ok(HandshakeOutcome::Handled);
config.clone(),
beobachten.clone(),
));
}
HandshakeResult::Error(e) => return Err(e),
};
@@ -389,17 +431,15 @@ where
peer = %peer,
"Authenticated TLS session failed MTProto validation; engaging masking fallback"
);
handle_bad_client(
return Ok(masking_outcome(
reader,
writer,
&handshake,
Vec::new(),
real_peer,
local_addr,
&config,
&beobachten,
)
.await;
return Ok(HandshakeOutcome::Handled);
config.clone(),
beobachten.clone(),
));
}
HandshakeResult::Error(e) => return Err(e),
};
@@ -416,18 +456,17 @@ where
if !config.general.modes.classic && !config.general.modes.secure {
debug!(peer = %real_peer, "Non-TLS modes disabled");
stats.increment_connects_bad();
maybe_apply_mask_reject_delay(&config).await;
let (reader, writer) = tokio::io::split(stream);
handle_bad_client(
return Ok(masking_outcome(
reader,
writer,
&first_bytes,
first_bytes.to_vec(),
real_peer,
local_addr,
&config,
&beobachten,
)
.await;
return Ok(HandshakeOutcome::Handled);
config.clone(),
beobachten.clone(),
));
}
let mut handshake = [0u8; HANDSHAKE_LEN];
@@ -443,17 +482,15 @@ where
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
handle_bad_client(
return Ok(masking_outcome(
reader,
writer,
&handshake,
handshake.to_vec(),
real_peer,
local_addr,
&config,
&beobachten,
)
.await;
return Ok(HandshakeOutcome::Handled);
config.clone(),
beobachten.clone(),
));
}
HandshakeResult::Error(e) => return Err(e),
};
@@ -503,8 +540,7 @@ where
// Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts)
match outcome {
HandshakeOutcome::NeedsRelay(fut) => fut.await,
HandshakeOutcome::Handled => Ok(()),
HandshakeOutcome::NeedsRelay(fut) | HandshakeOutcome::NeedsMasking(fut) => fut.await,
}
}
@@ -617,8 +653,7 @@ impl RunningClientHandler {
// Phase 2: relay (WITHOUT handshake timeout — relay has its own activity timeouts)
match outcome {
HandshakeOutcome::NeedsRelay(fut) => fut.await,
HandshakeOutcome::Handled => Ok(()),
HandshakeOutcome::NeedsRelay(fut) | HandshakeOutcome::NeedsMasking(fut) => fut.await,
}
}
@@ -731,18 +766,17 @@ impl RunningClientHandler {
if !tls_clienthello_len_in_bounds(tls_len) {
debug!(peer = %peer, tls_len = tls_len, max_tls_len = MAX_TLS_PLAINTEXT_SIZE, "TLS handshake length out of bounds");
self.stats.increment_connects_bad();
maybe_apply_mask_reject_delay(&self.config).await;
let (reader, writer) = self.stream.into_split();
handle_bad_client(
return Ok(masking_outcome(
reader,
writer,
&first_bytes,
first_bytes.to_vec(),
peer,
local_addr,
&self.config,
&self.beobachten,
)
.await;
return Ok(HandshakeOutcome::Handled);
self.config.clone(),
self.beobachten.clone(),
));
}
let mut handshake = vec![0u8; 5 + tls_len];
@@ -752,37 +786,35 @@ impl RunningClientHandler {
Err(e) => {
debug!(peer = %peer, error = %e, tls_len = tls_len, "TLS ClientHello body read failed; engaging masking fallback");
self.stats.increment_connects_bad();
maybe_apply_mask_reject_delay(&self.config).await;
let (reader, writer) = self.stream.into_split();
handle_bad_client(
return Ok(masking_outcome(
reader,
writer,
&handshake[..5],
handshake[..5].to_vec(),
peer,
local_addr,
&self.config,
&self.beobachten,
)
.await;
return Ok(HandshakeOutcome::Handled);
self.config.clone(),
self.beobachten.clone(),
));
}
};
if body_read < tls_len {
debug!(peer = %peer, got = body_read, expected = tls_len, "Truncated in-range TLS ClientHello; engaging masking fallback");
self.stats.increment_connects_bad();
maybe_apply_mask_reject_delay(&self.config).await;
let initial_len = 5 + body_read;
let (reader, writer) = self.stream.into_split();
handle_bad_client(
return Ok(masking_outcome(
reader,
writer,
&handshake[..initial_len],
handshake[..initial_len].to_vec(),
peer,
local_addr,
&self.config,
&self.beobachten,
)
.await;
return Ok(HandshakeOutcome::Handled);
self.config.clone(),
self.beobachten.clone(),
));
}
let config = self.config.clone();
@@ -807,17 +839,15 @@ impl RunningClientHandler {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
handle_bad_client(
return Ok(masking_outcome(
reader,
writer,
&handshake,
handshake.clone(),
peer,
local_addr,
&config,
&self.beobachten,
)
.await;
return Ok(HandshakeOutcome::Handled);
config.clone(),
self.beobachten.clone(),
));
}
HandshakeResult::Error(e) => return Err(e),
};
@@ -858,17 +888,15 @@ impl RunningClientHandler {
peer = %peer,
"Authenticated TLS session failed MTProto validation; engaging masking fallback"
);
handle_bad_client(
return Ok(masking_outcome(
reader,
writer,
&handshake,
Vec::new(),
peer,
local_addr,
&config,
&self.beobachten,
)
.await;
return Ok(HandshakeOutcome::Handled);
config.clone(),
self.beobachten.clone(),
));
}
HandshakeResult::Error(e) => return Err(e),
};
@@ -898,18 +926,17 @@ impl RunningClientHandler {
if !self.config.general.modes.classic && !self.config.general.modes.secure {
debug!(peer = %peer, "Non-TLS modes disabled");
self.stats.increment_connects_bad();
maybe_apply_mask_reject_delay(&self.config).await;
let (reader, writer) = self.stream.into_split();
handle_bad_client(
return Ok(masking_outcome(
reader,
writer,
&first_bytes,
first_bytes.to_vec(),
peer,
local_addr,
&self.config,
&self.beobachten,
)
.await;
return Ok(HandshakeOutcome::Handled);
self.config.clone(),
self.beobachten.clone(),
));
}
let mut handshake = [0u8; HANDSHAKE_LEN];
@@ -938,17 +965,15 @@ impl RunningClientHandler {
HandshakeResult::Success(result) => result,
HandshakeResult::BadClient { reader, writer } => {
stats.increment_connects_bad();
handle_bad_client(
return Ok(masking_outcome(
reader,
writer,
&handshake,
handshake.to_vec(),
peer,
local_addr,
&config,
&self.beobachten,
)
.await;
return Ok(HandshakeOutcome::Handled);
config.clone(),
self.beobachten.clone(),
));
}
HandshakeResult::Error(e) => return Err(e),
};
@@ -1208,3 +1233,31 @@ mod tls_clienthello_truncation_adversarial_tests;
#[cfg(test)]
#[path = "client_timing_profile_adversarial_tests.rs"]
mod timing_profile_adversarial_tests;
#[cfg(test)]
#[path = "client_masking_budget_security_tests.rs"]
mod masking_budget_security_tests;
#[cfg(test)]
#[path = "client_masking_redteam_expected_fail_tests.rs"]
mod masking_redteam_expected_fail_tests;
#[cfg(test)]
#[path = "client_masking_hard_adversarial_tests.rs"]
mod masking_hard_adversarial_tests;
#[cfg(test)]
#[path = "client_masking_stress_adversarial_tests.rs"]
mod masking_stress_adversarial_tests;
#[cfg(test)]
#[path = "client_masking_blackhat_campaign_tests.rs"]
mod masking_blackhat_campaign_tests;
#[cfg(test)]
#[path = "client_masking_diagnostics_security_tests.rs"]
mod masking_diagnostics_security_tests;
#[cfg(test)]
#[path = "client_masking_shape_hardening_security_tests.rs"]
mod masking_shape_hardening_security_tests;