use super::*; use crate::config::{UpstreamConfig, UpstreamType}; use crate::crypto::sha256_hmac; use crate::protocol::constants::{HANDSHAKE_LEN, MAX_TLS_CHUNK_SIZE, TLS_VERSION}; use crate::protocol::tls; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpListener; struct PipelineHarness { config: Arc, stats: Arc, upstream_manager: Arc, replay_checker: Arc, buffer_pool: Arc, rng: Arc, route_runtime: Arc, ip_tracker: Arc, beobachten: Arc, } fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness { let mut cfg = ProxyConfig::default(); cfg.general.beobachten = false; cfg.censorship.mask = true; cfg.censorship.mask_unix_sock = None; cfg.censorship.mask_host = Some("127.0.0.1".to_string()); cfg.censorship.mask_port = mask_port; cfg.censorship.mask_proxy_protocol = 0; cfg.access.ignore_time_skew = true; cfg.access .users .insert("user".to_string(), secret_hex.to_string()); let config = Arc::new(cfg); let stats = Arc::new(Stats::new()); let upstream_manager = Arc::new(UpstreamManager::new( vec![UpstreamConfig { upstream_type: UpstreamType::Direct { interface: None, bind_addresses: None, }, weight: 1, enabled: true, scopes: String::new(), selected_scope: String::new(), }], 1, 1, 1, 1, false, stats.clone(), )); PipelineHarness { config, stats, upstream_manager, replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), buffer_pool: Arc::new(BufferPool::new()), rng: Arc::new(SecureRandom::new()), route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), ip_tracker: Arc::new(UserIpTracker::new()), beobachten: Arc::new(BeobachtenStore::new()), } } fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); let total_len = 5 + tls_len; let mut handshake = vec![fill; total_len]; handshake[0] = 0x16; handshake[1] = 0x03; handshake[2] = 0x01; handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); let session_id_len: usize = 32; handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); let computed = sha256_hmac(secret, &handshake); let mut digest = computed; let ts = timestamp.to_le_bytes(); for i in 0..4 { digest[28 + i] ^= ts[i]; } handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] .copy_from_slice(&digest); handshake } fn wrap_tls_application_data(payload: &[u8]) -> Vec { let mut record = Vec::with_capacity(5 + payload.len()); record.push(0x17); record.extend_from_slice(&TLS_VERSION); record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); record.extend_from_slice(payload); record } fn wrap_tls_record(record_type: u8, payload: &[u8]) -> Vec { let mut record = Vec::with_capacity(5 + payload.len()); record.push(record_type); record.extend_from_slice(&TLS_VERSION); record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); record.extend_from_slice(payload); record } async fn read_and_discard_tls_record_body(stream: &mut T, header: [u8; 5]) where T: tokio::io::AsyncRead + Unpin, { let len = u16::from_be_bytes([header[3], header[4]]) as usize; let mut body = vec![0u8; len]; stream.read_exact(&mut body).await.unwrap(); } #[tokio::test] async fn tls_bad_mtproto_fallback_preserves_wire_and_backend_response() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let secret = [0x81u8; 16]; let client_hello = make_valid_tls_client_hello(&secret, 0, 600, 0x42); let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); let trailing_payload = b"masked-trailing-record".to_vec(); let trailing_record = wrap_tls_application_data(&trailing_payload); let backend_response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); let expected_client_hello = client_hello.clone(); let expected_trailing_record = trailing_record.clone(); let expected_response = backend_response.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got_hello = vec![0u8; expected_client_hello.len()]; stream.read_exact(&mut got_hello).await.unwrap(); assert_eq!(got_hello, expected_client_hello); let mut got_trailing = vec![0u8; expected_trailing_record.len()]; stream.read_exact(&mut got_trailing).await.unwrap(); assert_eq!(got_trailing, expected_trailing_record); stream.write_all(&expected_response).await.unwrap(); }); let harness = build_harness("81818181818181818181818181818181", backend_addr.port()); let (server_side, mut client_side) = duplex(131072); let peer: SocketAddr = "198.51.100.181:56001".parse().unwrap(); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } #[tokio::test] async fn tls_bad_mtproto_fallback_keeps_connects_bad_accounting() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let secret = [0x82u8; 16]; let client_hello = make_valid_tls_client_hello(&secret, 1, 600, 0x43); let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); let trailing_record = wrap_tls_application_data(b"x"); let expected_client_hello = client_hello.clone(); let expected_trailing_record = trailing_record.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got_hello = vec![0u8; expected_client_hello.len()]; stream.read_exact(&mut got_hello).await.unwrap(); assert_eq!(got_hello, expected_client_hello); let mut got_trailing = vec![0u8; expected_trailing_record.len()]; stream.read_exact(&mut got_trailing).await.unwrap(); assert_eq!(got_trailing, expected_trailing_record); }); let harness = build_harness("82828282828282828282828282828282", backend_addr.port()); let bad_before = harness.stats.get_connects_bad(); let (server_side, mut client_side) = duplex(65536); let peer: SocketAddr = "198.51.100.182:56002".parse().unwrap(); let stats_for_assert = harness.stats.clone(); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); let bad_after = stats_for_assert.get_connects_bad(); assert_eq!(bad_after, bad_before + 1, "connects_bad must increase exactly once for invalid MTProto after valid TLS"); } #[tokio::test] async fn tls_bad_mtproto_fallback_forwards_zero_length_tls_record_verbatim() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let secret = [0x83u8; 16]; let client_hello = make_valid_tls_client_hello(&secret, 2, 600, 0x44); let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); let trailing_record = wrap_tls_application_data(&[]); let expected_client_hello = client_hello.clone(); let expected_trailing_record = trailing_record.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got_hello = vec![0u8; expected_client_hello.len()]; stream.read_exact(&mut got_hello).await.unwrap(); assert_eq!(got_hello, expected_client_hello); let mut got_trailing = vec![0u8; expected_trailing_record.len()]; stream.read_exact(&mut got_trailing).await.unwrap(); assert_eq!(got_trailing, expected_trailing_record); }); let harness = build_harness("83838383838383838383838383838383", backend_addr.port()); let (server_side, mut client_side) = duplex(65536); let peer: SocketAddr = "198.51.100.183:56003".parse().unwrap(); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } #[tokio::test] async fn tls_bad_mtproto_fallback_forwards_max_tls_record_verbatim() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let secret = [0x84u8; 16]; let client_hello = make_valid_tls_client_hello(&secret, 3, 600, 0x45); let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); let trailing_payload = vec![0xAB; MAX_TLS_CHUNK_SIZE]; let trailing_record = wrap_tls_application_data(&trailing_payload); let expected_client_hello = client_hello.clone(); let expected_trailing_record = trailing_record.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got_hello = vec![0u8; expected_client_hello.len()]; stream.read_exact(&mut got_hello).await.unwrap(); assert_eq!(got_hello, expected_client_hello); let mut got_trailing = vec![0u8; expected_trailing_record.len()]; stream.read_exact(&mut got_trailing).await.unwrap(); assert_eq!(got_trailing, expected_trailing_record); }); let harness = build_harness("84848484848484848484848484848484", backend_addr.port()); let (server_side, mut client_side) = duplex(2 * 1024 * 1024); let peer: SocketAddr = "198.51.100.184:56004".parse().unwrap(); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } #[tokio::test] async fn tls_bad_mtproto_fallback_light_fuzz_tls_record_lengths_verbatim() { let lengths = [0usize, 1, 2, 3, 7, 15, 31, 63, 127, 255, 1024, 4096]; for (idx, payload_len) in lengths.iter().copied().enumerate() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let secret = [0x85u8; 16]; let client_hello = make_valid_tls_client_hello(&secret, idx as u32 + 4, 600, 0x46 + idx as u8); let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); let mut payload = vec![0u8; payload_len]; for (i, b) in payload.iter_mut().enumerate() { *b = ((idx as u8).wrapping_mul(29)).wrapping_add(i as u8); } let trailing_record = wrap_tls_application_data(&payload); let expected_client_hello = client_hello.clone(); let expected_trailing_record = trailing_record.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got_hello = vec![0u8; expected_client_hello.len()]; stream.read_exact(&mut got_hello).await.unwrap(); assert_eq!(got_hello, expected_client_hello); let mut got_trailing = vec![0u8; expected_trailing_record.len()]; stream.read_exact(&mut got_trailing).await.unwrap(); assert_eq!(got_trailing, expected_trailing_record); }); let harness = build_harness("85858585858585858585858585858585", backend_addr.port()); let (server_side, mut client_side) = duplex(262144); let peer: SocketAddr = format!("198.51.100.185:{}", 56010 + idx as u16) .parse() .unwrap(); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } } #[tokio::test] async fn tls_bad_mtproto_fallback_concurrent_sessions_are_isolated() { let sessions = 24usize; let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let mut expected_pairs = std::collections::HashMap::new(); let secret = [0x86u8; 16]; for idx in 0..sessions { let hello = make_valid_tls_client_hello(&secret, idx as u32 + 100, 600, 0x60 + idx as u8); let payload = vec![idx as u8; 64 + idx]; let trailing = wrap_tls_application_data(&payload); expected_pairs.insert(hello, trailing); } let accept_task = tokio::spawn(async move { let mut remaining = expected_pairs; for idx in 0..sessions { let (mut stream, _) = listener.accept().await.unwrap(); let _ = idx; let mut got_hello = vec![0u8; 605]; stream.read_exact(&mut got_hello).await.unwrap(); let expected_trailing = remaining .remove(&got_hello) .expect("unexpected client hello in concurrent isolation test"); let mut got_trailing = vec![0u8; expected_trailing.len()]; stream.read_exact(&mut got_trailing).await.unwrap(); assert_eq!(got_trailing, expected_trailing); } assert!(remaining.is_empty(), "all expected client sessions must be matched exactly once"); }); let mut client_tasks = Vec::with_capacity(sessions); for idx in 0..sessions { let harness = build_harness("86868686868686868686868686868686", backend_addr.port()); let secret = [0x86u8; 16]; let client_hello = make_valid_tls_client_hello(&secret, idx as u32 + 100, 600, 0x60 + idx as u8); let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); let trailing_payload = vec![idx as u8; 64 + idx]; let trailing_record = wrap_tls_application_data(&trailing_payload); let peer: SocketAddr = format!("198.51.100.186:{}", 57000 + idx as u16) .parse() .unwrap(); client_tasks.push(tokio::spawn(async move { let (server_side, mut client_side) = duplex(262144); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap(); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); })); } for task in client_tasks { task.await.unwrap(); } tokio::time::timeout(Duration::from_secs(5), accept_task) .await .unwrap() .unwrap(); } #[tokio::test] async fn tls_bad_mtproto_fallback_forwards_fragmented_client_writes_verbatim() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let secret = [0x87u8; 16]; let client_hello = make_valid_tls_client_hello(&secret, 9, 600, 0x57); let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); let payload = b"fragmented-writes-to-test-stream-boundary-robustness".to_vec(); let trailing_record = wrap_tls_application_data(&payload); let expected_client_hello = client_hello.clone(); let expected_trailing_record = trailing_record.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got_hello = vec![0u8; expected_client_hello.len()]; stream.read_exact(&mut got_hello).await.unwrap(); assert_eq!(got_hello, expected_client_hello); let mut got_trailing = vec![0u8; expected_trailing_record.len()]; stream.read_exact(&mut got_trailing).await.unwrap(); assert_eq!(got_trailing, expected_trailing_record); }); let harness = build_harness("87878787878787878787878787878787", backend_addr.port()); let (server_side, mut client_side) = duplex(262144); let peer: SocketAddr = "198.51.100.187:56087".parse().unwrap(); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); client_side.write_all(&invalid_mtproto_record).await.unwrap(); for chunk in trailing_record.chunks(3) { client_side.write_all(chunk).await.unwrap(); } tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } #[tokio::test] async fn tls_bad_mtproto_fallback_header_fragmentation_bytewise_is_verbatim() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let secret = [0x88u8; 16]; let client_hello = make_valid_tls_client_hello(&secret, 10, 600, 0x58); let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let trailing_record = wrap_tls_application_data(b"bytewise-header"); let expected_hello = client_hello.clone(); let expected_trailing = trailing_record.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got_hello = vec![0u8; expected_hello.len()]; stream.read_exact(&mut got_hello).await.unwrap(); assert_eq!(got_hello, expected_hello); let mut got_trailing = vec![0u8; expected_trailing.len()]; stream.read_exact(&mut got_trailing).await.unwrap(); assert_eq!(got_trailing, expected_trailing); }); let harness = build_harness("88888888888888888888888888888888", backend_addr.port()); let (server_side, mut client_side) = duplex(131072); let peer: SocketAddr = "198.51.100.188:56088".parse().unwrap(); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); client_side.write_all(&invalid_mtproto_record).await.unwrap(); for b in trailing_record.iter().copied() { client_side.write_all(&[b]).await.unwrap(); } tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } #[tokio::test] async fn tls_bad_mtproto_fallback_record_splitting_chaos_is_verbatim() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let secret = [0x89u8; 16]; let client_hello = make_valid_tls_client_hello(&secret, 11, 600, 0x59); let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let mut payload = vec![0u8; 2048]; for (i, b) in payload.iter_mut().enumerate() { *b = (i as u8).wrapping_mul(17).wrapping_add(3); } let trailing_record = wrap_tls_application_data(&payload); let expected_hello = client_hello.clone(); let expected_trailing = trailing_record.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got_hello = vec![0u8; expected_hello.len()]; stream.read_exact(&mut got_hello).await.unwrap(); assert_eq!(got_hello, expected_hello); let mut got_trailing = vec![0u8; expected_trailing.len()]; stream.read_exact(&mut got_trailing).await.unwrap(); assert_eq!(got_trailing, expected_trailing); }); let harness = build_harness("89898989898989898989898989898989", backend_addr.port()); let (server_side, mut client_side) = duplex(262144); let peer: SocketAddr = "198.51.100.189:56089".parse().unwrap(); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); client_side.write_all(&invalid_mtproto_record).await.unwrap(); let chaos = [7usize, 1, 19, 3, 5, 31, 2, 11, 13, 17]; let mut pos = 0usize; let mut idx = 0usize; while pos < trailing_record.len() { let step = chaos[idx % chaos.len()]; let end = (pos + step).min(trailing_record.len()); client_side.write_all(&trailing_record[pos..end]).await.unwrap(); pos = end; idx += 1; } tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } #[tokio::test] async fn tls_bad_mtproto_fallback_multiple_tls_records_are_forwarded_in_order() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let secret = [0x8Au8; 16]; let client_hello = make_valid_tls_client_hello(&secret, 12, 600, 0x5A); let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let r1 = wrap_tls_application_data(b"alpha"); let r2 = wrap_tls_application_data(b"beta-beta"); let r3 = wrap_tls_application_data(b"gamma-gamma-gamma"); let expected = [r1.clone(), r2.clone(), r3.clone()].concat(); let expected_hello = client_hello.clone(); let expected_concat = expected.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got_hello = vec![0u8; expected_hello.len()]; stream.read_exact(&mut got_hello).await.unwrap(); assert_eq!(got_hello, expected_hello); let mut got = vec![0u8; expected_concat.len()]; stream.read_exact(&mut got).await.unwrap(); assert_eq!(got, expected_concat); }); let harness = build_harness("8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a", backend_addr.port()); let (server_side, mut client_side) = duplex(131072); let peer: SocketAddr = "198.51.100.190:56090".parse().unwrap(); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side.write_all(&r1).await.unwrap(); client_side.write_all(&r2).await.unwrap(); client_side.write_all(&r3).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } #[tokio::test] async fn tls_bad_mtproto_fallback_client_half_close_propagates_eof_to_backend() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let secret = [0x8Bu8; 16]; let client_hello = make_valid_tls_client_hello(&secret, 13, 600, 0x5B); let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let trailing_record = wrap_tls_application_data(b"half-close-probe"); let expected_hello = client_hello.clone(); let expected_trailing = trailing_record.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got_hello = vec![0u8; expected_hello.len()]; stream.read_exact(&mut got_hello).await.unwrap(); assert_eq!(got_hello, expected_hello); let mut got_trailing = vec![0u8; expected_trailing.len()]; stream.read_exact(&mut got_trailing).await.unwrap(); assert_eq!(got_trailing, expected_trailing); let mut tail = [0u8; 1]; let n = stream.read(&mut tail).await.unwrap(); assert_eq!(n, 0, "backend must observe EOF after client write half-close"); }); let harness = build_harness("8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b", backend_addr.port()); let (server_side, mut client_side) = duplex(131072); let peer: SocketAddr = "198.51.100.191:56091".parse().unwrap(); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap(); client_side.shutdown().await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } #[tokio::test] async fn tls_bad_mtproto_fallback_backend_half_close_after_response_is_tolerated() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let secret = [0x8Cu8; 16]; let client_hello = make_valid_tls_client_hello(&secret, 14, 600, 0x5C); let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let trailing_record = wrap_tls_application_data(b"backend-half-close"); let backend_response = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); let expected_hello = client_hello.clone(); let expected_trailing = trailing_record.clone(); let response = backend_response.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got_hello = vec![0u8; expected_hello.len()]; stream.read_exact(&mut got_hello).await.unwrap(); assert_eq!(got_hello, expected_hello); let mut got_trailing = vec![0u8; expected_trailing.len()]; stream.read_exact(&mut got_trailing).await.unwrap(); assert_eq!(got_trailing, expected_trailing); stream.write_all(&response).await.unwrap(); stream.shutdown().await.unwrap(); }); let harness = build_harness("8c8c8c8c8c8c8c8c8c8c8c8c8c8c8c8c", backend_addr.port()); let (server_side, mut client_side) = duplex(131072); let peer: SocketAddr = "198.51.100.192:56092".parse().unwrap(); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } #[tokio::test] async fn tls_bad_mtproto_fallback_backend_reset_after_clienthello_is_handled() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let secret = [0x8Du8; 16]; let client_hello = make_valid_tls_client_hello(&secret, 15, 600, 0x5D); let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let trailing_record = wrap_tls_application_data(b"backend-reset"); let expected_hello = client_hello.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got_hello = vec![0u8; expected_hello.len()]; stream.read_exact(&mut got_hello).await.unwrap(); assert_eq!(got_hello, expected_hello); drop(stream); }); let harness = build_harness("8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d", backend_addr.port()); let (server_side, mut client_side) = duplex(131072); let peer: SocketAddr = "198.51.100.193:56093".parse().unwrap(); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); client_side.write_all(&invalid_mtproto_record).await.unwrap(); let write_res = client_side.write_all(&trailing_record).await; assert!( write_res.is_ok() || write_res.is_err(), "write completion is environment dependent under backend reset" ); tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } #[tokio::test] async fn tls_bad_mtproto_fallback_backend_slow_reader_preserves_byte_identity() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let secret = [0x8Eu8; 16]; let client_hello = make_valid_tls_client_hello(&secret, 16, 600, 0x5E); let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let payload = vec![0xEC; 8192]; let trailing_record = wrap_tls_application_data(&payload); let expected_hello = client_hello.clone(); let expected_trailing = trailing_record.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got_hello = vec![0u8; expected_hello.len()]; stream.read_exact(&mut got_hello).await.unwrap(); assert_eq!(got_hello, expected_hello); let mut got_trailing = vec![0u8; expected_trailing.len()]; let mut offset = 0usize; while offset < got_trailing.len() { let step = (offset % 97).max(1).min(got_trailing.len() - offset); stream .read_exact(&mut got_trailing[offset..offset + step]) .await .unwrap(); offset += step; tokio::time::sleep(Duration::from_millis(1)).await; } assert_eq!(got_trailing, expected_trailing); }); let harness = build_harness("8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e", backend_addr.port()); let (server_side, mut client_side) = duplex(262144); let peer: SocketAddr = "198.51.100.194:56094".parse().unwrap(); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(5), accept_task) .await .unwrap() .unwrap(); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } #[tokio::test] async fn tls_bad_mtproto_fallback_replay_pressure_masks_replay_without_serverhello() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let secret = [0x8Fu8; 16]; let replayed_hello = make_valid_tls_client_hello(&secret, 17, 600, 0x5F); let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let trailing_record = wrap_tls_application_data(b"first-session"); let expected_first = replayed_hello.clone(); let expected_second = replayed_hello.clone(); let expected_trailing = trailing_record.clone(); let accept_task = tokio::spawn(async move { let (mut s1, _) = listener.accept().await.unwrap(); let mut got1 = vec![0u8; expected_first.len()]; s1.read_exact(&mut got1).await.unwrap(); assert_eq!(got1, expected_first); let mut got1_tail = vec![0u8; expected_trailing.len()]; s1.read_exact(&mut got1_tail).await.unwrap(); assert_eq!(got1_tail, expected_trailing); drop(s1); let (mut s2, _) = listener.accept().await.unwrap(); let mut got2 = vec![0u8; expected_second.len()]; s2.read_exact(&mut got2).await.unwrap(); assert_eq!(got2, expected_second); }); let harness = build_harness("8f8f8f8f8f8f8f8f8f8f8f8f8f8f8f8f", backend_addr.port()); let stats_for_assert = harness.stats.clone(); let bad_before = stats_for_assert.get_connects_bad(); let run_session = |hello: Vec, send_mtproto: bool| { let (server_side, mut client_side) = duplex(131072); let config = harness.config.clone(); let stats = harness.stats.clone(); let upstream = harness.upstream_manager.clone(); let replay = harness.replay_checker.clone(); let pool = harness.buffer_pool.clone(); let rng = harness.rng.clone(); let route = harness.route_runtime.clone(); let ipt = harness.ip_tracker.clone(); let beob = harness.beobachten.clone(); let invalid_mtproto_record = invalid_mtproto_record.clone(); let trailing_record = trailing_record.clone(); async move { let handler = tokio::spawn(handle_client_stream( server_side, "198.51.100.195:56095".parse().unwrap(), config, stats, upstream, replay, pool, rng, None, route, None, ipt, beob, false, )); client_side.write_all(&hello).await.unwrap(); if send_mtproto { let mut head = [0u8; 5]; client_side.read_exact(&mut head).await.unwrap(); assert_eq!(head[0], 0x16); client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side.write_all(&trailing_record).await.unwrap(); } else { let mut one = [0u8; 1]; let no_server_hello = tokio::time::timeout( Duration::from_millis(300), client_side.read_exact(&mut one), ) .await; assert!( no_server_hello.is_err() || no_server_hello.unwrap().is_err(), "replayed TLS hello must not receive authenticated TLS ServerHello" ); } drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } }; run_session(replayed_hello.clone(), true).await; run_session(replayed_hello.clone(), false).await; tokio::time::timeout(Duration::from_secs(5), accept_task) .await .unwrap() .unwrap(); let bad_after = stats_for_assert.get_connects_bad(); assert!( bad_after >= bad_before + 2, "both invalid-mtproto and replayed-tls paths must increment bad connection accounting" ); } #[tokio::test] async fn tls_bad_mtproto_fallback_large_multi_record_chaos_under_backpressure() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let secret = [0x90u8; 16]; let client_hello = make_valid_tls_client_hello(&secret, 18, 600, 0x60); let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let a = wrap_tls_application_data(&vec![0xA1; 2048]); let b = wrap_tls_application_data(&vec![0xB2; 3072]); let c = wrap_tls_application_data(&vec![0xC3; 1536]); let expected = [a.clone(), b.clone(), c.clone()].concat(); let expected_hello = client_hello.clone(); let expected_payload = expected.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got_hello = vec![0u8; expected_hello.len()]; stream.read_exact(&mut got_hello).await.unwrap(); assert_eq!(got_hello, expected_hello); let mut got = vec![0u8; expected_payload.len()]; let mut pos = 0usize; while pos < got.len() { let step = (pos % 257).max(1).min(got.len() - pos); stream.read_exact(&mut got[pos..pos + step]).await.unwrap(); pos += step; tokio::time::sleep(Duration::from_millis(1)).await; } assert_eq!(got, expected_payload); }); let harness = build_harness("90909090909090909090909090909090", backend_addr.port()); let (server_side, mut client_side) = duplex(262144); let peer: SocketAddr = "198.51.100.196:56096".parse().unwrap(); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); client_side.write_all(&invalid_mtproto_record).await.unwrap(); let chaos = [5usize, 23, 11, 47, 3, 19, 29, 13, 7, 31]; for record in [&a, &b, &c] { let mut pos = 0usize; let mut idx = 0usize; while pos < record.len() { let step = chaos[idx % chaos.len()]; let end = (pos + step).min(record.len()); client_side.write_all(&record[pos..end]).await.unwrap(); pos = end; idx += 1; } } tokio::time::timeout(Duration::from_secs(5), accept_task) .await .unwrap() .unwrap(); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } #[tokio::test] async fn tls_bad_mtproto_fallback_interleaved_control_and_application_records_verbatim() { let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let secret = [0x91u8; 16]; let client_hello = make_valid_tls_client_hello(&secret, 19, 600, 0x61); let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let ccs = wrap_tls_record(0x14, &[0x01]); let app = wrap_tls_application_data(b"opaque"); let alert = wrap_tls_record(0x15, &[0x01, 0x00]); let expected = [ccs.clone(), app.clone(), alert.clone()].concat(); let expected_hello = client_hello.clone(); let expected_records = expected.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got_hello = vec![0u8; expected_hello.len()]; stream.read_exact(&mut got_hello).await.unwrap(); assert_eq!(got_hello, expected_hello); let mut got = vec![0u8; expected_records.len()]; stream.read_exact(&mut got).await.unwrap(); assert_eq!(got, expected_records); }); let harness = build_harness("91919191919191919191919191919191", backend_addr.port()); let (server_side, mut client_side) = duplex(131072); let peer: SocketAddr = "198.51.100.197:56097".parse().unwrap(); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut tls_response_head = [0u8; 5]; client_side.read_exact(&mut tls_response_head).await.unwrap(); assert_eq!(tls_response_head[0], 0x16); client_side.write_all(&invalid_mtproto_record).await.unwrap(); client_side.write_all(&ccs).await.unwrap(); client_side.write_all(&app).await.unwrap(); client_side.write_all(&alert).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await .unwrap() .unwrap(); drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); } #[tokio::test] async fn tls_bad_mtproto_fallback_many_short_sessions_with_chaos_no_cross_leak() { let sessions = 40usize; let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let backend_addr = listener.local_addr().unwrap(); let mut expected_pairs = std::collections::HashMap::new(); let secret = [0x92u8; 16]; for idx in 0..sessions { let hello = make_valid_tls_client_hello(&secret, idx as u32 + 200, 600, 0x70 + idx as u8); let payload = vec![idx as u8; 33 + (idx % 17)]; let record = wrap_tls_application_data(&payload); expected_pairs.insert(hello, record); } let accept_task = tokio::spawn(async move { let mut remaining = expected_pairs; for idx in 0..sessions { let (mut stream, _) = listener.accept().await.unwrap(); let _ = idx; let mut got_hello = vec![0u8; 605]; stream.read_exact(&mut got_hello).await.unwrap(); let expected_record = remaining .remove(&got_hello) .expect("unexpected client hello in short-session chaos test"); let mut got = vec![0u8; expected_record.len()]; stream.read_exact(&mut got).await.unwrap(); assert_eq!(got, expected_record); } assert!(remaining.is_empty(), "all expected sessions must be consumed exactly once"); }); let mut tasks = Vec::with_capacity(sessions); for idx in 0..sessions { let harness = build_harness("92929292929292929292929292929292", backend_addr.port()); let secret = [0x92u8; 16]; let client_hello = make_valid_tls_client_hello(&secret, idx as u32 + 200, 600, 0x70 + idx as u8); let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); let payload = vec![idx as u8; 33 + (idx % 17)]; let record = wrap_tls_application_data(&payload); let peer: SocketAddr = format!("198.51.100.198:{}", 58000 + idx as u16) .parse() .unwrap(); tasks.push(tokio::spawn(async move { let (server_side, mut client_side) = duplex(131072); let handler = tokio::spawn(handle_client_stream( server_side, peer, harness.config, harness.stats, harness.upstream_manager, harness.replay_checker, harness.buffer_pool, harness.rng, None, harness.route_runtime, None, harness.ip_tracker, harness.beobachten, false, )); client_side.write_all(&client_hello).await.unwrap(); let mut head = [0u8; 5]; client_side.read_exact(&mut head).await.unwrap(); assert_eq!(head[0], 0x16); client_side.write_all(&invalid_mtproto_record).await.unwrap(); for chunk in record.chunks((idx % 9) + 1) { client_side.write_all(chunk).await.unwrap(); } drop(client_side); let _ = tokio::time::timeout(Duration::from_secs(3), handler) .await .unwrap() .unwrap(); })); } for task in tasks { task.await.unwrap(); } tokio::time::timeout(Duration::from_secs(6), accept_task) .await .unwrap() .unwrap(); }