mirror of https://github.com/telemt/telemt.git
feat(tls): add boot time timestamp constant and validation for SNI hostnames
- Introduced `BOOT_TIME_MAX_SECS` constant to define the maximum accepted boot-time timestamp. - Updated `validate_tls_handshake_at_time` to utilize the new boot time constant for timestamp validation. - Enhanced `extract_sni_from_client_hello` to validate SNI hostnames against specified criteria, rejecting invalid hostnames. - Added tests to ensure proper handling of boot time timestamps and SNI validation. feat(handshake): improve user secret decoding and ALPN enforcement - Refactored user secret decoding to provide better error handling and logging for invalid secrets. - Added tests for concurrent identical handshakes to ensure replay protection works as expected. - Implemented ALPN enforcement in handshake processing, rejecting unsupported protocols and allowing valid ones. fix(masking): implement timeout handling for masking operations - Added timeout handling for writing proxy headers and consuming client data in masking. - Adjusted timeout durations for testing to ensure faster feedback during unit tests. - Introduced tests to verify behavior when masking is disabled and when proxy header writes exceed the timeout. test(masking): add tests for slowloris connections and proxy header timeouts - Created tests to validate that slowloris connections are closed by consume timeout when masking is disabled. - Added a test for proxy header write timeout to ensure it returns false when the write operation does not complete.
This commit is contained in:
parent
213ce4555a
commit
e4a50f9286
|
|
@ -29,6 +29,8 @@ pub const TLS_DIGEST_HALF_LEN: usize = 16;
|
|||
/// Time skew limits for anti-replay (in seconds)
|
||||
pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before
|
||||
pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after
|
||||
/// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced.
|
||||
pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60;
|
||||
|
||||
// ============= Private Constants =============
|
||||
|
||||
|
|
@ -364,7 +366,7 @@ fn validate_tls_handshake_at_time(
|
|||
if !ignore_time_skew {
|
||||
// Allow very small timestamps (boot time instead of unix time)
|
||||
// This is a quirk in some clients that use uptime instead of real time
|
||||
let is_boot_time = timestamp < 60 * 60 * 24 * 1000; // < ~2.7 years in seconds
|
||||
let is_boot_time = timestamp < BOOT_TIME_MAX_SECS;
|
||||
if !is_boot_time {
|
||||
let time_diff = now - i64::from(timestamp);
|
||||
if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) {
|
||||
|
|
@ -563,8 +565,10 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
|
|||
if name_type == 0 && name_len > 0
|
||||
&& let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len])
|
||||
{
|
||||
if is_valid_sni_hostname(host) {
|
||||
return Some(host.to_string());
|
||||
}
|
||||
}
|
||||
sn_pos += name_len;
|
||||
}
|
||||
}
|
||||
|
|
@ -574,6 +578,35 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
|
|||
None
|
||||
}
|
||||
|
||||
fn is_valid_sni_hostname(host: &str) -> bool {
|
||||
if host.is_empty() || host.len() > 253 {
|
||||
return false;
|
||||
}
|
||||
if host.starts_with('.') || host.ends_with('.') {
|
||||
return false;
|
||||
}
|
||||
if host.parse::<std::net::IpAddr>().is_ok() {
|
||||
return false;
|
||||
}
|
||||
|
||||
for label in host.split('.') {
|
||||
if label.is_empty() || label.len() > 63 {
|
||||
return false;
|
||||
}
|
||||
if label.starts_with('-') || label.ends_with('-') {
|
||||
return false;
|
||||
}
|
||||
if !label
|
||||
.bytes()
|
||||
.all(|b| b.is_ascii_alphanumeric() || b == b'-')
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Extract ALPN protocol list from ClientHello, return in offered order.
|
||||
pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> {
|
||||
let mut pos = 5; // after record header
|
||||
|
|
|
|||
|
|
@ -286,8 +286,8 @@ fn boot_time_timestamp_accepted_without_ignore_flag() {
|
|||
// Timestamps below the boot-time threshold are treated as client uptime,
|
||||
// not real wall-clock time. The proxy allows them regardless of skew.
|
||||
let secret = b"boot_time_test";
|
||||
// 86_400_000 / 2 is well below the boot-time threshold (~2.74 years worth of seconds).
|
||||
let boot_ts: u32 = 86_400_000 / 2;
|
||||
// Keep this safely below BOOT_TIME_MAX_SECS to assert bypass behavior.
|
||||
let boot_ts: u32 = BOOT_TIME_MAX_SECS / 2;
|
||||
let handshake = make_valid_tls_handshake(secret, boot_ts);
|
||||
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||
assert!(
|
||||
|
|
@ -611,13 +611,13 @@ fn zero_length_session_id_accepted() {
|
|||
// Boot-time threshold — exact boundary precision
|
||||
// ------------------------------------------------------------------
|
||||
|
||||
/// timestamp = 86_399_999 is the last value inside the boot-time window.
|
||||
/// timestamp = BOOT_TIME_MAX_SECS - 1 is the last value inside the boot-time window.
|
||||
/// is_boot_time = true → skew check is skipped entirely → accepted even
|
||||
/// when `now` is far from the timestamp.
|
||||
#[test]
|
||||
fn timestamp_one_below_boot_threshold_bypasses_skew_check() {
|
||||
let secret = b"boot_last_value_test";
|
||||
let ts: u32 = 86_400_000 - 1;
|
||||
let ts: u32 = BOOT_TIME_MAX_SECS - 1;
|
||||
let h = make_valid_tls_handshake(secret, ts);
|
||||
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||
|
||||
|
|
@ -625,17 +625,17 @@ fn timestamp_one_below_boot_threshold_bypasses_skew_check() {
|
|||
// Boot-time bypass must prevent the skew check from running.
|
||||
assert!(
|
||||
validate_tls_handshake_at_time(&h, &secrets, false, 0).is_some(),
|
||||
"ts=86_399_999 must bypass skew check regardless of now"
|
||||
"ts=BOOT_TIME_MAX_SECS-1 must bypass skew check regardless of now"
|
||||
);
|
||||
}
|
||||
|
||||
/// timestamp = 86_400_000 is the first value outside the boot-time window.
|
||||
/// timestamp = BOOT_TIME_MAX_SECS is the first value outside the boot-time window.
|
||||
/// is_boot_time = false → skew check IS applied. Two sub-cases confirm this:
|
||||
/// once with now chosen so the skew passes (accepted) and once where it fails.
|
||||
#[test]
|
||||
fn timestamp_at_boot_threshold_triggers_skew_check() {
|
||||
let secret = b"boot_exact_value_test";
|
||||
let ts: u32 = 86_400_000;
|
||||
let ts: u32 = BOOT_TIME_MAX_SECS;
|
||||
let h = make_valid_tls_handshake(secret, ts);
|
||||
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||
|
||||
|
|
@ -643,14 +643,14 @@ fn timestamp_at_boot_threshold_triggers_skew_check() {
|
|||
let now_valid: i64 = ts as i64 + 50;
|
||||
assert!(
|
||||
validate_tls_handshake_at_time(&h, &secrets, false, now_valid).is_some(),
|
||||
"ts=86_400_000 within skew window must be accepted via skew check"
|
||||
"ts=BOOT_TIME_MAX_SECS within skew window must be accepted via skew check"
|
||||
);
|
||||
|
||||
// now = 0 → time_diff = -86_400_000, outside window → rejected.
|
||||
// If the boot-time bypass were wrongly applied here this would pass.
|
||||
assert!(
|
||||
validate_tls_handshake_at_time(&h, &secrets, false, 0).is_none(),
|
||||
"ts=86_400_000 far from now must be rejected — no boot-time bypass"
|
||||
"ts=BOOT_TIME_MAX_SECS far from now must be rejected — no boot-time bypass"
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -675,7 +675,7 @@ fn u32_max_timestamp_accepted_with_ignore_time_skew() {
|
|||
);
|
||||
}
|
||||
|
||||
/// u32::MAX > 86_400_000 so the skew check runs. With any realistic `now`
|
||||
/// u32::MAX > BOOT_TIME_MAX_SECS so the skew check runs. With any realistic `now`
|
||||
/// (~1.7 billion), time_diff = now - u32::MAX is deeply negative — far outside
|
||||
/// [-1200, 600] — so the handshake must be rejected without overflow.
|
||||
#[test]
|
||||
|
|
@ -1109,6 +1109,25 @@ fn extract_sni_rejects_zero_length_host_name() {
|
|||
assert!(extract_sni_from_client_hello(&ch).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_sni_rejects_raw_ipv4_literals() {
|
||||
let ch = build_client_hello_with_exts(Vec::new(), "203.0.113.10");
|
||||
assert!(extract_sni_from_client_hello(&ch).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_sni_rejects_invalid_label_characters() {
|
||||
let ch = build_client_hello_with_exts(Vec::new(), "exa_mple.com");
|
||||
assert!(extract_sni_from_client_hello(&ch).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_sni_rejects_oversized_label() {
|
||||
let oversized = format!("{}.example.com", "a".repeat(64));
|
||||
let ch = build_client_hello_with_exts(Vec::new(), &oversized);
|
||||
assert!(extract_sni_from_client_hello(&ch).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_sni_rejects_when_extension_block_is_truncated() {
|
||||
let mut ext_blob = Vec::new();
|
||||
|
|
|
|||
|
|
@ -92,6 +92,71 @@ async fn short_tls_probe_is_masked_through_client_pipeline() {
|
|||
accept_task.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn partial_tls_header_stall_triggers_handshake_timeout() {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = false;
|
||||
cfg.timeouts.client_handshake = 1;
|
||||
|
||||
let config = Arc::new(cfg);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
));
|
||||
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
let beobachten = Arc::new(BeobachtenStore::new());
|
||||
|
||||
let (server_side, mut client_side) = duplex(4096);
|
||||
let peer: SocketAddr = "198.51.100.170:55201".parse().unwrap();
|
||||
|
||||
let handler = tokio::spawn(handle_client_stream(
|
||||
server_side,
|
||||
peer,
|
||||
config,
|
||||
stats,
|
||||
upstream_manager,
|
||||
replay_checker,
|
||||
buffer_pool,
|
||||
rng,
|
||||
None,
|
||||
route_runtime,
|
||||
None,
|
||||
ip_tracker,
|
||||
beobachten,
|
||||
false,
|
||||
));
|
||||
|
||||
client_side
|
||||
.write_all(&[0x16, 0x03, 0x01, 0x02, 0x00])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = tokio::time::timeout(Duration::from_secs(3), handler)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert!(matches!(result, Err(ProxyError::TgHandshakeTimeout)));
|
||||
}
|
||||
|
||||
fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len: usize) -> Vec<u8> {
|
||||
assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header");
|
||||
|
||||
|
|
@ -122,6 +187,66 @@ fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32) -> Vec<u8> {
|
|||
make_valid_tls_client_hello_with_len(secret, timestamp, 600)
|
||||
}
|
||||
|
||||
fn make_valid_tls_client_hello_with_alpn(
|
||||
secret: &[u8],
|
||||
timestamp: u32,
|
||||
alpn_protocols: &[&[u8]],
|
||||
) -> Vec<u8> {
|
||||
let mut body = Vec::new();
|
||||
body.extend_from_slice(&TLS_VERSION);
|
||||
body.extend_from_slice(&[0u8; 32]);
|
||||
body.push(32);
|
||||
body.extend_from_slice(&[0x42u8; 32]);
|
||||
body.extend_from_slice(&2u16.to_be_bytes());
|
||||
body.extend_from_slice(&[0x13, 0x01]);
|
||||
body.push(1);
|
||||
body.push(0);
|
||||
|
||||
let mut ext_blob = Vec::new();
|
||||
if !alpn_protocols.is_empty() {
|
||||
let mut alpn_list = Vec::new();
|
||||
for proto in alpn_protocols {
|
||||
alpn_list.push(proto.len() as u8);
|
||||
alpn_list.extend_from_slice(proto);
|
||||
}
|
||||
|
||||
let mut alpn_data = Vec::new();
|
||||
alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
|
||||
alpn_data.extend_from_slice(&alpn_list);
|
||||
|
||||
ext_blob.extend_from_slice(&0x0010u16.to_be_bytes());
|
||||
ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes());
|
||||
ext_blob.extend_from_slice(&alpn_data);
|
||||
}
|
||||
|
||||
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
|
||||
body.extend_from_slice(&ext_blob);
|
||||
|
||||
let mut handshake = Vec::new();
|
||||
handshake.push(0x01);
|
||||
let body_len = (body.len() as u32).to_be_bytes();
|
||||
handshake.extend_from_slice(&body_len[1..4]);
|
||||
handshake.extend_from_slice(&body);
|
||||
|
||||
let mut record = Vec::new();
|
||||
record.push(0x16);
|
||||
record.extend_from_slice(&[0x03, 0x01]);
|
||||
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
|
||||
record.extend_from_slice(&handshake);
|
||||
|
||||
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
|
||||
let computed = sha256_hmac(secret, &record);
|
||||
let mut digest = computed;
|
||||
let ts = timestamp.to_le_bytes();
|
||||
for i in 0..4 {
|
||||
digest[28 + i] ^= ts[i];
|
||||
}
|
||||
|
||||
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
|
||||
.copy_from_slice(&digest);
|
||||
record
|
||||
}
|
||||
|
||||
fn wrap_tls_application_data(payload: &[u8]) -> Vec<u8> {
|
||||
let mut record = Vec::with_capacity(5 + payload.len());
|
||||
record.push(0x17);
|
||||
|
|
@ -439,6 +564,304 @@ async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() {
|
|||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn alpn_mismatch_tls_probe_is_masked_through_client_pipeline() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let backend_addr = listener.local_addr().unwrap();
|
||||
|
||||
let secret = [0x66u8; 16];
|
||||
let probe = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]);
|
||||
let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
|
||||
|
||||
let accept_task = tokio::spawn({
|
||||
let probe = probe.clone();
|
||||
let backend_reply = backend_reply.clone();
|
||||
async move {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let mut got = vec![0u8; probe.len()];
|
||||
stream.read_exact(&mut got).await.unwrap();
|
||||
assert_eq!(got, probe);
|
||||
stream.write_all(&backend_reply).await.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = false;
|
||||
cfg.censorship.mask = true;
|
||||
cfg.censorship.mask_unix_sock = None;
|
||||
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
cfg.censorship.mask_port = backend_addr.port();
|
||||
cfg.censorship.mask_proxy_protocol = 0;
|
||||
cfg.censorship.alpn_enforce = true;
|
||||
cfg.access.ignore_time_skew = true;
|
||||
cfg.access
|
||||
.users
|
||||
.insert("user".to_string(), "66666666666666666666666666666666".to_string());
|
||||
|
||||
let config = Arc::new(cfg);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
));
|
||||
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
let beobachten = Arc::new(BeobachtenStore::new());
|
||||
|
||||
let (server_side, mut client_side) = duplex(8192);
|
||||
let peer: SocketAddr = "198.51.100.66:55211".parse().unwrap();
|
||||
|
||||
let handler = tokio::spawn(handle_client_stream(
|
||||
server_side,
|
||||
peer,
|
||||
config,
|
||||
stats,
|
||||
upstream_manager,
|
||||
replay_checker,
|
||||
buffer_pool,
|
||||
rng,
|
||||
None,
|
||||
route_runtime,
|
||||
None,
|
||||
ip_tracker,
|
||||
beobachten,
|
||||
false,
|
||||
));
|
||||
|
||||
client_side.write_all(&probe).await.unwrap();
|
||||
let mut observed = vec![0u8; backend_reply.len()];
|
||||
client_side.read_exact(&mut observed).await.unwrap();
|
||||
assert_eq!(observed, backend_reply);
|
||||
|
||||
drop(client_side);
|
||||
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
accept_task.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_hmac_tls_probe_is_masked_through_client_pipeline() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let backend_addr = listener.local_addr().unwrap();
|
||||
|
||||
let secret = [0x77u8; 16];
|
||||
let mut probe = make_valid_tls_client_hello(&secret, 0);
|
||||
probe[tls::TLS_DIGEST_POS] ^= 0x01;
|
||||
|
||||
let accept_task = tokio::spawn({
|
||||
let probe = probe.clone();
|
||||
async move {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let mut got = vec![0u8; probe.len()];
|
||||
stream.read_exact(&mut got).await.unwrap();
|
||||
assert_eq!(got, probe);
|
||||
}
|
||||
});
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = false;
|
||||
cfg.censorship.mask = true;
|
||||
cfg.censorship.mask_unix_sock = None;
|
||||
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
cfg.censorship.mask_port = backend_addr.port();
|
||||
cfg.censorship.mask_proxy_protocol = 0;
|
||||
cfg.access.ignore_time_skew = true;
|
||||
cfg.access
|
||||
.users
|
||||
.insert("user".to_string(), "77777777777777777777777777777777".to_string());
|
||||
|
||||
let config = Arc::new(cfg);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
));
|
||||
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
let beobachten = Arc::new(BeobachtenStore::new());
|
||||
|
||||
let (server_side, mut client_side) = duplex(8192);
|
||||
let peer: SocketAddr = "198.51.100.77:55212".parse().unwrap();
|
||||
|
||||
let handler = tokio::spawn(handle_client_stream(
|
||||
server_side,
|
||||
peer,
|
||||
config,
|
||||
stats,
|
||||
upstream_manager,
|
||||
replay_checker,
|
||||
buffer_pool,
|
||||
rng,
|
||||
None,
|
||||
route_runtime,
|
||||
None,
|
||||
ip_tracker,
|
||||
beobachten,
|
||||
false,
|
||||
));
|
||||
|
||||
client_side.write_all(&probe).await.unwrap();
|
||||
tokio::time::timeout(Duration::from_secs(3), accept_task)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
drop(client_side);
|
||||
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn burst_invalid_tls_probes_are_masked_verbatim() {
|
||||
const N: usize = 12;
|
||||
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let backend_addr = listener.local_addr().unwrap();
|
||||
|
||||
let secret = [0x88u8; 16];
|
||||
let mut probe = make_valid_tls_client_hello(&secret, 0);
|
||||
probe[tls::TLS_DIGEST_POS + 1] ^= 0x01;
|
||||
|
||||
let accept_task = tokio::spawn({
|
||||
let probe = probe.clone();
|
||||
async move {
|
||||
for _ in 0..N {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let mut got = vec![0u8; probe.len()];
|
||||
stream.read_exact(&mut got).await.unwrap();
|
||||
assert_eq!(got, probe);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut handlers = Vec::with_capacity(N);
|
||||
for i in 0..N {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = false;
|
||||
cfg.censorship.mask = true;
|
||||
cfg.censorship.mask_unix_sock = None;
|
||||
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
cfg.censorship.mask_port = backend_addr.port();
|
||||
cfg.censorship.mask_proxy_protocol = 0;
|
||||
cfg.access.ignore_time_skew = true;
|
||||
cfg.access
|
||||
.users
|
||||
.insert("user".to_string(), "88888888888888888888888888888888".to_string());
|
||||
|
||||
let config = Arc::new(cfg);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
));
|
||||
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
let beobachten = Arc::new(BeobachtenStore::new());
|
||||
|
||||
let (server_side, mut client_side) = duplex(8192);
|
||||
let peer: SocketAddr = format!("198.51.100.{}:{}", 100 + i, 56000 + i)
|
||||
.parse()
|
||||
.unwrap();
|
||||
let probe_bytes = probe.clone();
|
||||
|
||||
let h = tokio::spawn(async move {
|
||||
let handler = tokio::spawn(handle_client_stream(
|
||||
server_side,
|
||||
peer,
|
||||
config,
|
||||
stats,
|
||||
upstream_manager,
|
||||
replay_checker,
|
||||
buffer_pool,
|
||||
rng,
|
||||
None,
|
||||
route_runtime,
|
||||
None,
|
||||
ip_tracker,
|
||||
beobachten,
|
||||
false,
|
||||
));
|
||||
|
||||
client_side.write_all(&probe_bytes).await.unwrap();
|
||||
drop(client_side);
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(3), handler)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
});
|
||||
handlers.push(h);
|
||||
}
|
||||
|
||||
for h in handlers {
|
||||
tokio::time::timeout(Duration::from_secs(5), h)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(5), accept_task)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unexpected_eof_is_classified_without_string_matching() {
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tracing::{debug, warn, trace};
|
||||
|
|
@ -20,6 +22,56 @@ use crate::config::ProxyConfig;
|
|||
use crate::tls_front::{TlsFrontCache, emulator};
|
||||
|
||||
const ACCESS_SECRET_BYTES: usize = 16;
|
||||
static INVALID_SECRET_WARNED: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
|
||||
|
||||
fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option<usize>) {
|
||||
let key = format!("{}:{}", name, reason);
|
||||
let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new()));
|
||||
let should_warn = match warned.lock() {
|
||||
Ok(mut guard) => guard.insert(key),
|
||||
Err(_) => true,
|
||||
};
|
||||
|
||||
if !should_warn {
|
||||
return;
|
||||
}
|
||||
|
||||
match got {
|
||||
Some(actual) => {
|
||||
warn!(
|
||||
user = %name,
|
||||
expected = expected,
|
||||
got = actual,
|
||||
"Skipping user: access secret has unexpected length"
|
||||
);
|
||||
}
|
||||
None => {
|
||||
warn!(
|
||||
user = %name,
|
||||
"Skipping user: access secret is not valid hex"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_user_secret(name: &str, secret_hex: &str) -> Option<Vec<u8>> {
|
||||
match hex::decode(secret_hex) {
|
||||
Ok(bytes) if bytes.len() == ACCESS_SECRET_BYTES => Some(bytes),
|
||||
Ok(bytes) => {
|
||||
warn_invalid_secret_once(
|
||||
name,
|
||||
"invalid_length",
|
||||
ACCESS_SECRET_BYTES,
|
||||
Some(bytes.len()),
|
||||
);
|
||||
None
|
||||
}
|
||||
Err(_) => {
|
||||
warn_invalid_secret_once(name, "invalid_hex", ACCESS_SECRET_BYTES, None);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decide whether a client-supplied proto tag is allowed given the configured
|
||||
// proxy modes and the transport that carried the handshake.
|
||||
|
|
@ -51,8 +103,7 @@ fn decode_user_secrets(
|
|||
|
||||
if let Some(preferred) = preferred_user
|
||||
&& let Some(secret_hex) = config.access.users.get(preferred)
|
||||
&& let Ok(bytes) = hex::decode(secret_hex)
|
||||
&& bytes.len() == ACCESS_SECRET_BYTES
|
||||
&& let Some(bytes) = decode_user_secret(preferred, secret_hex)
|
||||
{
|
||||
secrets.push((preferred.to_string(), bytes));
|
||||
}
|
||||
|
|
@ -61,9 +112,7 @@ fn decode_user_secrets(
|
|||
if preferred_user.is_some_and(|preferred| preferred == name.as_str()) {
|
||||
continue;
|
||||
}
|
||||
if let Ok(bytes) = hex::decode(secret_hex)
|
||||
&& bytes.len() == ACCESS_SECRET_BYTES
|
||||
{
|
||||
if let Some(bytes) = decode_user_secret(name, secret_hex) {
|
||||
secrets.push((name.clone(), bytes));
|
||||
}
|
||||
}
|
||||
|
|
@ -193,6 +242,9 @@ where
|
|||
Some(b"h2".to_vec())
|
||||
} else if alpn_list.iter().any(|p| p == b"http/1.1") {
|
||||
Some(b"http/1.1".to_vec())
|
||||
} else if !alpn_list.is_empty() {
|
||||
debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback");
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use super::*;
|
||||
use crate::crypto::sha256_hmac;
|
||||
use std::time::Duration;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
|
||||
let session_id_len: usize = 32;
|
||||
|
|
@ -22,6 +23,64 @@ fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
|
|||
handshake
|
||||
}
|
||||
|
||||
fn make_valid_tls_client_hello_with_alpn(
|
||||
secret: &[u8],
|
||||
timestamp: u32,
|
||||
alpn_protocols: &[&[u8]],
|
||||
) -> Vec<u8> {
|
||||
let mut body = Vec::new();
|
||||
body.extend_from_slice(&TLS_VERSION);
|
||||
body.extend_from_slice(&[0u8; 32]);
|
||||
body.push(32);
|
||||
body.extend_from_slice(&[0x42u8; 32]);
|
||||
body.extend_from_slice(&2u16.to_be_bytes());
|
||||
body.extend_from_slice(&[0x13, 0x01]);
|
||||
body.push(1);
|
||||
body.push(0);
|
||||
|
||||
let mut ext_blob = Vec::new();
|
||||
if !alpn_protocols.is_empty() {
|
||||
let mut alpn_list = Vec::new();
|
||||
for proto in alpn_protocols {
|
||||
alpn_list.push(proto.len() as u8);
|
||||
alpn_list.extend_from_slice(proto);
|
||||
}
|
||||
let mut alpn_data = Vec::new();
|
||||
alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
|
||||
alpn_data.extend_from_slice(&alpn_list);
|
||||
|
||||
ext_blob.extend_from_slice(&0x0010u16.to_be_bytes());
|
||||
ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes());
|
||||
ext_blob.extend_from_slice(&alpn_data);
|
||||
}
|
||||
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
|
||||
body.extend_from_slice(&ext_blob);
|
||||
|
||||
let mut handshake = Vec::new();
|
||||
handshake.push(0x01);
|
||||
let body_len = (body.len() as u32).to_be_bytes();
|
||||
handshake.extend_from_slice(&body_len[1..4]);
|
||||
handshake.extend_from_slice(&body);
|
||||
|
||||
let mut record = Vec::new();
|
||||
record.push(TLS_RECORD_HANDSHAKE);
|
||||
record.extend_from_slice(&[0x03, 0x01]);
|
||||
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
|
||||
record.extend_from_slice(&handshake);
|
||||
|
||||
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
|
||||
let computed = sha256_hmac(secret, &record);
|
||||
let mut digest = computed;
|
||||
let ts = timestamp.to_le_bytes();
|
||||
for i in 0..4 {
|
||||
digest[28 + i] ^= ts[i];
|
||||
}
|
||||
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
|
||||
.copy_from_slice(&digest);
|
||||
|
||||
record
|
||||
}
|
||||
|
||||
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.access.users.clear();
|
||||
|
|
@ -251,6 +310,51 @@ async fn tls_replay_second_identical_handshake_is_rejected() {
|
|||
assert!(matches!(second, HandshakeResult::BadClient { .. }));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_replay_concurrent_identical_handshake_allows_exactly_one_success() {
|
||||
let secret = [0x77u8; 16];
|
||||
let config = Arc::new(test_config_with_secret_hex("77777777777777777777777777777777"));
|
||||
let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60)));
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
let handshake = Arc::new(make_valid_tls_handshake(&secret, 0));
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
for _ in 0..50 {
|
||||
let config = config.clone();
|
||||
let replay_checker = replay_checker.clone();
|
||||
let rng = rng.clone();
|
||||
let handshake = handshake.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
handle_tls_handshake(
|
||||
&handshake,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
"127.0.0.1:45000".parse().unwrap(),
|
||||
&config,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}));
|
||||
}
|
||||
|
||||
let mut success_count = 0usize;
|
||||
for task in tasks {
|
||||
let result = task.await.unwrap();
|
||||
if matches!(result, HandshakeResult::Success(_)) {
|
||||
success_count += 1;
|
||||
} else {
|
||||
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
success_count, 1,
|
||||
"Concurrent replay attempts must allow exactly one successful handshake"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_tls_probe_does_not_pollute_replay_cache() {
|
||||
let config = test_config_with_secret_hex("11111111111111111111111111111111");
|
||||
|
|
@ -387,6 +491,158 @@ async fn mixed_secret_lengths_keep_valid_user_authenticating() {
|
|||
assert!(matches!(result, HandshakeResult::Success(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn alpn_enforce_rejects_unsupported_client_alpn() {
|
||||
let secret = [0x33u8; 16];
|
||||
let mut config = test_config_with_secret_hex("33333333333333333333333333333333");
|
||||
config.censorship.alpn_enforce = true;
|
||||
|
||||
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
let peer: SocketAddr = "127.0.0.1:44327".parse().unwrap();
|
||||
let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]);
|
||||
|
||||
let result = handle_tls_handshake(
|
||||
&handshake,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
peer,
|
||||
&config,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn alpn_enforce_accepts_h2() {
|
||||
let secret = [0x44u8; 16];
|
||||
let mut config = test_config_with_secret_hex("44444444444444444444444444444444");
|
||||
config.censorship.alpn_enforce = true;
|
||||
|
||||
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
let peer: SocketAddr = "127.0.0.1:44328".parse().unwrap();
|
||||
let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h2", b"h3"]);
|
||||
|
||||
let result = handle_tls_handshake(
|
||||
&handshake,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
peer,
|
||||
&config,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, HandshakeResult::Success(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn malformed_tls_classes_complete_within_bounded_time() {
|
||||
let secret = [0x55u8; 16];
|
||||
let mut config = test_config_with_secret_hex("55555555555555555555555555555555");
|
||||
config.censorship.alpn_enforce = true;
|
||||
|
||||
let replay_checker = ReplayChecker::new(512, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
let peer: SocketAddr = "127.0.0.1:44329".parse().unwrap();
|
||||
|
||||
let too_short = vec![0x16, 0x03, 0x01];
|
||||
|
||||
let mut bad_hmac = make_valid_tls_handshake(&secret, 0);
|
||||
bad_hmac[tls::TLS_DIGEST_POS] ^= 0x01;
|
||||
|
||||
let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]);
|
||||
|
||||
for probe in [too_short, bad_hmac, alpn_mismatch] {
|
||||
let result = tokio::time::timeout(
|
||||
Duration::from_millis(200),
|
||||
handle_tls_handshake(
|
||||
&probe,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
peer,
|
||||
&config,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("Malformed TLS classes must be rejected within bounded time");
|
||||
|
||||
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn malformed_tls_classes_share_close_latency_buckets() {
|
||||
const ITER: usize = 24;
|
||||
const BUCKET_MS: u128 = 10;
|
||||
|
||||
let secret = [0x99u8; 16];
|
||||
let mut config = test_config_with_secret_hex("99999999999999999999999999999999");
|
||||
config.censorship.alpn_enforce = true;
|
||||
|
||||
let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
let peer: SocketAddr = "127.0.0.1:44330".parse().unwrap();
|
||||
|
||||
let too_short = vec![0x16, 0x03, 0x01];
|
||||
|
||||
let mut bad_hmac = make_valid_tls_handshake(&secret, 0);
|
||||
bad_hmac[tls::TLS_DIGEST_POS + 1] ^= 0x01;
|
||||
|
||||
let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]);
|
||||
|
||||
let mut class_means_ms = Vec::new();
|
||||
for probe in [too_short, bad_hmac, alpn_mismatch] {
|
||||
let mut sum_micros: u128 = 0;
|
||||
for _ in 0..ITER {
|
||||
let started = Instant::now();
|
||||
let result = handle_tls_handshake(
|
||||
&probe,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
peer,
|
||||
&config,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
let elapsed = started.elapsed();
|
||||
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||
sum_micros += elapsed.as_micros();
|
||||
}
|
||||
|
||||
class_means_ms.push(sum_micros / ITER as u128 / 1_000);
|
||||
}
|
||||
|
||||
let min_bucket = class_means_ms
|
||||
.iter()
|
||||
.map(|ms| ms / BUCKET_MS)
|
||||
.min()
|
||||
.unwrap();
|
||||
let max_bucket = class_means_ms
|
||||
.iter()
|
||||
.map(|ms| ms / BUCKET_MS)
|
||||
.max()
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
max_bucket <= min_bucket + 1,
|
||||
"Malformed TLS classes diverged across latency buckets: means_ms={:?}",
|
||||
class_means_ms
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn secure_tag_requires_tls_mode_on_tls_transport() {
|
||||
let mut config = ProxyConfig::default();
|
||||
|
|
|
|||
|
|
@ -14,12 +14,41 @@ use crate::network::dns_overrides::resolve_socket_addr;
|
|||
use crate::stats::beobachten::BeobachtenStore;
|
||||
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
|
||||
|
||||
#[cfg(not(test))]
|
||||
const MASK_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
#[cfg(test)]
|
||||
const MASK_TIMEOUT: Duration = Duration::from_millis(50);
|
||||
/// Maximum duration for the entire masking relay.
|
||||
/// Limits resource consumption from slow-loris attacks and port scanners.
|
||||
#[cfg(not(test))]
|
||||
const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60);
|
||||
#[cfg(test)]
|
||||
const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200);
|
||||
const MASK_BUFFER_SIZE: usize = 8192;
|
||||
|
||||
async fn write_proxy_header_with_timeout<W>(mask_write: &mut W, header: &[u8]) -> bool
|
||||
where
|
||||
W: AsyncWrite + Unpin,
|
||||
{
|
||||
match timeout(MASK_TIMEOUT, mask_write.write_all(header)).await {
|
||||
Ok(Ok(())) => true,
|
||||
Ok(Err(_)) => false,
|
||||
Err(_) => {
|
||||
debug!("Timeout writing proxy protocol header to mask backend");
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn consume_client_data_with_timeout<R>(reader: R)
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)).await.is_err() {
|
||||
debug!("Timed out while consuming client data on masking fallback path");
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect client type based on initial data
|
||||
fn detect_client_type(data: &[u8]) -> &'static str {
|
||||
// Check for HTTP request
|
||||
|
|
@ -71,7 +100,7 @@ where
|
|||
|
||||
if !config.censorship.mask {
|
||||
// Masking disabled, just consume data
|
||||
consume_client_data(reader).await;
|
||||
consume_client_data_with_timeout(reader).await;
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -107,7 +136,7 @@ where
|
|||
}
|
||||
};
|
||||
if let Some(header) = proxy_header {
|
||||
if mask_write.write_all(&header).await.is_err() {
|
||||
if !write_proxy_header_with_timeout(&mut mask_write, &header).await {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -117,11 +146,11 @@ where
|
|||
}
|
||||
Ok(Err(e)) => {
|
||||
debug!(error = %e, "Failed to connect to mask unix socket");
|
||||
consume_client_data(reader).await;
|
||||
consume_client_data_with_timeout(reader).await;
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Timeout connecting to mask unix socket");
|
||||
consume_client_data(reader).await;
|
||||
consume_client_data_with_timeout(reader).await;
|
||||
}
|
||||
}
|
||||
return;
|
||||
|
|
@ -166,7 +195,7 @@ where
|
|||
|
||||
let (mask_read, mut mask_write) = stream.into_split();
|
||||
if let Some(header) = proxy_header {
|
||||
if mask_write.write_all(&header).await.is_err() {
|
||||
if !write_proxy_header_with_timeout(&mut mask_write, &header).await {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
|
@ -176,11 +205,11 @@ where
|
|||
}
|
||||
Ok(Err(e)) => {
|
||||
debug!(error = %e, "Failed to connect to mask host");
|
||||
consume_client_data(reader).await;
|
||||
consume_client_data_with_timeout(reader).await;
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Timeout connecting to mask host");
|
||||
consume_client_data(reader).await;
|
||||
consume_client_data_with_timeout(reader).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
use super::*;
|
||||
use crate::config::ProxyConfig;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{duplex, AsyncBufReadExt, BufReader};
|
||||
use tokio::net::TcpListener;
|
||||
#[cfg(unix)]
|
||||
|
|
@ -484,3 +486,59 @@ async fn unix_socket_mask_path_forwards_probe_and_response() {
|
|||
accept_task.await.unwrap();
|
||||
let _ = std::fs::remove_file(sock_path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mask_disabled_slowloris_connection_is_closed_by_consume_timeout() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = false;
|
||||
config.censorship.mask = false;
|
||||
|
||||
let peer: SocketAddr = "198.51.100.33:45455".parse().unwrap();
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
|
||||
let (_client_reader_side, client_reader) = duplex(256);
|
||||
let (_client_visible_reader, client_visible_writer) = duplex(256);
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
handle_bad_client(
|
||||
client_reader,
|
||||
client_visible_writer,
|
||||
b"slowloris",
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
timeout(Duration::from_secs(1), task).await.unwrap().unwrap();
|
||||
}
|
||||
|
||||
struct PendingWriter;
|
||||
|
||||
impl tokio::io::AsyncWrite for PendingWriter {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn proxy_header_write_timeout_returns_false() {
|
||||
let mut writer = PendingWriter;
|
||||
let ok = write_proxy_header_with_timeout(&mut writer, b"PROXY UNKNOWN\r\n").await;
|
||||
assert!(!ok, "Proxy header writes that never complete must time out");
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue