use super::*; use crate::config::{UpstreamConfig, UpstreamType}; use crate::crypto::sha256_hmac; use crate::protocol::constants::{HANDSHAKE_LEN, TLS_RECORD_APPLICATION, TLS_VERSION}; use crate::protocol::tls; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; use tokio::time::{Duration, Instant}; struct Harness { config: Arc, stats: Arc, upstream_manager: Arc, replay_checker: Arc, buffer_pool: Arc, rng: Arc, route_runtime: Arc, ip_tracker: Arc, beobachten: Arc, } fn new_upstream_manager(stats: Arc) -> Arc { Arc::new(UpstreamManager::new( vec![UpstreamConfig { upstream_type: UpstreamType::Direct { interface: None, bind_addresses: None, }, weight: 1, enabled: true, scopes: String::new(), selected_scope: String::new(), }], 1, 1, 1, 1, false, stats, )) } fn build_harness(secret_hex: &str, mask_port: u16) -> Harness { let mut cfg = ProxyConfig::default(); cfg.general.beobachten = false; cfg.censorship.mask = true; cfg.censorship.mask_unix_sock = None; cfg.censorship.mask_host = Some("127.0.0.1".to_string()); cfg.censorship.mask_port = mask_port; cfg.censorship.mask_proxy_protocol = 0; cfg.access.ignore_time_skew = true; cfg.access .users .insert("user".to_string(), secret_hex.to_string()); let config = Arc::new(cfg); let stats = Arc::new(Stats::new()); Harness { config, stats: stats.clone(), upstream_manager: new_upstream_manager(stats), replay_checker: Arc::new(ReplayChecker::new(512, Duration::from_secs(60))), buffer_pool: Arc::new(BufferPool::new()), rng: Arc::new(SecureRandom::new()), route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), ip_tracker: Arc::new(UserIpTracker::new()), beobachten: Arc::new(BeobachtenStore::new()), } } fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { assert!( tls_len <= u16::MAX as usize, "TLS length must fit into record header" ); let total_len = 5 + tls_len; let mut handshake = vec![fill; total_len]; handshake[0] = 0x16; handshake[1] = 0x03; handshake[2] = 0x01; handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); let session_id_len: usize = 32; handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); let computed = sha256_hmac(secret, &handshake); let mut digest = computed; let ts = timestamp.to_le_bytes(); for i in 0..4 { digest[28 + i] ^= ts[i]; } handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] .copy_from_slice(&digest); handshake } fn wrap_tls_application_data(payload: &[u8]) -> Vec { let mut record = Vec::with_capacity(5 + payload.len()); record.push(TLS_RECORD_APPLICATION); record.extend_from_slice(&TLS_VERSION); record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); record.extend_from_slice(payload); record } async fn read_tls_record_body(stream: &mut T, header: [u8; 5]) where T: tokio::io::AsyncRead + Unpin, { let len = u16::from_be_bytes([header[3], header[4]]) as usize; let mut body = vec![0u8; len]; stream.read_exact(&mut body).await.unwrap(); } async fn run_tls_success_mtproto_fail_capture( secret_hex: &str, secret: [u8; 16], timestamp: u32, trailing_records: Vec>, ) -> Vec { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let expected_len = trailing_records.iter().map(Vec::len).sum::(); let expected_concat = trailing_records.concat(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got = vec![0u8; expected_len]; stream.read_exact(&mut got).await.unwrap(); got }); let harness = build_harness(secret_hex, backend_addr.port()); let client_hello = make_valid_tls_client_hello(&secret, timestamp, 600, 0x42); let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let (server_side, mut client_side) = duplex(262144); let handler = tokio::spawn(handle_client_stream( server_side, "198.51.100.210:56010".parse().unwrap(), harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side .read_exact(&mut tls_response_head) .await .unwrap(); assert_eq!(tls_response_head[0], 0x16); read_tls_record_body(&mut client_side, tls_response_head).await; client_side .write_all(&invalid_mtproto_record) .await .unwrap(); for record in trailing_records { client_side.write_all(&record).await.unwrap(); } let got = tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); assert_eq!(got, expected_concat); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); got } #[tokio::test] async fn masking_budget_survives_zero_handshake_timeout_with_delay() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let mut cfg = ProxyConfig::default(); cfg.general.beobachten = false; cfg.censorship.mask = true; cfg.censorship.mask_unix_sock = None; cfg.censorship.mask_host = Some("127.0.0.1".to_string()); cfg.censorship.mask_port = backend_addr.port(); cfg.timeouts.client_handshake = 0; cfg.censorship.server_hello_delay_min_ms = 720; cfg.censorship.server_hello_delay_max_ms = 720; let config = Arc::new(cfg); let stats = Arc::new(Stats::new()); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got = vec![0u8; 605]; stream.read_exact(&mut got).await.unwrap(); got }); let (server_side, mut client_side) = duplex(65536); let handler = tokio::spawn(handle_client_stream( server_side, "198.51.100.211:56011".parse().unwrap(), config, stats.clone(), new_upstream_manager(stats.clone()), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), None, Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), None, Arc::new(UserIpTracker::new()), Arc::new(BeobachtenStore::new()), false, )); let mut invalid_hello = vec![0u8; 605]; invalid_hello[0] = 0x16; invalid_hello[1] = 0x03; invalid_hello[2] = 0x01; invalid_hello[3..5].copy_from_slice(&600u16.to_be_bytes()); invalid_hello[5..].fill(0xA1); let started = Instant::now(); client_side.write_all(&invalid_hello).await.unwrap(); let _ = tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); client_side.shutdown().await.unwrap(); let result = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); assert!(result.is_ok()); assert_eq!(stats.get_handshake_timeouts(), 0); assert!(started.elapsed() >= Duration::from_millis(680)); } #[tokio::test] async fn tls_mtproto_fail_forwards_only_trailing_record() { let tail = wrap_tls_application_data(b"tail-only"); let got = run_tls_success_mtproto_fail_capture( "c1c1c1c1c1c1c1c1c1c1c1c1c1c1c1c1", [0xC1; 16], 1, vec![tail.clone()], ) .await; assert_eq!(got, tail); } #[tokio::test] async fn replayed_tls_hello_gets_no_serverhello_and_is_masked() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let harness = build_harness("c2c2c2c2c2c2c2c2c2c2c2c2c2c2c2c2", backend_addr.port()); let secret = [0xC2u8; 16]; let hello = make_valid_tls_client_hello(&secret, 2, 600, 0x41); let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let first_tail = wrap_tls_application_data(b"seed"); let expected_hello = hello.clone(); let expected_tail = first_tail.clone(); let accept_task = tokio::spawn(async move { let (mut s1, _) = listener.accept().await.unwrap(); let mut got_tail = vec![0u8; expected_tail.len()]; s1.read_exact(&mut got_tail).await.unwrap(); assert_eq!(got_tail, expected_tail); drop(s1); let (mut s2, _) = listener.accept().await.unwrap(); let mut got_hello = vec![0u8; expected_hello.len()]; s2.read_exact(&mut got_hello).await.unwrap(); assert_eq!(got_hello, expected_hello); }); let run_session = |send_mtproto: bool| { let (server_side, mut client_side) = duplex(131072); let config = harness.config.clone(); let stats = harness.stats.clone(); let upstream = harness.upstream_manager.clone(); let replay = harness.replay_checker.clone(); let pool = harness.buffer_pool.clone(); let rng = harness.rng.clone(); let route = harness.route_runtime.clone(); let ipt = harness.ip_tracker.clone(); let beob = harness.beobachten.clone(); let hello = hello.clone(); let invalid_mtproto_record = invalid_mtproto_record.clone(); let first_tail = first_tail.clone(); async move { let handler = tokio::spawn(handle_client_stream( server_side, "198.51.100.212:56012".parse().unwrap(), config, stats, upstream, replay, pool, rng, None, route, None, ipt, beob, false, )); client_side.write_all(&hello).await.unwrap(); if send_mtproto { let mut head = [0u8; 5]; client_side.read_exact(&mut head).await.unwrap(); assert_eq!(head[0], 0x16); read_tls_record_body(&mut client_side, head).await; client_side .write_all(&invalid_mtproto_record) .await .unwrap(); client_side.write_all(&first_tail).await.unwrap(); } else { let mut one = [0u8; 1]; let no_server_hello = tokio::time::timeout( Duration::from_millis(300), client_side.read_exact(&mut one), ) .await; assert!(no_server_hello.is_err() || no_server_hello.unwrap().is_err()); } client_side.shutdown().await.unwrap(); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } }; run_session(true).await; run_session(false).await; tokio::time::timeout(Duration::from_secs(5), accept_task) .await .unwrap() .unwrap(); } #[tokio::test] async fn connects_bad_increments_once_per_invalid_mtproto() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let harness = build_harness("c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3c3", backend_addr.port()); let stats = harness.stats.clone(); let bad_before = stats.get_connects_bad(); let tail = wrap_tls_application_data(b"accounting"); let expected_tail = tail.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got = vec![0u8; expected_tail.len()]; stream.read_exact(&mut got).await.unwrap(); assert_eq!(got, expected_tail); }); let hello = make_valid_tls_client_hello(&[0xC3; 16], 3, 600, 0x42); let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let (server_side, mut client_side) = duplex(131072); let handler = tokio::spawn(handle_client_stream( server_side, "198.51.100.213:56013".parse().unwrap(), harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&hello).await.unwrap(); let mut head = [0u8; 5]; client_side.read_exact(&mut head).await.unwrap(); read_tls_record_body(&mut client_side, head).await; client_side .write_all(&invalid_mtproto_record) .await .unwrap(); client_side.write_all(&tail).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); client_side.shutdown().await.unwrap(); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); assert_eq!(stats.get_connects_bad(), bad_before + 1); } #[tokio::test] async fn truncated_clienthello_forwards_only_seen_prefix() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let mut cfg = ProxyConfig::default(); cfg.general.beobachten = false; cfg.censorship.mask = true; cfg.censorship.mask_host = Some("127.0.0.1".to_string()); cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_unix_sock = None; let config = Arc::new(cfg); let stats = Arc::new(Stats::new()); let expected_prefix_len = 5 + 17; let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got = vec![0u8; expected_prefix_len]; stream.read_exact(&mut got).await.unwrap(); got }); let (server_side, mut client_side) = duplex(65536); let handler = tokio::spawn(handle_client_stream( server_side, "198.51.100.214:56014".parse().unwrap(), config, stats, new_upstream_manager(Arc::new(Stats::new())), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), None, Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), None, Arc::new(UserIpTracker::new()), Arc::new(BeobachtenStore::new()), false, )); let mut hello = vec![0u8; 5 + 17]; hello[0] = 0x16; hello[1] = 0x03; hello[2] = 0x01; hello[3..5].copy_from_slice(&600u16.to_be_bytes()); hello[5..].fill(0x55); client_side.write_all(&hello).await.unwrap(); client_side.shutdown().await.unwrap(); let got = tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); assert_eq!(got, hello); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } #[tokio::test] async fn out_of_bounds_tls_len_forwards_header_only() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let mut cfg = ProxyConfig::default(); cfg.general.beobachten = false; cfg.censorship.mask = true; cfg.censorship.mask_host = Some("127.0.0.1".to_string()); cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_unix_sock = None; let config = Arc::new(cfg); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got = [0u8; 5]; stream.read_exact(&mut got).await.unwrap(); got }); let (server_side, mut client_side) = duplex(8192); let handler = tokio::spawn(handle_client_stream( server_side, "198.51.100.215:56015".parse().unwrap(), config, Arc::new(Stats::new()), new_upstream_manager(Arc::new(Stats::new())), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), None, Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), None, Arc::new(UserIpTracker::new()), Arc::new(BeobachtenStore::new()), false, )); let hdr = [0x16, 0x03, 0x01, 0x42, 0x69]; client_side.write_all(&hdr).await.unwrap(); client_side.shutdown().await.unwrap(); let got = tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); assert_eq!(got, hdr); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } #[tokio::test] async fn non_tls_with_modes_disabled_is_masked() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let mut cfg = ProxyConfig::default(); cfg.general.beobachten = false; cfg.censorship.mask = true; cfg.censorship.mask_host = Some("127.0.0.1".to_string()); cfg.censorship.mask_port = backend_addr.port(); cfg.censorship.mask_unix_sock = None; cfg.general.modes.classic = false; cfg.general.modes.secure = false; let config = Arc::new(cfg); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got = [0u8; 5]; stream.read_exact(&mut got).await.unwrap(); got }); let (server_side, mut client_side) = duplex(8192); let handler = tokio::spawn(handle_client_stream( server_side, "198.51.100.216:56016".parse().unwrap(), config, Arc::new(Stats::new()), new_upstream_manager(Arc::new(Stats::new())), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), None, Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), None, Arc::new(UserIpTracker::new()), Arc::new(BeobachtenStore::new()), false, )); let probe = *b"HELLO"; client_side.write_all(&probe).await.unwrap(); client_side.shutdown().await.unwrap(); let got = tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); assert_eq!(got, probe); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } #[tokio::test] async fn concurrent_tls_mtproto_fail_sessions_are_isolated() { let sessions = 12usize; let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let mut expected = std::collections::HashSet::new(); for idx in 0..sessions { let payload = vec![idx as u8; 32 + idx]; expected.insert(wrap_tls_application_data(&payload)); } let accept_task = tokio::spawn(async move { let mut remaining = expected; for _ in 0..sessions { let (mut stream, _) = listener.accept().await.unwrap(); let mut header = [0u8; 5]; stream.read_exact(&mut header).await.unwrap(); assert_eq!(header[0], TLS_RECORD_APPLICATION); let len = u16::from_be_bytes([header[3], header[4]]) as usize; let mut record = vec![0u8; 5 + len]; record[..5].copy_from_slice(&header); stream.read_exact(&mut record[5..]).await.unwrap(); assert!(remaining.remove(&record)); } assert!(remaining.is_empty()); }); let mut tasks = Vec::with_capacity(sessions); for idx in 0..sessions { let secret_hex = "c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4c4"; let harness = build_harness(secret_hex, backend_addr.port()); let hello = make_valid_tls_client_hello(&[0xC4; 16], 20 + idx as u32, 600, 0x40 + idx as u8); let invalid_mtproto = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let trailing = wrap_tls_application_data(&vec![idx as u8; 32 + idx]); let peer: SocketAddr = format!("198.51.100.217:{}", 56100 + idx as u16) .parse() .unwrap(); tasks.push(tokio::spawn(async move { let (server_side, mut client_side) = duplex(131072); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&hello).await.unwrap(); let mut head = [0u8; 5]; client_side.read_exact(&mut head).await.unwrap(); read_tls_record_body(&mut client_side, head).await; client_side.write_all(&invalid_mtproto).await.unwrap(); client_side.write_all(&trailing).await.unwrap(); client_side.shutdown().await.unwrap(); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); })); } for task in tasks { task.await.unwrap(); } tokio::time::timeout(Duration::from_secs(6), accept_task) .await .unwrap() .unwrap(); } macro_rules! tail_length_case { ($name:ident, $hex:expr, $secret:expr, $ts:expr, $len:expr) => { #[tokio::test] async fn $name() { let mut payload = vec![0u8; $len]; for (i, b) in payload.iter_mut().enumerate() { *b = (i as u8).wrapping_mul(17).wrapping_add(5); } let record = wrap_tls_application_data(&payload); let got = run_tls_success_mtproto_fail_capture($hex, $secret, $ts, vec![record.clone()]) .await; assert_eq!(got, record); } }; } tail_length_case!( tail_len_1_preserved, "d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1d1", [0xD1; 16], 30, 1 ); tail_length_case!( tail_len_2_preserved, "d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2d2", [0xD2; 16], 31, 2 ); tail_length_case!( tail_len_3_preserved, "d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", [0xD3; 16], 32, 3 ); tail_length_case!( tail_len_7_preserved, "d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4d4", [0xD4; 16], 33, 7 ); tail_length_case!( tail_len_31_preserved, "d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5d5", [0xD5; 16], 34, 31 ); tail_length_case!( tail_len_127_preserved, "d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6d6", [0xD6; 16], 35, 127 ); tail_length_case!( tail_len_511_preserved, "d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7", [0xD7; 16], 36, 511 ); tail_length_case!( tail_len_1023_preserved, "d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8", [0xD8; 16], 37, 1023 );