feat(tls): add boot time timestamp constant and validation for SNI hostnames

- Introduced `BOOT_TIME_MAX_SECS` constant to define the maximum accepted boot-time timestamp.
- Updated `validate_tls_handshake_at_time` to utilize the new boot time constant for timestamp validation.
- Enhanced `extract_sni_from_client_hello` to validate SNI hostnames against specified criteria, rejecting invalid hostnames.
- Added tests to ensure proper handling of boot time timestamps and SNI validation.

feat(handshake): improve user secret decoding and ALPN enforcement

- Refactored user secret decoding to provide better error handling and logging for invalid secrets.
- Added tests for concurrent identical handshakes to ensure replay protection works as expected.
- Implemented ALPN enforcement in handshake processing, rejecting unsupported protocols and allowing valid ones.

fix(masking): implement timeout handling for masking operations

- Added timeout handling for writing proxy headers and consuming client data in masking.
- Adjusted timeout durations for testing to ensure faster feedback during unit tests.
- Introduced tests to verify behavior when masking is disabled and when proxy header writes exceed the timeout.

test(masking): add tests for slowloris connections and proxy header timeouts

- Created tests to validate that slowloris connections are closed by consume timeout when masking is disabled.
- Added a test for proxy header write timeout to ensure it returns false when the write operation does not complete.
This commit is contained in:
David Osipov 2026-03-16 21:37:59 +04:00
parent 213ce4555a
commit e4a50f9286
No known key found for this signature in database
GPG Key ID: 0E55C4A47454E82E
7 changed files with 895 additions and 25 deletions

View File

