From e4a50f9286a8a5faba7725c90d5bbb6fad4e4234 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Mon, 16 Mar 2026 21:37:59 +0400 Subject: [PATCH] feat(tls): add boot time timestamp constant and validation for SNI hostnames - Introduced `BOOT_TIME_MAX_SECS` constant to define the maximum accepted boot-time timestamp. - Updated `validate_tls_handshake_at_time` to utilize the new boot time constant for timestamp validation. - Enhanced `extract_sni_from_client_hello` to validate SNI hostnames against specified criteria, rejecting invalid hostnames. - Added tests to ensure proper handling of boot time timestamps and SNI validation. feat(handshake): improve user secret decoding and ALPN enforcement - Refactored user secret decoding to provide better error handling and logging for invalid secrets. - Added tests for concurrent identical handshakes to ensure replay protection works as expected. - Implemented ALPN enforcement in handshake processing, rejecting unsupported protocols and allowing valid ones. fix(masking): implement timeout handling for masking operations - Added timeout handling for writing proxy headers and consuming client data in masking. - Adjusted timeout durations for testing to ensure faster feedback during unit tests. - Introduced tests to verify behavior when masking is disabled and when proxy header writes exceed the timeout. test(masking): add tests for slowloris connections and proxy header timeouts - Created tests to validate that slowloris connections are closed by consume timeout when masking is disabled. - Added a test for proxy header write timeout to ensure it returns false when the write operation does not complete. --- src/protocol/tls.rs | 37 ++- src/protocol/tls_security_tests.rs | 39 ++- src/proxy/client_security_tests.rs | 423 ++++++++++++++++++++++++++ src/proxy/handshake.rs | 62 +++- src/proxy/handshake_security_tests.rs | 258 +++++++++++++++- src/proxy/masking.rs | 43 ++- src/proxy/masking_security_tests.rs | 58 ++++ 7 files changed, 895 insertions(+), 25 deletions(-) diff --git a/src/protocol/tls.rs b/src/protocol/tls.rs index 33d28c4..5a5ef21 100644 --- a/src/protocol/tls.rs +++ b/src/protocol/tls.rs @@ -29,6 +29,8 @@ pub const TLS_DIGEST_HALF_LEN: usize = 16; /// Time skew limits for anti-replay (in seconds) pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after +/// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced. +pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60; // ============= Private Constants ============= @@ -364,7 +366,7 @@ fn validate_tls_handshake_at_time( if !ignore_time_skew { // Allow very small timestamps (boot time instead of unix time) // This is a quirk in some clients that use uptime instead of real time - let is_boot_time = timestamp < 60 * 60 * 24 * 1000; // < ~2.7 years in seconds + let is_boot_time = timestamp < BOOT_TIME_MAX_SECS; if !is_boot_time { let time_diff = now - i64::from(timestamp); if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) { @@ -563,7 +565,9 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { if name_type == 0 && name_len > 0 && let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len]) { - return Some(host.to_string()); + if is_valid_sni_hostname(host) { + return Some(host.to_string()); + } } sn_pos += name_len; } @@ -574,6 +578,35 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option { None } +fn is_valid_sni_hostname(host: &str) -> bool { + if host.is_empty() || host.len() > 253 { + return false; + } + if host.starts_with('.') || host.ends_with('.') { + return false; + } + if host.parse::().is_ok() { + return false; + } + + for label in host.split('.') { + if label.is_empty() || label.len() > 63 { + return false; + } + if label.starts_with('-') || label.ends_with('-') { + return false; + } + if !label + .bytes() + .all(|b| b.is_ascii_alphanumeric() || b == b'-') + { + return false; + } + } + + true +} + /// Extract ALPN protocol list from ClientHello, return in offered order. pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec> { let mut pos = 5; // after record header diff --git a/src/protocol/tls_security_tests.rs b/src/protocol/tls_security_tests.rs index 476f24a..4372af8 100644 --- a/src/protocol/tls_security_tests.rs +++ b/src/protocol/tls_security_tests.rs @@ -286,8 +286,8 @@ fn boot_time_timestamp_accepted_without_ignore_flag() { // Timestamps below the boot-time threshold are treated as client uptime, // not real wall-clock time. The proxy allows them regardless of skew. let secret = b"boot_time_test"; - // 86_400_000 / 2 is well below the boot-time threshold (~2.74 years worth of seconds). - let boot_ts: u32 = 86_400_000 / 2; + // Keep this safely below BOOT_TIME_MAX_SECS to assert bypass behavior. + let boot_ts: u32 = BOOT_TIME_MAX_SECS / 2; let handshake = make_valid_tls_handshake(secret, boot_ts); let secrets = vec![("u".to_string(), secret.to_vec())]; assert!( @@ -611,13 +611,13 @@ fn zero_length_session_id_accepted() { // Boot-time threshold — exact boundary precision // ------------------------------------------------------------------ -/// timestamp = 86_399_999 is the last value inside the boot-time window. +/// timestamp = BOOT_TIME_MAX_SECS - 1 is the last value inside the boot-time window. /// is_boot_time = true → skew check is skipped entirely → accepted even /// when `now` is far from the timestamp. #[test] fn timestamp_one_below_boot_threshold_bypasses_skew_check() { let secret = b"boot_last_value_test"; - let ts: u32 = 86_400_000 - 1; + let ts: u32 = BOOT_TIME_MAX_SECS - 1; let h = make_valid_tls_handshake(secret, ts); let secrets = vec![("u".to_string(), secret.to_vec())]; @@ -625,17 +625,17 @@ fn timestamp_one_below_boot_threshold_bypasses_skew_check() { // Boot-time bypass must prevent the skew check from running. assert!( validate_tls_handshake_at_time(&h, &secrets, false, 0).is_some(), - "ts=86_399_999 must bypass skew check regardless of now" + "ts=BOOT_TIME_MAX_SECS-1 must bypass skew check regardless of now" ); } -/// timestamp = 86_400_000 is the first value outside the boot-time window. +/// timestamp = BOOT_TIME_MAX_SECS is the first value outside the boot-time window. /// is_boot_time = false → skew check IS applied. Two sub-cases confirm this: /// once with now chosen so the skew passes (accepted) and once where it fails. #[test] fn timestamp_at_boot_threshold_triggers_skew_check() { let secret = b"boot_exact_value_test"; - let ts: u32 = 86_400_000; + let ts: u32 = BOOT_TIME_MAX_SECS; let h = make_valid_tls_handshake(secret, ts); let secrets = vec![("u".to_string(), secret.to_vec())]; @@ -643,14 +643,14 @@ fn timestamp_at_boot_threshold_triggers_skew_check() { let now_valid: i64 = ts as i64 + 50; assert!( validate_tls_handshake_at_time(&h, &secrets, false, now_valid).is_some(), - "ts=86_400_000 within skew window must be accepted via skew check" + "ts=BOOT_TIME_MAX_SECS within skew window must be accepted via skew check" ); // now = 0 → time_diff = -86_400_000, outside window → rejected. // If the boot-time bypass were wrongly applied here this would pass. assert!( validate_tls_handshake_at_time(&h, &secrets, false, 0).is_none(), - "ts=86_400_000 far from now must be rejected — no boot-time bypass" + "ts=BOOT_TIME_MAX_SECS far from now must be rejected — no boot-time bypass" ); } @@ -675,7 +675,7 @@ fn u32_max_timestamp_accepted_with_ignore_time_skew() { ); } -/// u32::MAX > 86_400_000 so the skew check runs. With any realistic `now` +/// u32::MAX > BOOT_TIME_MAX_SECS so the skew check runs. With any realistic `now` /// (~1.7 billion), time_diff = now - u32::MAX is deeply negative — far outside /// [-1200, 600] — so the handshake must be rejected without overflow. #[test] @@ -1109,6 +1109,25 @@ fn extract_sni_rejects_zero_length_host_name() { assert!(extract_sni_from_client_hello(&ch).is_none()); } +#[test] +fn extract_sni_rejects_raw_ipv4_literals() { + let ch = build_client_hello_with_exts(Vec::new(), "203.0.113.10"); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_invalid_label_characters() { + let ch = build_client_hello_with_exts(Vec::new(), "exa_mple.com"); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + +#[test] +fn extract_sni_rejects_oversized_label() { + let oversized = format!("{}.example.com", "a".repeat(64)); + let ch = build_client_hello_with_exts(Vec::new(), &oversized); + assert!(extract_sni_from_client_hello(&ch).is_none()); +} + #[test] fn extract_sni_rejects_when_extension_block_is_truncated() { let mut ext_blob = Vec::new(); diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index 46eba11..100763a 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -92,6 +92,71 @@ async fn short_tls_probe_is_masked_through_client_pipeline() { accept_task.await.unwrap(); } +#[tokio::test] +async fn partial_tls_header_stall_triggers_handshake_timeout() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "198.51.100.170:55201".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side + .write_all(&[0x16, 0x03, 0x01, 0x02, 0x00]) + .await + .unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::TgHandshakeTimeout))); +} + fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len: usize) -> Vec { assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); @@ -122,6 +187,66 @@ fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32) -> Vec { make_valid_tls_client_hello_with_len(secret, timestamp, 600) } +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(0x16); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + record +} + fn wrap_tls_application_data(payload: &[u8]) -> Vec { let mut record = Vec::with_capacity(5 + payload.len()); record.push(0x17); @@ -439,6 +564,304 @@ async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() { .unwrap(); } +#[tokio::test] +async fn alpn_mismatch_tls_probe_is_masked_through_client_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x66u8; 16]; + let probe = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.censorship.alpn_enforce = true; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "66666666666666666666666666666666".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(8192); + let peer: SocketAddr = "198.51.100.66:55211".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn invalid_hmac_tls_probe_is_masked_through_client_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x77u8; 16]; + let mut probe = make_valid_tls_client_hello(&secret, 0); + probe[tls::TLS_DIGEST_POS] ^= 0x01; + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "77777777777777777777777777777777".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(8192); + let peer: SocketAddr = "198.51.100.77:55212".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn burst_invalid_tls_probes_are_masked_verbatim() { + const N: usize = 12; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x88u8; 16]; + let mut probe = make_valid_tls_client_hello(&secret, 0); + probe[tls::TLS_DIGEST_POS + 1] ^= 0x01; + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + async move { + for _ in 0..N { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = vec![0u8; probe.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + } + } + }); + + let mut handlers = Vec::with_capacity(N); + for i in 0..N { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "88888888888888888888888888888888".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(8192); + let peer: SocketAddr = format!("198.51.100.{}:{}", 100 + i, 56000 + i) + .parse() + .unwrap(); + let probe_bytes = probe.clone(); + + let h = tokio::spawn(async move { + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe_bytes).await.unwrap(); + drop(client_side); + + tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap() + .unwrap(); + }); + handlers.push(h); + } + + for h in handlers { + tokio::time::timeout(Duration::from_secs(5), h) + .await + .unwrap() + .unwrap(); + } + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); +} + #[test] fn unexpected_eof_is_classified_without_string_matching() { let beobachten = BeobachtenStore::new(); diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index e7e4751..a97657d 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -3,7 +3,9 @@ #![allow(dead_code)] use std::net::SocketAddr; +use std::collections::HashSet; use std::sync::Arc; +use std::sync::{Mutex, OnceLock}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, warn, trace}; @@ -20,6 +22,56 @@ use crate::config::ProxyConfig; use crate::tls_front::{TlsFrontCache, emulator}; const ACCESS_SECRET_BYTES: usize = 16; +static INVALID_SECRET_WARNED: OnceLock>> = OnceLock::new(); + +fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option) { + let key = format!("{}:{}", name, reason); + let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new())); + let should_warn = match warned.lock() { + Ok(mut guard) => guard.insert(key), + Err(_) => true, + }; + + if !should_warn { + return; + } + + match got { + Some(actual) => { + warn!( + user = %name, + expected = expected, + got = actual, + "Skipping user: access secret has unexpected length" + ); + } + None => { + warn!( + user = %name, + "Skipping user: access secret is not valid hex" + ); + } + } +} + +fn decode_user_secret(name: &str, secret_hex: &str) -> Option> { + match hex::decode(secret_hex) { + Ok(bytes) if bytes.len() == ACCESS_SECRET_BYTES => Some(bytes), + Ok(bytes) => { + warn_invalid_secret_once( + name, + "invalid_length", + ACCESS_SECRET_BYTES, + Some(bytes.len()), + ); + None + } + Err(_) => { + warn_invalid_secret_once(name, "invalid_hex", ACCESS_SECRET_BYTES, None); + None + } + } +} // Decide whether a client-supplied proto tag is allowed given the configured // proxy modes and the transport that carried the handshake. @@ -51,8 +103,7 @@ fn decode_user_secrets( if let Some(preferred) = preferred_user && let Some(secret_hex) = config.access.users.get(preferred) - && let Ok(bytes) = hex::decode(secret_hex) - && bytes.len() == ACCESS_SECRET_BYTES + && let Some(bytes) = decode_user_secret(preferred, secret_hex) { secrets.push((preferred.to_string(), bytes)); } @@ -61,9 +112,7 @@ fn decode_user_secrets( if preferred_user.is_some_and(|preferred| preferred == name.as_str()) { continue; } - if let Ok(bytes) = hex::decode(secret_hex) - && bytes.len() == ACCESS_SECRET_BYTES - { + if let Some(bytes) = decode_user_secret(name, secret_hex) { secrets.push((name.clone(), bytes)); } } @@ -193,6 +242,9 @@ where Some(b"h2".to_vec()) } else if alpn_list.iter().any(|p| p == b"http/1.1") { Some(b"http/1.1".to_vec()) + } else if !alpn_list.is_empty() { + debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback"); + return HandshakeResult::BadClient { reader, writer }; } else { None } diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index c4a5ba6..da4aa26 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -1,6 +1,7 @@ use super::*; use crate::crypto::sha256_hmac; -use std::time::Duration; +use std::sync::Arc; +use std::time::{Duration, Instant}; fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { let session_id_len: usize = 32; @@ -22,6 +23,64 @@ fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { handshake } +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + record +} + fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { let mut cfg = ProxyConfig::default(); cfg.access.users.clear(); @@ -251,6 +310,51 @@ async fn tls_replay_second_identical_handshake_is_rejected() { assert!(matches!(second, HandshakeResult::BadClient { .. })); } +#[tokio::test] +async fn tls_replay_concurrent_identical_handshake_allows_exactly_one_success() { + let secret = [0x77u8; 16]; + let config = Arc::new(test_config_with_secret_hex("77777777777777777777777777777777")); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let rng = Arc::new(SecureRandom::new()); + let handshake = Arc::new(make_valid_tls_handshake(&secret, 0)); + + let mut tasks = Vec::new(); + for _ in 0..50 { + let config = config.clone(); + let replay_checker = replay_checker.clone(); + let rng = rng.clone(); + let handshake = handshake.clone(); + tasks.push(tokio::spawn(async move { + handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + "127.0.0.1:45000".parse().unwrap(), + &config, + &replay_checker, + &rng, + None, + ) + .await + })); + } + + let mut success_count = 0usize; + for task in tasks { + let result = task.await.unwrap(); + if matches!(result, HandshakeResult::Success(_)) { + success_count += 1; + } else { + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } + } + + assert_eq!( + success_count, 1, + "Concurrent replay attempts must allow exactly one successful handshake" + ); +} + #[tokio::test] async fn invalid_tls_probe_does_not_pollute_replay_cache() { let config = test_config_with_secret_hex("11111111111111111111111111111111"); @@ -387,6 +491,158 @@ async fn mixed_secret_lengths_keep_valid_user_authenticating() { assert!(matches!(result, HandshakeResult::Success(_))); } +#[tokio::test] +async fn alpn_enforce_rejects_unsupported_client_alpn() { + let secret = [0x33u8; 16]; + let mut config = test_config_with_secret_hex("33333333333333333333333333333333"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "127.0.0.1:44327".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn alpn_enforce_accepts_h2() { + let secret = [0x44u8; 16]; + let mut config = test_config_with_secret_hex("44444444444444444444444444444444"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "127.0.0.1:44328".parse().unwrap(); + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h2", b"h3"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); +} + +#[tokio::test] +async fn malformed_tls_classes_complete_within_bounded_time() { + let secret = [0x55u8; 16]; + let mut config = test_config_with_secret_hex("55555555555555555555555555555555"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(512, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "127.0.0.1:44329".parse().unwrap(); + + let too_short = vec![0x16, 0x03, 0x01]; + + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS] ^= 0x01; + + let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + for probe in [too_short, bad_hmac, alpn_mismatch] { + let result = tokio::time::timeout( + Duration::from_millis(200), + handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ), + ) + .await + .expect("Malformed TLS classes must be rejected within bounded time"); + + assert!(matches!(result, HandshakeResult::BadClient { .. })); + } +} + +#[tokio::test] +async fn malformed_tls_classes_share_close_latency_buckets() { + const ITER: usize = 24; + const BUCKET_MS: u128 = 10; + + let secret = [0x99u8; 16]; + let mut config = test_config_with_secret_hex("99999999999999999999999999999999"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(4096, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "127.0.0.1:44330".parse().unwrap(); + + let too_short = vec![0x16, 0x03, 0x01]; + + let mut bad_hmac = make_valid_tls_handshake(&secret, 0); + bad_hmac[tls::TLS_DIGEST_POS + 1] ^= 0x01; + + let alpn_mismatch = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + + let mut class_means_ms = Vec::new(); + for probe in [too_short, bad_hmac, alpn_mismatch] { + let mut sum_micros: u128 = 0; + for _ in 0..ITER { + let started = Instant::now(); + let result = handle_tls_handshake( + &probe, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = started.elapsed(); + assert!(matches!(result, HandshakeResult::BadClient { .. })); + sum_micros += elapsed.as_micros(); + } + + class_means_ms.push(sum_micros / ITER as u128 / 1_000); + } + + let min_bucket = class_means_ms + .iter() + .map(|ms| ms / BUCKET_MS) + .min() + .unwrap(); + let max_bucket = class_means_ms + .iter() + .map(|ms| ms / BUCKET_MS) + .max() + .unwrap(); + + assert!( + max_bucket <= min_bucket + 1, + "Malformed TLS classes diverged across latency buckets: means_ms={:?}", + class_means_ms + ); +} + #[test] fn secure_tag_requires_tls_mode_on_tls_transport() { let mut config = ProxyConfig::default(); diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index fd0b404..d7eaef8 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -14,12 +14,41 @@ use crate::network::dns_overrides::resolve_socket_addr; use crate::stats::beobachten::BeobachtenStore; use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder}; +#[cfg(not(test))] const MASK_TIMEOUT: Duration = Duration::from_secs(5); +#[cfg(test)] +const MASK_TIMEOUT: Duration = Duration::from_millis(50); /// Maximum duration for the entire masking relay. /// Limits resource consumption from slow-loris attacks and port scanners. +#[cfg(not(test))] const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60); +#[cfg(test)] +const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200); const MASK_BUFFER_SIZE: usize = 8192; +async fn write_proxy_header_with_timeout(mask_write: &mut W, header: &[u8]) -> bool +where + W: AsyncWrite + Unpin, +{ + match timeout(MASK_TIMEOUT, mask_write.write_all(header)).await { + Ok(Ok(())) => true, + Ok(Err(_)) => false, + Err(_) => { + debug!("Timeout writing proxy protocol header to mask backend"); + false + } + } +} + +async fn consume_client_data_with_timeout(reader: R) +where + R: AsyncRead + Unpin, +{ + if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)).await.is_err() { + debug!("Timed out while consuming client data on masking fallback path"); + } +} + /// Detect client type based on initial data fn detect_client_type(data: &[u8]) -> &'static str { // Check for HTTP request @@ -71,7 +100,7 @@ where if !config.censorship.mask { // Masking disabled, just consume data - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; return; } @@ -107,7 +136,7 @@ where } }; if let Some(header) = proxy_header { - if mask_write.write_all(&header).await.is_err() { + if !write_proxy_header_with_timeout(&mut mask_write, &header).await { return; } } @@ -117,11 +146,11 @@ where } Ok(Err(e)) => { debug!(error = %e, "Failed to connect to mask unix socket"); - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; } Err(_) => { debug!("Timeout connecting to mask unix socket"); - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; } } return; @@ -166,7 +195,7 @@ where let (mask_read, mut mask_write) = stream.into_split(); if let Some(header) = proxy_header { - if mask_write.write_all(&header).await.is_err() { + if !write_proxy_header_with_timeout(&mut mask_write, &header).await { return; } } @@ -176,11 +205,11 @@ where } Ok(Err(e)) => { debug!(error = %e, "Failed to connect to mask host"); - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; } Err(_) => { debug!("Timeout connecting to mask host"); - consume_client_data(reader).await; + consume_client_data_with_timeout(reader).await; } } } diff --git a/src/proxy/masking_security_tests.rs b/src/proxy/masking_security_tests.rs index 8e5e003..2fc6a79 100644 --- a/src/proxy/masking_security_tests.rs +++ b/src/proxy/masking_security_tests.rs @@ -1,5 +1,7 @@ use super::*; use crate::config::ProxyConfig; +use std::pin::Pin; +use std::task::{Context, Poll}; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; use tokio::net::TcpListener; #[cfg(unix)] @@ -484,3 +486,59 @@ async fn unix_socket_mask_path_forwards_probe_and_response() { accept_task.await.unwrap(); let _ = std::fs::remove_file(sock_path); } + +#[tokio::test] +async fn mask_disabled_slowloris_connection_is_closed_by_consume_timeout() { + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = false; + + let peer: SocketAddr = "198.51.100.33:45455".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (_client_reader_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + b"slowloris", + peer, + local_addr, + &config, + &beobachten, + ) + .await; + }); + + timeout(Duration::from_secs(1), task).await.unwrap().unwrap(); +} + +struct PendingWriter; + +impl tokio::io::AsyncWrite for PendingWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + Poll::Pending + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn proxy_header_write_timeout_returns_false() { + let mut writer = PendingWriter; + let ok = write_proxy_header_with_timeout(&mut writer, b"PROXY UNKNOWN\r\n").await; + assert!(!ok, "Proxy header writes that never complete must time out"); +}