From 5a16e68487525374164679d7a0cca0ac75552ef2 Mon Sep 17 00:00:00 2001 From: David Osipov Date: Mon, 16 Mar 2026 20:43:49 +0400 Subject: [PATCH] Enhance TLS record handling and security tests - Enforce TLS record length constraints in client handling to comply with RFC 8446, rejecting records outside the range of 512 to 16,384 bytes. - Update security tests to validate behavior for oversized and undersized TLS records, ensuring they are correctly masked or rejected. - Introduce new tests to verify the handling of TLS records in both generic and client handler pipelines. - Refactor handshake logic to enforce mode restrictions based on transport type, preventing misuse of secure tags. - Add tests for nonce generation and encryption consistency, ensuring correct behavior for different configurations. - Improve masking tests to ensure proper logging and detection of client types, including SSH and unknown probes. --- src/proxy/client.rs | 15 +- src/proxy/client_security_tests.rs | 640 +++++++++++++++++++++++++- src/proxy/handshake.rs | 41 +- src/proxy/handshake_security_tests.rs | 152 +++++- src/proxy/masking_security_tests.rs | 233 +++++++++- 5 files changed, 1060 insertions(+), 21 deletions(-) diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 0ef2cc6..ec99a47 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -151,8 +151,13 @@ where if is_tls { let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize; - if tls_len < 512 { - debug!(peer = %real_peer, tls_len = tls_len, "TLS handshake too short"); +// RFC 8446 §5.1 mandates that TLSPlaintext records must not exceed 2^14 + // bytes (16_384). A client claiming a larger record is non-compliant and + // may be an active probe attempting to force large allocations. + // + // Also enforce a minimum record size to avoid trivial/garbage probes. + if !(512..=MAX_TLS_RECORD_SIZE).contains(&tls_len) { + debug!(peer = %real_peer, tls_len = tls_len, max_tls_len = MAX_TLS_RECORD_SIZE, "TLS handshake length out of bounds"); stats.increment_connects_bad(); let (reader, writer) = tokio::io::split(stream); handle_bad_client( @@ -525,8 +530,10 @@ impl RunningClientHandler { debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake"); - if tls_len < 512 { - debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short"); + // See RFC 8446 §5.1: TLSPlaintext records must not exceed 16_384 bytes. + // Treat too-small or too-large lengths as active probes and mask them. + if !(512..=MAX_TLS_RECORD_SIZE).contains(&tls_len) { + debug!(peer = %peer, tls_len = tls_len, max_tls_len = MAX_TLS_RECORD_SIZE, "TLS handshake length out of bounds"); self.stats.increment_connects_bad(); let (reader, writer) = self.stream.into_split(); handle_bad_client( diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index 70930ea..46eba11 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -92,8 +92,9 @@ async fn short_tls_probe_is_masked_through_client_pipeline() { accept_task.await.unwrap(); } -fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32) -> Vec { - let tls_len: usize = 600; +fn make_valid_tls_client_hello_with_len(secret: &[u8], timestamp: u32, tls_len: usize) -> Vec { + assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + let total_len = 5 + tls_len; let mut handshake = vec![0x42u8; total_len]; @@ -117,6 +118,10 @@ fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32) -> Vec { handshake } +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32) -> Vec { + make_valid_tls_client_hello_with_len(secret, timestamp, 600) +} + fn wrap_tls_application_data(payload: &[u8]) -> Vec { let mut record = Vec::with_capacity(5 + payload.len()); record.push(0x17); @@ -629,3 +634,634 @@ async fn concurrent_limit_rejections_from_mixed_ips_leave_no_ip_footprint() { "No rollback should occur under concurrent rejection storms" ); } + +#[tokio::test] +async fn oversized_tls_record_is_masked_in_generic_stream_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = [ + 0x16, + 0x03, + 0x01, + (((MAX_TLS_RECORD_SIZE + 1) >> 8) & 0xff) as u8, + ((MAX_TLS_RECORD_SIZE + 1) & 0xff) as u8, + ]; + let backend_reply = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.123:55123".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); + + assert_eq!( + stats.get_connects_bad(), + bad_before + 1, + "Oversized TLS probe must be classified as bad" + ); +} + +#[tokio::test] +async fn oversized_tls_record_is_masked_in_client_handler_pipeline() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let probe = [ + 0x16, + 0x03, + 0x01, + (((MAX_TLS_RECORD_SIZE + 1) >> 8) & 0xff) as u8, + ((MAX_TLS_RECORD_SIZE + 1) & 0xff) as u8, + ]; + let backend_reply = b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&probe).await.unwrap(); + + let mut observed = vec![0u8; backend_reply.len()]; + client.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + tokio::time::timeout(Duration::from_secs(3), mask_accept_task) + .await + .unwrap() + .unwrap(); + + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_record_len_511_is_rejected_in_generic_stream_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = [0x16, 0x03, 0x01, 0x01, 0xff]; + let backend_reply = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let peer: SocketAddr = "203.0.113.130:55130".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&probe).await.unwrap(); + let mut observed = vec![0u8; backend_reply.len()]; + client_side.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + accept_task.await.unwrap(); + + assert_eq!( + stats.get_connects_bad(), + bad_before + 1, + "TLS record length 511 must be rejected" + ); +} + +#[tokio::test] +async fn tls_record_len_511_is_rejected_in_client_handler_pipeline() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let probe = [0x16, 0x03, 0x01, 0x01, 0xff]; + let backend_reply = b"HTTP/1.1 403 Forbidden\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let mask_accept_task = tokio::spawn({ + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = mask_listener.accept().await.unwrap(); + let mut got = [0u8; 5]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&probe).await.unwrap(); + + let mut observed = vec![0u8; backend_reply.len()]; + client.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + tokio::time::timeout(Duration::from_secs(3), mask_accept_task) + .await + .unwrap() + .unwrap(); + + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_record_len_16384_is_accepted_in_generic_stream_pipeline() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x55u8; 16]; + let client_hello = make_valid_tls_client_hello_with_len(&secret, 0, MAX_TLS_RECORD_SIZE); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "55555555555555555555555555555555".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.55:56055".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + config, + stats.clone(), + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut record_header = [0u8; 5]; + client_side.read_exact(&mut record_header).await.unwrap(); + assert_eq!(record_header[0], 0x16, "Valid max-length ClientHello must be accepted"); + + drop(client_side); + let handler_result = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + assert!(handler_result.is_err()); + + let no_mask_connect = tokio::time::timeout(Duration::from_millis(250), listener.accept()).await; + assert!( + no_mask_connect.is_err(), + "Valid max-length ClientHello must not trigger mask fallback" + ); + + assert_eq!( + bad_before, + stats.get_connects_bad(), + "Valid max-length ClientHello must not increment bad counter" + ); +} + +#[tokio::test] +async fn tls_record_len_16384_is_accepted_in_client_handler_pipeline() { + let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = mask_listener.local_addr().unwrap(); + + let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let front_addr = front_listener.local_addr().unwrap(); + + let secret = [0x66u8; 16]; + let client_hello = make_valid_tls_client_hello_with_len(&secret, 0, MAX_TLS_RECORD_SIZE); + + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = backend_addr.port(); + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), "66666666666666666666666666666666".to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let bad_before = stats.get_connects_bad(); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60))); + let buffer_pool = Arc::new(BufferPool::new()); + let rng = Arc::new(SecureRandom::new()); + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + let ip_tracker = Arc::new(UserIpTracker::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let server_task = { + let config = config.clone(); + let stats = stats.clone(); + let upstream_manager = upstream_manager.clone(); + let replay_checker = replay_checker.clone(); + let buffer_pool = buffer_pool.clone(); + let rng = rng.clone(); + let route_runtime = route_runtime.clone(); + let ip_tracker = ip_tracker.clone(); + let beobachten = beobachten.clone(); + + tokio::spawn(async move { + let (stream, peer) = front_listener.accept().await.unwrap(); + let real_peer_report = Arc::new(std::sync::Mutex::new(None)); + ClientHandler::new( + stream, + peer, + config, + stats, + upstream_manager, + replay_checker, + buffer_pool, + rng, + None, + route_runtime, + None, + ip_tracker, + beobachten, + false, + real_peer_report, + ) + .run() + .await + }) + }; + + let mut client = TcpStream::connect(front_addr).await.unwrap(); + client.write_all(&client_hello).await.unwrap(); + + let mut record_header = [0u8; 5]; + client.read_exact(&mut record_header).await.unwrap(); + assert_eq!(record_header[0], 0x16, "Valid max-length ClientHello must be accepted"); + + drop(client); + + let _ = tokio::time::timeout(Duration::from_secs(3), server_task) + .await + .unwrap() + .unwrap(); + + let no_mask_connect = tokio::time::timeout(Duration::from_millis(250), mask_listener.accept()).await; + assert!( + no_mask_connect.is_err(), + "Valid max-length ClientHello must not trigger mask fallback in ClientHandler path" + ); + + assert_eq!( + bad_before, + stats.get_connects_bad(), + "Valid max-length ClientHello must not increment bad counter" + ); +} diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 4e7b371..e7e4751 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -21,6 +21,28 @@ use crate::tls_front::{TlsFrontCache, emulator}; const ACCESS_SECRET_BYTES: usize = 16; +// Decide whether a client-supplied proto tag is allowed given the configured +// proxy modes and the transport that carried the handshake. +// +// A common mistake is to treat `modes.tls` and `modes.secure` as interchangeable +// even though they correspond to different transport profiles: `modes.tls` is +// for the TLS-fronted (EE-TLS) path, while `modes.secure` is for direct MTProto +// over TCP (DD). Enforcing this separation prevents an attacker from using a +// TLS-capable client to bypass the operator intent for the direct MTProto mode, +// and vice versa. +fn mode_enabled_for_proto(config: &ProxyConfig, proto_tag: ProtoTag, is_tls: bool) -> bool { + match proto_tag { + ProtoTag::Secure => { + if is_tls { + config.general.modes.tls + } else { + config.general.modes.secure + } + } + ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic, + } +} + fn decode_user_secrets( config: &ProxyConfig, preferred_user: Option<&str>, @@ -292,16 +314,7 @@ where None => continue, }; - let mode_ok = match proto_tag { - ProtoTag::Secure => { - if is_tls { - config.general.modes.tls || config.general.modes.secure - } else { - config.general.modes.secure || config.general.modes.tls - } - } - ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic, - }; + let mode_ok = mode_enabled_for_proto(config, proto_tag, is_tls); if !mode_ok { debug!(peer = %peer, user = %user, proto = ?proto_tag, "Mode not enabled"); @@ -324,8 +337,12 @@ where let encryptor = AesCtr::new(&enc_key, enc_iv); - // Apply replay tracking only after successful authentication to prevent - // unauthenticated probes from evicting legitimate replay-cache entries. +// Apply replay tracking only after successful authentication. + // + // This ordering prevents an attacker from producing invalid handshakes that + // still collide with a valid handshake's replay slot and thus evict a valid + // entry from the cache. We accept the cost of performing the full + // authentication check first to avoid poisoning the replay cache. if replay_checker.check_and_add_handshake(dec_prekey_iv) { warn!(peer = %peer, user = %user, "MTProto replay attack detected"); return HandshakeResult::BadClient { reader, writer }; diff --git a/src/proxy/handshake_security_tests.rs b/src/proxy/handshake_security_tests.rs index 58178d9..c4a5ba6 100644 --- a/src/proxy/handshake_security_tests.rs +++ b/src/proxy/handshake_security_tests.rs @@ -84,7 +84,7 @@ fn test_encrypt_tg_nonce() { } #[test] -fn test_handshake_success_zeroize_on_drop() { +fn test_handshake_success_drop_does_not_panic() { let success = HandshakeSuccess { user: "test".to_string(), dc_idx: 2, @@ -103,6 +103,118 @@ fn test_handshake_success_zeroize_on_drop() { drop(success); } +#[test] +fn test_generate_tg_nonce_enc_dec_material_is_consistent() { + let client_dec_key = [0x12u8; 32]; + let client_dec_iv = 0x11223344556677889900aabbccddeeffu128; + let client_enc_key = [0x34u8; 32]; + let client_enc_iv = 0xffeeddccbbaa00998877665544332211u128; + let rng = SecureRandom::new(); + + let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce( + ProtoTag::Secure, + 7, + &client_dec_key, + client_dec_iv, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; + let dec_key_iv: Vec = enc_key_iv.iter().rev().copied().collect(); + + let mut expected_tg_enc_key = [0u8; 32]; + expected_tg_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut expected_tg_enc_iv_arr = [0u8; IV_LEN]; + expected_tg_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let expected_tg_enc_iv = u128::from_be_bytes(expected_tg_enc_iv_arr); + + let mut expected_tg_dec_key = [0u8; 32]; + expected_tg_dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]); + let mut expected_tg_dec_iv_arr = [0u8; IV_LEN]; + expected_tg_dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]); + let expected_tg_dec_iv = u128::from_be_bytes(expected_tg_dec_iv_arr); + + assert_eq!(tg_enc_key, expected_tg_enc_key); + assert_eq!(tg_enc_iv, expected_tg_enc_iv); + assert_eq!(tg_dec_key, expected_tg_dec_key); + assert_eq!(tg_dec_iv, expected_tg_dec_iv); + assert_eq!( + i16::from_le_bytes([nonce[DC_IDX_POS], nonce[DC_IDX_POS + 1]]), + 7, + "Generated nonce must keep target dc index in protocol slot" + ); +} + +#[test] +fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() { + let client_dec_key = [0x22u8; 32]; + let client_dec_iv = 0x0102030405060708090a0b0c0d0e0f10u128; + let client_enc_key = [0xABu8; 32]; + let client_enc_iv = 0x11223344556677889900aabbccddeeffu128; + let rng = SecureRandom::new(); + + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 9, + &client_dec_key, + client_dec_iv, + &client_enc_key, + client_enc_iv, + &rng, + true, + ); + + let mut expected = Vec::with_capacity(KEY_LEN + IV_LEN); + expected.extend_from_slice(&client_enc_key); + expected.extend_from_slice(&client_enc_iv.to_be_bytes()); + expected.reverse(); + + assert_eq!(&nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN], expected.as_slice()); +} + +#[test] +fn test_encrypt_tg_nonce_with_ciphers_matches_manual_suffix_encryption() { + let client_dec_key = [0x42u8; 32]; + let client_dec_iv = 12345u128; + let client_enc_key = [0x24u8; 32]; + let client_enc_iv = 54321u128; + + let rng = SecureRandom::new(); + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_dec_key, + client_dec_iv, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let (encrypted, _, _) = encrypt_tg_nonce_with_ciphers(&nonce); + + let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; + let mut expected_enc_key = [0u8; 32]; + expected_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut expected_enc_iv_arr = [0u8; IV_LEN]; + expected_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let expected_enc_iv = u128::from_be_bytes(expected_enc_iv_arr); + + let mut manual_encryptor = AesCtr::new(&expected_enc_key, expected_enc_iv); + let manual = manual_encryptor.encrypt(&nonce); + + assert_eq!(encrypted.len(), HANDSHAKE_LEN); + assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]); + assert_eq!( + &encrypted[PROTO_TAG_POS..], + &manual[PROTO_TAG_POS..], + "Encrypted nonce suffix must match AES-CTR output with derived enc key/iv" + ); +} + #[tokio::test] async fn tls_replay_second_identical_handshake_is_rejected() { let secret = [0x11u8; 16]; @@ -274,3 +386,41 @@ async fn mixed_secret_lengths_keep_valid_user_authenticating() { assert!(matches!(result, HandshakeResult::Success(_))); } + +#[test] +fn secure_tag_requires_tls_mode_on_tls_transport() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = false; + config.general.modes.secure = true; + config.general.modes.tls = false; + + assert!( + !mode_enabled_for_proto(&config, ProtoTag::Secure, true), + "Secure tag over TLS must be rejected when tls mode is disabled" + ); + + config.general.modes.tls = true; + assert!( + mode_enabled_for_proto(&config, ProtoTag::Secure, true), + "Secure tag over TLS must be accepted when tls mode is enabled" + ); +} + +#[test] +fn secure_tag_requires_secure_mode_on_direct_transport() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = false; + config.general.modes.secure = false; + config.general.modes.tls = true; + + assert!( + !mode_enabled_for_proto(&config, ProtoTag::Secure, false), + "Secure tag without TLS must be rejected when secure mode is disabled" + ); + + config.general.modes.secure = true; + assert!( + mode_enabled_for_proto(&config, ProtoTag::Secure, false), + "Secure tag without TLS must be accepted when secure mode is enabled" + ); +} diff --git a/src/proxy/masking_security_tests.rs b/src/proxy/masking_security_tests.rs index 50ea8ed..8e5e003 100644 --- a/src/proxy/masking_security_tests.rs +++ b/src/proxy/masking_security_tests.rs @@ -2,6 +2,8 @@ use super::*; use crate::config::ProxyConfig; use tokio::io::{duplex, AsyncBufReadExt, BufReader}; use tokio::net::TcpListener; +#[cfg(unix)] +use tokio::net::UnixListener; use tokio::time::{timeout, Duration}; #[tokio::test] @@ -75,7 +77,8 @@ async fn tls_scanner_probe_keeps_http_like_fallback_surface() { }); let mut config = ProxyConfig::default(); - config.general.beobachten = false; + config.general.beobachten = true; + config.general.beobachten_minutes = 1; config.censorship.mask = true; config.censorship.mask_host = Some("127.0.0.1".to_string()); config.censorship.mask_port = backend_addr.port(); @@ -103,10 +106,58 @@ async fn tls_scanner_probe_keeps_http_like_fallback_surface() { let mut observed = vec![0u8; backend_reply.len()]; client_visible_reader.read_exact(&mut observed).await.unwrap(); assert_eq!(observed, backend_reply); - assert!(observed.starts_with(b"HTTP/")); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[TLS-scanner]")); + assert!(snapshot.contains("198.51.100.44-1")); accept_task.await.unwrap(); } +#[test] +fn detect_client_type_covers_ssh_port_scanner_and_unknown() { + assert_eq!(detect_client_type(b"SSH-2.0-OpenSSH_9.7"), "SSH"); + assert_eq!(detect_client_type(b"\x01\x02\x03"), "port-scanner"); + assert_eq!(detect_client_type(b"random-binary-payload"), "unknown"); +} + +#[tokio::test] +async fn beobachten_records_scanner_class_when_mask_is_disabled() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + config.censorship.mask = false; + + let peer: SocketAddr = "203.0.113.99:41234".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + let initial = b"SSH-2.0-probe"; + + let (mut client_reader_side, client_reader) = duplex(256); + let (_client_visible_reader, client_visible_writer) = duplex(256); + let beobachten = BeobachtenStore::new(); + + let task = tokio::spawn(async move { + handle_bad_client( + client_reader, + client_visible_writer, + initial, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + beobachten + }); + + client_reader_side.write_all(b"noise").await.unwrap(); + drop(client_reader_side); + + let beobachten = timeout(Duration::from_secs(3), task).await.unwrap().unwrap(); + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[SSH]")); + assert!(snapshot.contains("203.0.113.99-1")); +} + #[tokio::test] async fn backend_unavailable_falls_back_to_silent_consume() { let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); @@ -255,3 +306,181 @@ async fn proxy_protocol_v1_header_is_sent_before_probe() { assert_eq!(observed, backend_reply); accept_task.await.unwrap(); } + +#[tokio::test] +async fn proxy_protocol_v2_header_is_sent_before_probe() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut sig = [0u8; 12]; + stream.read_exact(&mut sig).await.unwrap(); + assert_eq!(&sig, b"\r\n\r\n\0\r\nQUIT\n"); + + let mut fixed = [0u8; 4]; + stream.read_exact(&mut fixed).await.unwrap(); + let addr_len = u16::from_be_bytes([fixed[2], fixed[3]]) as usize; + + let mut addr_block = vec![0u8; addr_len]; + stream.read_exact(&mut addr_block).await.unwrap(); + + let mut received_probe = vec![0u8; probe.len()]; + stream.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 2; + + let peer: SocketAddr = "203.0.113.18:50004".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + accept_task.await.unwrap(); +} + +#[tokio::test] +async fn proxy_protocol_v1_mixed_family_falls_back_to_unknown_header() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + let probe = b"GET /mix HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (stream, _) = listener.accept().await.unwrap(); + let mut reader = BufReader::new(stream); + + let mut header_line = Vec::new(); + reader.read_until(b'\n', &mut header_line).await.unwrap(); + let header_text = String::from_utf8(header_line).unwrap(); + assert_eq!(header_text, "PROXY UNKNOWN\r\n"); + + let mut received_probe = vec![0u8; probe.len()]; + reader.read_exact(&mut received_probe).await.unwrap(); + assert_eq!(received_probe, probe); + + let mut stream = reader.into_inner(); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_host = Some("127.0.0.1".to_string()); + config.censorship.mask_port = backend_addr.port(); + config.censorship.mask_unix_sock = None; + config.censorship.mask_proxy_protocol = 1; + + let peer: SocketAddr = "203.0.113.20:50006".parse().unwrap(); + let local_addr: SocketAddr = "[::1]:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + accept_task.await.unwrap(); +} + +#[cfg(unix)] +#[tokio::test] +async fn unix_socket_mask_path_forwards_probe_and_response() { + let sock_path = format!("/tmp/telemt-mask-test-{}-{}.sock", std::process::id(), rand::random::()); + let _ = std::fs::remove_file(&sock_path); + + let listener = UnixListener::bind(&sock_path).unwrap(); + let probe = b"GET /unix HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec(); + let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let accept_task = tokio::spawn({ + let probe = probe.clone(); + let backend_reply = backend_reply.clone(); + async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut received = vec![0u8; probe.len()]; + stream.read_exact(&mut received).await.unwrap(); + assert_eq!(received, probe); + stream.write_all(&backend_reply).await.unwrap(); + } + }); + + let mut config = ProxyConfig::default(); + config.general.beobachten = false; + config.censorship.mask = true; + config.censorship.mask_unix_sock = Some(sock_path.clone()); + config.censorship.mask_proxy_protocol = 0; + + let peer: SocketAddr = "203.0.113.30:50010".parse().unwrap(); + let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap(); + + let (client_reader, _client_writer) = duplex(256); + let (mut client_visible_reader, client_visible_writer) = duplex(2048); + + let beobachten = BeobachtenStore::new(); + handle_bad_client( + client_reader, + client_visible_writer, + &probe, + peer, + local_addr, + &config, + &beobachten, + ) + .await; + + let mut observed = vec![0u8; backend_reply.len()]; + client_visible_reader.read_exact(&mut observed).await.unwrap(); + assert_eq!(observed, backend_reply); + + accept_task.await.unwrap(); + let _ = std::fs::remove_file(sock_path); +}