@ -29,6 +29,8 @@ pub const TLS_DIGEST_HALF_LEN: usize = 16;
/// Time skew limits for anti-replay (in seconds) /// Time skew limits for anti-replay (in seconds)
pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before
pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after
/// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced.
pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60;
// ============= Private Constants ============= // ============= Private Constants =============
@ -364,7 +366,7 @@ fn validate_tls_handshake_at_time(
if !ignore_time_skew { if !ignore_time_skew {
// Allow very small timestamps (boot time instead of unix time) // Allow very small timestamps (boot time instead of unix time)
// This is a quirk in some clients that use uptime instead of real time // This is a quirk in some clients that use uptime instead of real time
let is_boot_time = timestamp < 60 * 60 * 24 * 1000; // < ~2.7 years in seconds let is_boot_time = timestamp < BOOT_TIME_MAX_SECS;
if !is_boot_time { if !is_boot_time {
let time_diff = now - i64::from(timestamp); let time_diff = now - i64::from(timestamp);
if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) { if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) {
@ -563,7 +565,9 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
if name_type == 0 && name_len > 0 if name_type == 0 && name_len > 0
&& let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len]) && let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len])
{ {
return Some(host.to_string()); if is_valid_sni_hostname(host) {
return Some(host.to_string());
}
} }
sn_pos += name_len; sn_pos += name_len;
} }
@ -574,6 +578,35 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
None None
} }
fn is_valid_sni_hostname(host: &str) -> bool {
if host.is_empty() || host.len() > 253 {
return false;
}
if host.starts_with('.') || host.ends_with('.') {
return false;
}
if host.parse::<std::net::IpAddr>().is_ok() {
return false;
}
for label in host.split('.') {
if label.is_empty() || label.len() > 63 {
return false;
}
if label.starts_with('-') || label.ends_with('-') {
return false;
}
if !label
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'-')
{
return false;
}
}
true
}
/// Extract ALPN protocol list from ClientHello, return in offered order. /// Extract ALPN protocol list from ClientHello, return in offered order.
pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> { pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> {
let mut pos = 5; // after record header let mut pos = 5; // after record header

View File

@ -286,8 +286,8 @@ fn boot_time_timestamp_accepted_without_ignore_flag() {
// Timestamps below the boot-time threshold are treated as client uptime, // Timestamps below the boot-time threshold are treated as client uptime,
// not real wall-clock time. The proxy allows them regardless of skew. // not real wall-clock time. The proxy allows them regardless of skew.
let secret = b"boot_time_test"; let secret = b"boot_time_test";
// 86_400_000 / 2 is well below the boot-time threshold (~2.74 years worth of seconds). // Keep this safely below BOOT_TIME_MAX_SECS to assert bypass behavior.
let boot_ts: u32 = 86_400_000 / 2; let boot_ts: u32 = BOOT_TIME_MAX_SECS / 2;
let handshake = make_valid_tls_handshake(secret, boot_ts); let handshake = make_valid_tls_handshake(secret, boot_ts);
let secrets = vec![("u".to_string(), secret.to_vec())]; let secrets = vec![("u".to_string(), secret.to_vec())];
assert!( assert!(
@ -611,13 +611,13 @@ fn zero_length_session_id_accepted() {
// Boot-time threshold — exact boundary precision // Boot-time threshold — exact boundary precision
// ------------------------------------------------------------------ // ------------------------------------------------------------------
/// timestamp = 86_399_999 is the last value inside the boot-time window. /// timestamp = BOOT_TIME_MAX_SECS - 1 is the last value inside the boot-time window.
/// is_boot_time = true → skew check is skipped entirely → accepted even /// is_boot_time = true → skew check is skipped entirely → accepted even
/// when `now` is far from the timestamp. /// when `now` is far from the timestamp.
#[test] #[test]
fn timestamp_one_below_boot_threshold_bypasses_skew_check() { fn timestamp_one_below_boot_threshold_bypasses_skew_check() {
let secret = b"boot_last_value_test"; let secret = b"boot_last_value_test";
let ts: u32 = 86_400_000 - 1; let ts: u32 = BOOT_TIME_MAX_SECS - 1;
let h = make_valid_tls_handshake(secret, ts); let h = make_valid_tls_handshake(secret, ts);
let secrets = vec![("u".to_string(), secret.to_vec())]; let secrets = vec![("u".to_string(), secret.to_vec())];
@ -625,17 +625,17 @@ fn timestamp_one_below_boot_threshold_bypasses_skew_check() {
// Boot-time bypass must prevent the skew check from running. // Boot-time bypass must prevent the skew check from running.
assert!( assert!(
validate_tls_handshake_at_time(&h, &secrets, false, 0).is_some(), validate_tls_handshake_at_time(&h, &secrets, false, 0).is_some(),
"ts=86_399_999 must bypass skew check regardless of now" "ts=BOOT_TIME_MAX_SECS-1 must bypass skew check regardless of now"
); );
} }
/// timestamp = 86_400_000 is the first value outside the boot-time window. /// timestamp = BOOT_TIME_MAX_SECS is the first value outside the boot-time window.
/// is_boot_time = false → skew check IS applied. Two sub-cases confirm this: /// is_boot_time = false → skew check IS applied. Two sub-cases confirm this:
/// once with now chosen so the skew passes (accepted) and once where it fails. /// once with now chosen so the skew passes (accepted) and once where it fails.
#[test] #[test]
fn timestamp_at_boot_threshold_triggers_skew_check() { fn timestamp_at_boot_threshold_triggers_skew_check() {
let secret = b"boot_exact_value_test"; let secret = b"boot_exact_value_test";
let ts: u32 = 86_400_000; let ts: u32 = BOOT_TIME_MAX_SECS;
let h = make_valid_tls_handshake(secret, ts); let h = make_valid_tls_handshake(secret, ts);
let secrets = vec![("u".to_string(), secret.to_vec())]; let secrets = vec![("u".to_string(), secret.to_vec())];
@ -643,14 +643,14 @@ fn timestamp_at_boot_threshold_triggers_skew_check() {
let now_valid: i64 = ts as i64 + 50; let now_valid: i64 = ts as i64 + 50;
assert!( assert!(
validate_tls_handshake_at_time(&h, &secrets, false, now_valid).is_some(), validate_tls_handshake_at_time(&h, &secrets, false, now_valid).is_some(),
"ts=86_400_000 within skew window must be accepted via skew check" "ts=BOOT_TIME_MAX_SECS within skew window must be accepted via skew check"
); );
// now = 0 → time_diff = -86_400_000, outside window → rejected. // now = 0 → time_diff = -86_400_000, outside window → rejected.
// If the boot-time bypass were wrongly applied here this would pass. // If the boot-time bypass were wrongly applied here this would pass.
assert!( assert!(
validate_tls_handshake_at_time(&h, &secrets, false, 0).is_none(), validate_tls_handshake_at_time(&h, &secrets, false, 0).is_none(),
"ts=86_400_000 far from now must be rejected — no boot-time bypass" "ts=BOOT_TIME_MAX_SECS far from now must be rejected — no boot-time bypass"
); );
} }
@ -675,7 +675,7 @@ fn u32_max_timestamp_accepted_with_ignore_time_skew() {
); );
} }
/// u32::MAX > 86_400_000 so the skew check runs. With any realistic `now` /// u32::MAX > BOOT_TIME_MAX_SECS so the skew check runs. With any realistic `now`
/// (~1.7 billion), time_diff = now - u32::MAX is deeply negative — far outside /// (~1.7 billion), time_diff = now - u32::MAX is deeply negative — far outside
/// [-1200, 600] — so the handshake must be rejected without overflow. /// [-1200, 600] — so the handshake must be rejected without overflow.
#[test] #[test]
@ -1109,6 +1109,25 @@ fn extract_sni_rejects_zero_length_host_name() {
assert!(extract_sni_from_client_hello(&ch).is_none()); assert!(extract_sni_from_client_hello(&ch).is_none());
} }
#[test]
fn extract_sni_rejects_raw_ipv4_literals() {
let ch = build_client_hello_with_exts(Vec::new(), "203.0.113.10");
assert!(extract_sni_from_client_hello(&ch).is_none());
}
#[test]
fn extract_sni_rejects_invalid_label_characters() {
let ch = build_client_hello_with_exts(Vec::new(), "exa_mple.com");
assert!(extract_sni_from_client_hello(&ch).is_none());
}
#[test]
fn extract_sni_rejects_oversized_label() {
let oversized = format!("{}.example.com", "a".repeat(64));
let ch = build_client_hello_with_exts(Vec::new(), &oversized);
assert!(extract_sni_from_client_hello(&ch).is_none());
}
#[test] #[test]
fn extract_sni_rejects_when_extension_block_is_truncated() { fn extract_sni_rejects_when_extension_block_is_truncated() {
let mut ext_blob = Vec::new(); let mut ext_blob = Vec::new();

View File

@ -92,6 +92,71 @@ async fn short_tls_probe_is_masked_through_client_pipeline() {
accept_task.await.unwrap(); accept_task.await.unwrap();
} }
#[tokio::test]
async fn partial_tls_header_stall_triggers_handshake_timeout() {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.timeouts.client_handshake = 1;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let peer: SocketAddr = "198.51.100.170:55201".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
));
client_side
.write_all(&[0x16, 0x03, 0x01, 0x02, 0x00])
.await
.unwrap();
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(matches!(result, Err(ProxyError::TgHandshakeTimeout)));
}
fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len: usize) -> Vec<u8> { fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len: usize) -> Vec<u8> {
assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header");
@ -122,6 +187,66 @@ fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32) -> Vec<u8> {
make_valid_tls_client_hello_with_len(secret, timestamp, 600) make_valid_tls_client_hello_with_len(secret, timestamp, 600)
} }
fn make_valid_tls_client_hello_with_alpn(
secret: &[u8],
timestamp: u32,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
body.push(32);
body.extend_from_slice(&[0x42u8; 32]);
body.extend_from_slice(&2u16.to_be_bytes());
body.extend_from_slice(&[0x13, 0x01]);
body.push(1);
body.push(0);
let mut ext_blob = Vec::new();
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
alpn_list.push(proto.len() as u8);
alpn_list.extend_from_slice(proto);
}
let mut alpn_data = Vec::new();
alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
alpn_data.extend_from_slice(&alpn_list);
ext_blob.extend_from_slice(&0x0010u16.to_be_bytes());
ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&alpn_data);
}
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_blob);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut record = Vec::new();
record.push(0x16);
record.extend_from_slice(&[0x03, 0x01]);
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
record.extend_from_slice(&handshake);
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &record);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
record
}
fn wrap_tls_application_data(payload: &[u8]) -> Vec<u8> { fn wrap_tls_application_data(payload: &[u8]) -> Vec<u8> {
let mut record = Vec::with_capacity(5 + payload.len()); let mut record = Vec::with_capacity(5 + payload.len());
record.push(0x17); record.push(0x17);
@ -439,6 +564,304 @@ async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() {
.unwrap(); .unwrap();
} }
#[tokio::test]
async fn alpn_mismatch_tls_probe_is_masked_through_client_pipeline() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let secret = [0x66u8; 16];
let probe = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]);
let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
let accept_task = tokio::spawn({
let probe = probe.clone();
let backend_reply = backend_reply.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; probe.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, probe);
stream.write_all(&backend_reply).await.unwrap();
}
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0;
cfg.censorship.alpn_enforce = true;
cfg.access.ignore_time_skew = true;
cfg.access
.users
.insert("user".to_string(), "66666666666666666666666666666666".to_string());
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(8192);
let peer: SocketAddr = "198.51.100.66:55211".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
));
client_side.write_all(&probe).await.unwrap();
let mut observed = vec![0u8; backend_reply.len()];
client_side.read_exact(&mut observed).await.unwrap();
assert_eq!(observed, backend_reply);
drop(client_side);
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
accept_task.await.unwrap();
}
#[tokio::test]
async fn invalid_hmac_tls_probe_is_masked_through_client_pipeline() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let secret = [0x77u8; 16];
let mut probe = make_valid_tls_client_hello(&secret, 0);
probe[tls::TLS_DIGEST_POS] ^= 0x01;
let accept_task = tokio::spawn({
let probe = probe.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; probe.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, probe);
}
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true;
cfg.access
.users
.insert("user".to_string(), "77777777777777777777777777777777".to_string());
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(8192);
let peer: SocketAddr = "198.51.100.77:55212".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
));
client_side.write_all(&probe).await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
drop(client_side);
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
}
#[tokio::test]
async fn burst_invalid_tls_probes_are_masked_verbatim() {
const N: usize = 12;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let secret = [0x88u8; 16];
let mut probe = make_valid_tls_client_hello(&secret, 0);
probe[tls::TLS_DIGEST_POS + 1] ^= 0x01;
let accept_task = tokio::spawn({
let probe = probe.clone();
async move {
for _ in 0..N {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; probe.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, probe);
}
}
});
let mut handlers = Vec::with_capacity(N);
for i in 0..N {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true;
cfg.access
.users
.insert("user".to_string(), "88888888888888888888888888888888".to_string());
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(8192);
let peer: SocketAddr = format!("198.51.100.{}:{}", 100 + i, 56000 + i)
.parse()
.unwrap();
let probe_bytes = probe.clone();
let h = tokio::spawn(async move {
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
));
client_side.write_all(&probe_bytes).await.unwrap();
drop(client_side);
tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap()
.unwrap();
});
handlers.push(h);
}
for h in handlers {
tokio::time::timeout(Duration::from_secs(5), h)
.await
.unwrap()
.unwrap();
}
tokio::time::timeout(Duration::from_secs(5), accept_task)
.await
.unwrap()
.unwrap();
}
#[test] #[test]
fn unexpected_eof_is_classified_without_string_matching() { fn unexpected_eof_is_classified_without_string_matching() {
let beobachten = BeobachtenStore::new(); let beobachten = BeobachtenStore::new();

View File

@ -3,7 +3,9 @@
#![allow(dead_code)] #![allow(dead_code)]
use std::net::SocketAddr; use std::net::SocketAddr;
use std::collections::HashSet;
use std::sync::Arc; use std::sync::Arc;
use std::sync::{Mutex, OnceLock};
use std::time::Duration; use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, warn, trace}; use tracing::{debug, warn, trace};
@ -20,6 +22,56 @@ use crate::config::ProxyConfig;
use crate::tls_front::{TlsFrontCache, emulator}; use crate::tls_front::{TlsFrontCache, emulator};
const ACCESS_SECRET_BYTES: usize = 16; const ACCESS_SECRET_BYTES: usize = 16;
static INVALID_SECRET_WARNED: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option<usize>) {
let key = format!("{}:{}", name, reason);
let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new()));
let should_warn = match warned.lock() {
Ok(mut guard) => guard.insert(key),
Err(_) => true,
};
if !should_warn {
return;
}
match got {
Some(actual) => {
warn!(
user = %name,
expected = expected,
got = actual,
"Skipping user: access secret has unexpected length"
);
}
None => {
warn!(
user = %name,
"Skipping user: access secret is not valid hex"
);
}
}
}
fn decode_user_secret(name: &str, secret_hex: &str) -> Option<Vec<u8>> {
match hex::decode(secret_hex) {
Ok(bytes) if bytes.len() == ACCESS_SECRET_BYTES => Some(bytes),
Ok(bytes) => {
warn_invalid_secret_once(
name,
"invalid_length",
ACCESS_SECRET_BYTES,
Some(bytes.len()),
);
None
}
Err(_) => {
warn_invalid_secret_once(name, "invalid_hex", ACCESS_SECRET_BYTES, None);
None
}
}
}
// Decide whether a client-supplied proto tag is allowed given the configured // Decide whether a client-supplied proto tag is allowed given the configured
// proxy modes and the transport that carried the handshake. // proxy modes and the transport that carried the handshake.
@ -51,8 +103,7 @@ fn decode_user_secrets(
if let Some(preferred) = preferred_user if let Some(preferred) = preferred_user
&& let Some(secret_hex) = config.access.users.get(preferred) && let Some(secret_hex) = config.access.users.get(preferred)
&& let Ok(bytes) = hex::decode(secret_hex) && let Some(bytes) = decode_user_secret(preferred, secret_hex)
&& bytes.len() == ACCESS_SECRET_BYTES
{ {
secrets.push((preferred.to_string(), bytes)); secrets.push((preferred.to_string(), bytes));
} }
@ -61,9 +112,7 @@ fn decode_user_secrets(
if preferred_user.is_some_and(|preferred| preferred == name.as_str()) { if preferred_user.is_some_and(|preferred| preferred == name.as_str()) {
continue; continue;
} }
if let Ok(bytes) = hex::decode(secret_hex) if let Some(bytes) = decode_user_secret(name, secret_hex) {
&& bytes.len() == ACCESS_SECRET_BYTES
{
secrets.push((name.clone(), bytes)); secrets.push((name.clone(), bytes));
} }
} }
@ -193,6 +242,9 @@ where
Some(b"h2".to_vec()) Some(b"h2".to_vec())
} else if alpn_list.iter().any(|p| p == b"http/1.1") { } else if alpn_list.iter().any(|p| p == b"http/1.1") {
Some(b"http/1.1".to_vec()) Some(b"http/1.1".to_vec())
} else if !alpn_list.is_empty() {
debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback");
return HandshakeResult::BadClient { reader, writer };
} else { } else {
None None
} }

View File

@ -1,6 +1,7 @@
use super::*; use super::*;
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use std::time::Duration; use std::sync::Arc;
use std::time::{Duration, Instant};
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> { fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
let session_id_len: usize = 32; let session_id_len: usize = 32;
@ -22,6 +23,64 @@ fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
handshake handshake
} }
fn make_valid_tls_client_hello_with_alpn(
secret: &[u8],
timestamp: u32,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
body.push(32);
body.extend_from_slice(&[0x42u8; 32]);
body.extend_from_slice(&2u16.to_be_bytes());
body.extend_from_slice(&[0x13, 0x01]);
body.push(1);
body.push(0);
let mut ext_blob = Vec::new();
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
alpn_list.push(proto.len() as u8);
alpn_list.extend_from_slice(proto);
}
let mut alpn_data = Vec::new();
alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
alpn_data.extend_from_slice(&alpn_list);
ext_blob.extend_from_slice(&0x0010u16.to_be_bytes());
ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&alpn_data);
}
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_blob);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut record = Vec::new();
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&[0x03, 0x01]);
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
record.extend_from_slice(&handshake);
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &record);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
record
}
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default(); let mut cfg = ProxyConfig::default();
cfg.access.users.clear(); cfg.access.users.clear();
@ -251,6 +310,51 @@ async fn tls_replay_second_identical_handshake_is_rejected() {
assert!(matches!(second, HandshakeResult::BadClient { .. })); assert!(matches!(second, HandshakeResult::BadClient { .. }));
} }
#[tokio::test]
async fn tls_replay_concurrent_identical_handshake_allows_exactly_one_success() {
let secret = [0x77u8; 16];
let config = Arc::new(test_config_with_secret_hex("77777777777777777777777777777777"));
let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60)));
let rng = Arc::new(SecureRandom::new());
let handshake = Arc::new(make_valid_tls_handshake(&secret, 0));
let mut tasks = Vec::new();
for _ in 0..50 {
let config = config.clone();
let replay_checker = replay_checker.clone();
let rng = rng.clone();
let handshake = handshake.clone();
tasks.push(tokio::spawn(async move {
handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
"127.0.0.1:45000".parse().unwrap(),
&config,
&replay_checker,
&rng,
None,
)
.await
}));
}
let mut success_count = 0usize;
for task in tasks {
let result = task.await.unwrap();
if matches!(result, HandshakeResult::Success(_)) {
success_count += 1;
} else {
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
}
assert_eq!(
success_count, 1,
"Concurrent replay attempts must allow exactly one successful handshake"
);
}
#[tokio::test] #[tokio::test]
async fn invalid_tls_probe_does_not_pollute_replay_cache() { async fn invalid_tls_probe_does_not_pollute_replay_cache() {
let config = test_config_with_secret_hex("11111111111111111111111111111111"); let config = test_config_with_secret_hex("11111111111111111111111111111111");
@ -387,6 +491,158 @@ async fn mixed_secret_lengths_keep_valid_user_authenticating() {
assert!(matches!(result, HandshakeResult::Success(_))); assert!(matches!(result, HandshakeResult::Success(_)));
} }
#[tokio::test]
async fn alpn_enforce_rejects_unsupported_client_alpn() {
let secret = [0x33u8; 16];
let mut config = test_config_with_secret_hex("33333333333333333333333333333333");
config.censorship.alpn_enforce = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "127.0.0.1:44327".parse().unwrap();
let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]);
let result = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
#[tokio::test]
async fn alpn_enforce_accepts_h2() {
let secret = [0x44u8; 16];
let mut config = test_config_with_secret_hex("44444444444444444444444444444444");
config.censorship.alpn_enforce = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "127.0.0.1:44328".parse().unwrap();
let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h2", b"h3"]);
let result = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::Success(_)));
}
#[tokio::test]
async fn malformed_tls_classes_complete_within_bounded_time() {
let secret = [0x55u8; 16];
let mut config = test_config_with_secret_hex("55555555555555555555555555555555");
config.censorship.alpn_enforce = true;
let replay_checker = ReplayChecker::new(512, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "127.0.0.1:44329".parse().unwrap();
let too_short = vec![0x16, 0x03, 0x01];
let mut bad_hmac = make_valid_tls_handshake(&secret, 0);
bad_hmac[tls::TLS_DIGEST_POS] ^= 0x01;
let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]);
for probe in [too_short, bad_hmac, alpn_mismatch] {
let result = tokio::time::timeout(
Duration::from_millis(200),
handle_tls_handshake(
&probe,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
),
)
.await
.expect("Malformed TLS classes must be rejected within bounded time");
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
}
#[tokio::test]
async fn malformed_tls_classes_share_close_latency_buckets() {
const ITER: usize = 24;
const BUCKET_MS: u128 = 10;
let secret = [0x99u8; 16];
let mut config = test_config_with_secret_hex("99999999999999999999999999999999");
config.censorship.alpn_enforce = true;
let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "127.0.0.1:44330".parse().unwrap();
let too_short = vec![0x16, 0x03, 0x01];
let mut bad_hmac = make_valid_tls_handshake(&secret, 0);
bad_hmac[tls::TLS_DIGEST_POS + 1] ^= 0x01;
let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]);
let mut class_means_ms = Vec::new();
for probe in [too_short, bad_hmac, alpn_mismatch] {
let mut sum_micros: u128 = 0;
for _ in 0..ITER {
let started = Instant::now();
let result = handle_tls_handshake(
&probe,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
let elapsed = started.elapsed();
assert!(matches!(result, HandshakeResult::BadClient { .. }));
sum_micros += elapsed.as_micros();
}
class_means_ms.push(sum_micros / ITER as u128 / 1_000);
}
let min_bucket = class_means_ms
.iter()
.map(|ms| ms / BUCKET_MS)
.min()
.unwrap();
let max_bucket = class_means_ms
.iter()
.map(|ms| ms / BUCKET_MS)
.max()
.unwrap();
assert!(
max_bucket <= min_bucket + 1,
"Malformed TLS classes diverged across latency buckets: means_ms={:?}",
class_means_ms
);
}
#[test] #[test]
fn secure_tag_requires_tls_mode_on_tls_transport() { fn secure_tag_requires_tls_mode_on_tls_transport() {
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();

View File

@ -14,12 +14,41 @@ use crate::network::dns_overrides::resolve_socket_addr;
use crate::stats::beobachten::BeobachtenStore; use crate::stats::beobachten::BeobachtenStore;
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
#[cfg(not(test))]
const MASK_TIMEOUT: Duration = Duration::from_secs(5); const MASK_TIMEOUT: Duration = Duration::from_secs(5);
#[cfg(test)]
const MASK_TIMEOUT: Duration = Duration::from_millis(50);
/// Maximum duration for the entire masking relay. /// Maximum duration for the entire masking relay.
/// Limits resource consumption from slow-loris attacks and port scanners. /// Limits resource consumption from slow-loris attacks and port scanners.
#[cfg(not(test))]
const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60); const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60);
#[cfg(test)]
const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200);
const MASK_BUFFER_SIZE: usize = 8192; const MASK_BUFFER_SIZE: usize = 8192;
async fn write_proxy_header_with_timeout<W>(mask_write: &mut W, header: &[u8]) -> bool
where
W: AsyncWrite + Unpin,
{
match timeout(MASK_TIMEOUT, mask_write.write_all(header)).await {
Ok(Ok(())) => true,
Ok(Err(_)) => false,
Err(_) => {
debug!("Timeout writing proxy protocol header to mask backend");
false
}
}
}
async fn consume_client_data_with_timeout<R>(reader: R)
where
R: AsyncRead + Unpin,
{
if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)).await.is_err() {
debug!("Timed out while consuming client data on masking fallback path");
}
}
/// Detect client type based on initial data /// Detect client type based on initial data
fn detect_client_type(data: &[u8]) -> &'static str { fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request // Check for HTTP request
@ -71,7 +100,7 @@ where
if !config.censorship.mask { if !config.censorship.mask {
// Masking disabled, just consume data // Masking disabled, just consume data
consume_client_data(reader).await; consume_client_data_with_timeout(reader).await;
return; return;
} }
@ -107,7 +136,7 @@ where
} }
}; };
if let Some(header) = proxy_header { if let Some(header) = proxy_header {
if mask_write.write_all(&header).await.is_err() { if !write_proxy_header_with_timeout(&mut mask_write, &header).await {
return; return;
} }
} }
@ -117,11 +146,11 @@ where
} }
Ok(Err(e)) => { Ok(Err(e)) => {
debug!(error = %e, "Failed to connect to mask unix socket"); debug!(error = %e, "Failed to connect to mask unix socket");
consume_client_data(reader).await; consume_client_data_with_timeout(reader).await;
} }
Err(_) => { Err(_) => {
debug!("Timeout connecting to mask unix socket"); debug!("Timeout connecting to mask unix socket");
consume_client_data(reader).await; consume_client_data_with_timeout(reader).await;
} }
} }
return; return;
@ -166,7 +195,7 @@ where
let (mask_read, mut mask_write) = stream.into_split(); let (mask_read, mut mask_write) = stream.into_split();
if let Some(header) = proxy_header { if let Some(header) = proxy_header {
if mask_write.write_all(&header).await.is_err() { if !write_proxy_header_with_timeout(&mut mask_write, &header).await {
return; return;
} }
} }
@ -176,11 +205,11 @@ where
} }
Ok(Err(e)) => { Ok(Err(e)) => {
debug!(error = %e, "Failed to connect to mask host"); debug!(error = %e, "Failed to connect to mask host");
consume_client_data(reader).await; consume_client_data_with_timeout(reader).await;
} }
Err(_) => { Err(_) => {
debug!("Timeout connecting to mask host"); debug!("Timeout connecting to mask host");
consume_client_data(reader).await; consume_client_data_with_timeout(reader).await;
} }
} }
} }

View File

@ -1,5 +1,7 @@
use super::*; use super::*;
use crate::config::ProxyConfig; use crate::config::ProxyConfig;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{duplex, AsyncBufReadExt, BufReader}; use tokio::io::{duplex, AsyncBufReadExt, BufReader};
use tokio::net::TcpListener; use tokio::net::TcpListener;
#[cfg(unix)] #[cfg(unix)]
@ -484,3 +486,59 @@ async fn unix_socket_mask_path_forwards_probe_and_response() {
accept_task.await.unwrap(); accept_task.await.unwrap();
let _ = std::fs::remove_file(sock_path); let _ = std::fs::remove_file(sock_path);
} }
#[tokio::test]
async fn mask_disabled_slowloris_connection_is_closed_by_consume_timeout() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = false;
let peer: SocketAddr = "198.51.100.33:45455".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (_client_reader_side, client_reader) = duplex(256);
let (_client_visible_reader, client_visible_writer) = duplex(256);
let beobachten = BeobachtenStore::new();
let task = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
b"slowloris",
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
timeout(Duration::from_secs(1), task).await.unwrap().unwrap();
}
struct PendingWriter;
impl tokio::io::AsyncWrite for PendingWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn proxy_header_write_timeout_returns_false() {
let mut writer = PendingWriter;
let ok = write_proxy_header_with_timeout(&mut writer, b"PROXY UNKNOWN\r\n").await;
assert!(!ok, "Proxy header writes that never complete must time out");
}