diff --git a/src/proxy/client.rs b/src/proxy/client.rs index bd02ac8..984c7b4 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -295,8 +295,16 @@ where ).await { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { + // MTProto failed after TLS ServerHello was already sent. + // Switch fallback relay back to raw transport so the mask + // backend receives valid TLS records (not unwrapped payload). + let reader = reader.into_inner(); + let writer = writer.into_inner(); stats.increment_connects_bad(); - debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); + debug!( + peer = %peer, + "Authenticated TLS session failed MTProto validation; engaging masking fallback" + ); handle_bad_client( reader, writer, @@ -708,8 +716,16 @@ impl RunningClientHandler { { HandshakeResult::Success(result) => result, HandshakeResult::BadClient { reader, writer } => { + // MTProto failed after TLS ServerHello was already sent. + // Switch fallback relay back to raw transport so the mask + // backend receives valid TLS records (not unwrapped payload). + let reader = reader.into_inner(); + let writer = writer.into_inner(); stats.increment_connects_bad(); - debug!(peer = %peer, "Valid TLS but invalid MTProto handshake"); + debug!( + peer = %peer, + "Authenticated TLS session failed MTProto validation; engaging masking fallback" + ); handle_bad_client( reader, writer, @@ -1044,3 +1060,7 @@ mod security_tests; #[cfg(test)] #[path = "client_adversarial_tests.rs"] mod adversarial_tests; + +#[cfg(test)] +#[path = "client_tls_mtproto_fallback_security_tests.rs"] +mod tls_mtproto_fallback_security_tests; diff --git a/src/proxy/client_adversarial_tests.rs b/src/proxy/client_adversarial_tests.rs index 80d65f2..37bc53d 100644 --- a/src/proxy/client_adversarial_tests.rs +++ b/src/proxy/client_adversarial_tests.rs @@ -4,6 +4,7 @@ use crate::stats::Stats; use crate::ip_tracker::UserIpTracker; use crate::error::ProxyError; use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; // ------------------------------------------------------------------ @@ -107,3 +108,357 @@ async fn client_ip_tracker_race_condition_stress() { assert_eq!(ip_tracker.get_active_ip_count(user).await, 0, "IP count must be zero after balanced add/remove burst"); } + +#[tokio::test] +async fn client_limit_burst_peak_never_exceeds_cap() { + let user = "peak-cap-user"; + let limit = 32; + let attempts = 256; + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), limit); + + let peak = Arc::new(AtomicU64::new(0)); + let mut tasks = Vec::with_capacity(attempts); + + for i in 0..attempts { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + let peak = Arc::clone(&peak); + + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(203, 0, 113, (i % 250 + 1) as u8)), + 20000 + i as u16, + ); + + let acquired = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker, + ) + .await; + + if let Ok(reservation) = acquired { + let now = stats.get_user_curr_connects(user); + loop { + let prev = peak.load(Ordering::Relaxed); + if now <= prev { + break; + } + if peak + .compare_exchange(prev, now, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + break; + } + } + tokio::time::sleep(Duration::from_millis(2)).await; + drop(reservation); + } + })); + } + + futures::future::join_all(tasks).await; + ip_tracker.drain_cleanup_queue().await; + + assert!( + peak.load(Ordering::Relaxed) <= limit as u64, + "peak concurrent reservations must not exceed configured cap" + ); + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_quota_rejection_never_mutates_live_counters() { + let user = "quota-reject-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert(user.to_string(), 0); + + let peer: SocketAddr = "198.51.100.201:31111".parse().unwrap(); + let res = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + + assert!(matches!(res, Err(ProxyError::DataQuotaExceeded { .. }))); + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_expiration_rejection_never_mutates_live_counters() { + let user = "expired-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + + let mut config = ProxyConfig::default(); + config + .access + .user_expirations + .insert(user.to_string(), chrono::Utc::now() - chrono::Duration::seconds(1)); + + let peer: SocketAddr = "198.51.100.202:31112".parse().unwrap(); + let res = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + + assert!(matches!(res, Err(ProxyError::UserExpired { .. }))); + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_ip_limit_failure_rolls_back_counter_exactly() { + let user = "ip-limit-rollback-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 16); + + let first_peer: SocketAddr = "198.51.100.203:31113".parse().unwrap(); + let first = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + first_peer, + ip_tracker.clone(), + ) + .await + .unwrap(); + + let second_peer: SocketAddr = "198.51.100.204:31114".parse().unwrap(); + let second = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + second_peer, + ip_tracker.clone(), + ) + .await; + + assert!(matches!(second, Err(ProxyError::ConnectionLimitExceeded { .. }))); + assert_eq!(stats.get_user_curr_connects(user), 1); + + drop(first); + ip_tracker.drain_cleanup_queue().await; + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_parallel_limit_checks_success_path_leaves_no_residue() { + let user = "parallel-check-success-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 128).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 128); + + let mut tasks = Vec::new(); + for i in 0..128u16 { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(10, 10, (i / 255) as u8, (i % 255 + 1) as u8)), + 32000 + i, + ); + RunningClientHandler::check_user_limits_static(user, &config, &stats, peer, &ip_tracker) + .await + })); + } + + for result in futures::future::join_all(tasks).await { + assert!(result.unwrap().is_ok()); + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_parallel_limit_checks_failure_path_leaves_no_residue() { + let user = "parallel-check-failure-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 0).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 512); + + let mut tasks = Vec::new(); + for i in 0..64u16 { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(172, 16, 0, (i % 250 + 1) as u8)), 33000 + i); + RunningClientHandler::check_user_limits_static(user, &config, &stats, peer, &ip_tracker) + .await + })); + } + + let mut _denied = 0usize; + for result in futures::future::join_all(tasks).await { + match result.unwrap() { + Ok(()) => {} + Err(ProxyError::ConnectionLimitExceeded { .. }) => _denied += 1, + Err(other) => panic!("unexpected error: {other}"), + } + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_churn_mixed_success_failure_converges_to_zero_state() { + let user = "mixed-churn-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 4).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 8); + + let mut tasks = Vec::new(); + for i in 0..200u16 { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + + tasks.push(tokio::spawn(async move { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(192, 0, 2, (i % 16 + 1) as u8)), + 34000 + (i % 32), + ); + let maybe_res = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await; + + if let Ok(reservation) = maybe_res { + tokio::time::sleep(Duration::from_millis((i % 3) as u64)).await; + drop(reservation); + } + })); + } + + futures::future::join_all(tasks).await; + ip_tracker.drain_cleanup_queue().await; + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_same_ip_parallel_attempts_allow_at_most_one_when_limit_is_one() { + let user = "same-ip-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 1); + + let peer: SocketAddr = "203.0.113.44:35555".parse().unwrap(); + let mut tasks = Vec::new(); + + for _ in 0..64 { + let stats = Arc::clone(&stats); + let ip_tracker = Arc::clone(&ip_tracker); + let config = config.clone(); + tasks.push(tokio::spawn(async move { + RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await + })); + } + + let mut granted = 0usize; + let mut reservations = Vec::new(); + for result in futures::future::join_all(tasks).await { + match result.unwrap() { + Ok(reservation) => { + granted += 1; + reservations.push(reservation); + } + Err(ProxyError::ConnectionLimitExceeded { .. }) => {} + Err(other) => panic!("unexpected error: {other}"), + } + } + + assert_eq!(granted, 1, "only one reservation may be granted for same IP with limit=1"); + drop(reservations); + ip_tracker.drain_cleanup_queue().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn client_repeat_acquire_release_cycles_never_accumulate_state() { + let user = "repeat-cycle-user"; + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 32).await; + + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 32); + + for i in 0..500u16 { + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(198, 18, (i / 250) as u8, (i % 250 + 1) as u8)), + 36000 + (i % 128), + ); + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .unwrap(); + drop(reservation); + } + + ip_tracker.drain_cleanup_queue().await; + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} diff --git a/src/proxy/client_security_tests.rs b/src/proxy/client_security_tests.rs index d3c411e..74eeba2 100644 --- a/src/proxy/client_security_tests.rs +++ b/src/proxy/client_security_tests.rs @@ -1322,13 +1322,20 @@ async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() { let client_hello = make_valid_tls_client_hello(&secret, 0); let invalid_mtproto = vec![0u8; crate::protocol::constants::HANDSHAKE_LEN]; let tls_app_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_tls_payload = b"still-tls-after-fallback".to_vec(); + let trailing_tls_record = wrap_tls_application_data(&trailing_tls_payload); let expected_fallback = client_hello.clone(); + let expected_trailing_tls_record = trailing_tls_record.clone(); let accept_task = tokio::spawn(async move { let (mut stream, _) = listener.accept().await.unwrap(); let mut got = vec![0u8; expected_fallback.len()]; stream.read_exact(&mut got).await.unwrap(); assert_eq!(got, expected_fallback); + + let mut trailing = vec![0u8; expected_trailing_tls_record.len()]; + stream.read_exact(&mut trailing).await.unwrap(); + assert_eq!(trailing, expected_trailing_tls_record); }); let mut cfg = ProxyConfig::default(); @@ -1396,6 +1403,7 @@ async fn valid_tls_with_invalid_mtproto_falls_back_to_mask_backend() { assert_eq!(tls_response_head[0], 0x16); client_side.write_all(&tls_app_record).await.unwrap(); + client_side.write_all(&trailing_tls_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), accept_task) .await @@ -1421,13 +1429,20 @@ async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() { let client_hello = make_valid_tls_client_hello(&secret, 0); let invalid_mtproto = vec![0u8; crate::protocol::constants::HANDSHAKE_LEN]; let tls_app_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_tls_payload = b"second-tls-record".to_vec(); + let trailing_tls_record = wrap_tls_application_data(&trailing_tls_payload); let expected_fallback = client_hello.clone(); + let expected_trailing_tls_record = trailing_tls_record.clone(); let mask_accept_task = tokio::spawn(async move { let (mut stream, _) = mask_listener.accept().await.unwrap(); let mut got = vec![0u8; expected_fallback.len()]; stream.read_exact(&mut got).await.unwrap(); assert_eq!(got, expected_fallback); + + let mut trailing = vec![0u8; expected_trailing_tls_record.len()]; + stream.read_exact(&mut trailing).await.unwrap(); + assert_eq!(trailing, expected_trailing_tls_record); }); let mut cfg = ProxyConfig::default(); @@ -1513,6 +1528,7 @@ async fn client_handler_tls_bad_mtproto_is_forwarded_to_mask_backend() { assert_eq!(tls_response_head[0], 0x16); client.write_all(&tls_app_record).await.unwrap(); + client.write_all(&trailing_tls_record).await.unwrap(); tokio::time::timeout(Duration::from_secs(3), mask_accept_task) .await diff --git a/src/proxy/client_tls_mtproto_fallback_security_tests.rs b/src/proxy/client_tls_mtproto_fallback_security_tests.rs new file mode 100644 index 0000000..80393bb --- /dev/null +++ b/src/proxy/client_tls_mtproto_fallback_security_tests.rs @@ -0,0 +1,1503 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType}; +use crate::crypto::sha256_hmac; +use crate::protocol::constants::{HANDSHAKE_LEN, MAX_TLS_CHUNK_SIZE, TLS_VERSION}; +use crate::protocol::tls; +use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener; + +struct PipelineHarness { + config: Arc, + stats: Arc, + upstream_manager: Arc, + replay_checker: Arc, + buffer_pool: Arc, + rng: Arc, + route_runtime: Arc, + ip_tracker: Arc, + beobachten: Arc, +} + +fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.censorship.mask = true; + cfg.censorship.mask_unix_sock = None; + cfg.censorship.mask_host = Some("127.0.0.1".to_string()); + cfg.censorship.mask_port = mask_port; + cfg.censorship.mask_proxy_protocol = 0; + cfg.access.ignore_time_skew = true; + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let upstream_manager = Arc::new(UpstreamManager::new( + vec![UpstreamConfig { + upstream_type: UpstreamType::Direct { + interface: None, + bind_addresses: None, + }, + weight: 1, + enabled: true, + scopes: String::new(), + selected_scope: String::new(), + }], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )); + + PipelineHarness { + config, + stats, + upstream_manager, + replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))), + buffer_pool: Arc::new(BufferPool::new()), + rng: Arc::new(SecureRandom::new()), + route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + ip_tracker: Arc::new(UserIpTracker::new()), + beobachten: Arc::new(BeobachtenStore::new()), + } +} + +fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec { + assert!(tls_len <= u16::MAX as usize, "TLS length must fit into record header"); + + let total_len = 5 + tls_len; + let mut handshake = vec![fill; total_len]; + + handshake[0] = 0x16; + handshake[1] = 0x03; + handshake[2] = 0x01; + handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes()); + + let session_id_len: usize = 32; + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn wrap_tls_application_data(payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(0x17); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +fn wrap_tls_record(record_type: u8, payload: &[u8]) -> Vec { + let mut record = Vec::with_capacity(5 + payload.len()); + record.push(record_type); + record.extend_from_slice(&TLS_VERSION); + record.extend_from_slice(&(payload.len() as u16).to_be_bytes()); + record.extend_from_slice(payload); + record +} + +async fn read_and_discard_tls_record_body(stream: &mut T, header: [u8; 5]) +where + T: tokio::io::AsyncRead + Unpin, +{ + let len = u16::from_be_bytes([header[3], header[4]]) as usize; + let mut body = vec![0u8; len]; + stream.read_exact(&mut body).await.unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_preserves_wire_and_backend_response() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x81u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 0, 600, 0x42); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_payload = b"masked-trailing-record".to_vec(); + let trailing_record = wrap_tls_application_data(&trailing_payload); + let backend_response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec(); + + let expected_client_hello = client_hello.clone(); + let expected_trailing_record = trailing_record.clone(); + let expected_response = backend_response.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_hello = vec![0u8; expected_client_hello.len()]; + stream.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_client_hello); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + + stream.write_all(&expected_response).await.unwrap(); + }); + + let harness = build_harness("81818181818181818181818181818181", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.181:56001".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_keeps_connects_bad_accounting() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x82u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 1, 600, 0x43); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_record = wrap_tls_application_data(b"x"); + + let expected_client_hello = client_hello.clone(); + let expected_trailing_record = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_hello = vec![0u8; expected_client_hello.len()]; + stream.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_client_hello); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + }); + + let harness = build_harness("82828282828282828282828282828282", backend_addr.port()); + let bad_before = harness.stats.get_connects_bad(); + + let (server_side, mut client_side) = duplex(65536); + let peer: SocketAddr = "198.51.100.182:56002".parse().unwrap(); + let stats_for_assert = harness.stats.clone(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + + let bad_after = stats_for_assert.get_connects_bad(); + assert_eq!(bad_after, bad_before + 1, "connects_bad must increase exactly once for invalid MTProto after valid TLS"); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_forwards_zero_length_tls_record_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x83u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 2, 600, 0x44); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_record = wrap_tls_application_data(&[]); + + let expected_client_hello = client_hello.clone(); + let expected_trailing_record = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_hello = vec![0u8; expected_client_hello.len()]; + stream.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_client_hello); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + }); + + let harness = build_harness("83838383838383838383838383838383", backend_addr.port()); + let (server_side, mut client_side) = duplex(65536); + let peer: SocketAddr = "198.51.100.183:56003".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_forwards_max_tls_record_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x84u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 3, 600, 0x45); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_payload = vec![0xAB; MAX_TLS_CHUNK_SIZE]; + let trailing_record = wrap_tls_application_data(&trailing_payload); + + let expected_client_hello = client_hello.clone(); + let expected_trailing_record = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_hello = vec![0u8; expected_client_hello.len()]; + stream.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_client_hello); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + }); + + let harness = build_harness("84848484848484848484848484848484", backend_addr.port()); + let (server_side, mut client_side) = duplex(2 * 1024 * 1024); + let peer: SocketAddr = "198.51.100.184:56004".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_light_fuzz_tls_record_lengths_verbatim() { + let lengths = [0usize, 1, 2, 3, 7, 15, 31, 63, 127, 255, 1024, 4096]; + + for (idx, payload_len) in lengths.iter().copied().enumerate() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x85u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, idx as u32 + 4, 600, 0x46 + idx as u8); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + + let mut payload = vec![0u8; payload_len]; + for (i, b) in payload.iter_mut().enumerate() { + *b = ((idx as u8).wrapping_mul(29)).wrapping_add(i as u8); + } + let trailing_record = wrap_tls_application_data(&payload); + + let expected_client_hello = client_hello.clone(); + let expected_trailing_record = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_hello = vec![0u8; expected_client_hello.len()]; + stream.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_client_hello); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + }); + + let harness = build_harness("85858585858585858585858585858585", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = format!("198.51.100.185:{}", 56010 + idx as u16) + .parse() + .unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + } +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_concurrent_sessions_are_isolated() { + let sessions = 24usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected_pairs = std::collections::HashMap::new(); + let secret = [0x86u8; 16]; + for idx in 0..sessions { + let hello = make_valid_tls_client_hello(&secret, idx as u32 + 100, 600, 0x60 + idx as u8); + let payload = vec![idx as u8; 64 + idx]; + let trailing = wrap_tls_application_data(&payload); + expected_pairs.insert(hello, trailing); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected_pairs; + for idx in 0..sessions { + let (mut stream, _) = listener.accept().await.unwrap(); + + let _ = idx; + let mut got_hello = vec![0u8; 605]; + stream.read_exact(&mut got_hello).await.unwrap(); + let expected_trailing = remaining + .remove(&got_hello) + .expect("unexpected client hello in concurrent isolation test"); + + let mut got_trailing = vec![0u8; expected_trailing.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing); + } + + assert!(remaining.is_empty(), "all expected client sessions must be matched exactly once"); + }); + + let mut client_tasks = Vec::with_capacity(sessions); + + for idx in 0..sessions { + let harness = build_harness("86868686868686868686868686868686", backend_addr.port()); + let secret = [0x86u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, idx as u32 + 100, 600, 0x60 + idx as u8); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let trailing_payload = vec![idx as u8; 64 + idx]; + let trailing_record = wrap_tls_application_data(&trailing_payload); + + let peer: SocketAddr = format!("198.51.100.186:{}", 57000 + idx as u16) + .parse() + .unwrap(); + + client_tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(262144); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + })); + } + + for task in client_tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_forwards_fragmented_client_writes_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x87u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 9, 600, 0x57); + let invalid_mtproto = vec![0u8; HANDSHAKE_LEN]; + let invalid_mtproto_record = wrap_tls_application_data(&invalid_mtproto); + let payload = b"fragmented-writes-to-test-stream-boundary-robustness".to_vec(); + let trailing_record = wrap_tls_application_data(&payload); + + let expected_client_hello = client_hello.clone(); + let expected_trailing_record = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_hello = vec![0u8; expected_client_hello.len()]; + stream.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_client_hello); + + let mut got_trailing = vec![0u8; expected_trailing_record.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing_record); + }); + + let harness = build_harness("87878787878787878787878787878787", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = "198.51.100.187:56087".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + + for chunk in trailing_record.chunks(3) { + client_side.write_all(chunk).await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_header_fragmentation_bytewise_is_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x88u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 10, 600, 0x58); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"bytewise-header"); + + let expected_hello = client_hello.clone(); + let expected_trailing = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got_hello = vec![0u8; expected_hello.len()]; + stream.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_hello); + + let mut got_trailing = vec![0u8; expected_trailing.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing); + }); + + let harness = build_harness("88888888888888888888888888888888", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.188:56088".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + for b in trailing_record.iter().copied() { + client_side.write_all(&[b]).await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_record_splitting_chaos_is_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x89u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 11, 600, 0x59); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let mut payload = vec![0u8; 2048]; + for (i, b) in payload.iter_mut().enumerate() { + *b = (i as u8).wrapping_mul(17).wrapping_add(3); + } + let trailing_record = wrap_tls_application_data(&payload); + + let expected_hello = client_hello.clone(); + let expected_trailing = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got_hello = vec![0u8; expected_hello.len()]; + stream.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_hello); + + let mut got_trailing = vec![0u8; expected_trailing.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing); + }); + + let harness = build_harness("89898989898989898989898989898989", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = "198.51.100.189:56089".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + + let chaos = [7usize, 1, 19, 3, 5, 31, 2, 11, 13, 17]; + let mut pos = 0usize; + let mut idx = 0usize; + while pos < trailing_record.len() { + let step = chaos[idx % chaos.len()]; + let end = (pos + step).min(trailing_record.len()); + client_side.write_all(&trailing_record[pos..end]).await.unwrap(); + pos = end; + idx += 1; + } + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_multiple_tls_records_are_forwarded_in_order() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Au8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 12, 600, 0x5A); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let r1 = wrap_tls_application_data(b"alpha"); + let r2 = wrap_tls_application_data(b"beta-beta"); + let r3 = wrap_tls_application_data(b"gamma-gamma-gamma"); + let expected = [r1.clone(), r2.clone(), r3.clone()].concat(); + + let expected_hello = client_hello.clone(); + let expected_concat = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got_hello = vec![0u8; expected_hello.len()]; + stream.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_hello); + + let mut got = vec![0u8; expected_concat.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected_concat); + }); + + let harness = build_harness("8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a8a", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.190:56090".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&r1).await.unwrap(); + client_side.write_all(&r2).await.unwrap(); + client_side.write_all(&r3).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_client_half_close_propagates_eof_to_backend() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Bu8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 13, 600, 0x5B); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"half-close-probe"); + + let expected_hello = client_hello.clone(); + let expected_trailing = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_hello = vec![0u8; expected_hello.len()]; + stream.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_hello); + + let mut got_trailing = vec![0u8; expected_trailing.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing); + + let mut tail = [0u8; 1]; + let n = stream.read(&mut tail).await.unwrap(); + assert_eq!(n, 0, "backend must observe EOF after client write half-close"); + }); + + let harness = build_harness("8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b8b", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.191:56091".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + client_side.shutdown().await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_backend_half_close_after_response_is_tolerated() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Cu8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 14, 600, 0x5C); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"backend-half-close"); + let backend_response = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec(); + + let expected_hello = client_hello.clone(); + let expected_trailing = trailing_record.clone(); + let response = backend_response.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got_hello = vec![0u8; expected_hello.len()]; + stream.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_hello); + + let mut got_trailing = vec![0u8; expected_trailing.len()]; + stream.read_exact(&mut got_trailing).await.unwrap(); + assert_eq!(got_trailing, expected_trailing); + + stream.write_all(&response).await.unwrap(); + stream.shutdown().await.unwrap(); + }); + + let harness = build_harness("8c8c8c8c8c8c8c8c8c8c8c8c8c8c8c8c", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.192:56092".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + read_and_discard_tls_record_body(&mut client_side, tls_response_head).await; + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_backend_reset_after_clienthello_is_handled() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Du8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 15, 600, 0x5D); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"backend-reset"); + + let expected_hello = client_hello.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got_hello = vec![0u8; expected_hello.len()]; + stream.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_hello); + drop(stream); + }); + + let harness = build_harness("8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d8d", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.193:56093".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + let write_res = client_side.write_all(&trailing_record).await; + assert!( + write_res.is_ok() || write_res.is_err(), + "write completion is environment dependent under backend reset" + ); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_backend_slow_reader_preserves_byte_identity() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Eu8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 16, 600, 0x5E); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let payload = vec![0xEC; 8192]; + let trailing_record = wrap_tls_application_data(&payload); + + let expected_hello = client_hello.clone(); + let expected_trailing = trailing_record.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + + let mut got_hello = vec![0u8; expected_hello.len()]; + stream.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_hello); + + let mut got_trailing = vec![0u8; expected_trailing.len()]; + let mut offset = 0usize; + while offset < got_trailing.len() { + let step = (offset % 97).max(1).min(got_trailing.len() - offset); + stream + .read_exact(&mut got_trailing[offset..offset + step]) + .await + .unwrap(); + offset += step; + tokio::time::sleep(Duration::from_millis(1)).await; + } + assert_eq!(got_trailing, expected_trailing); + }); + + let harness = build_harness("8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = "198.51.100.194:56094".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_replay_pressure_masks_replay_without_serverhello() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x8Fu8; 16]; + let replayed_hello = make_valid_tls_client_hello(&secret, 17, 600, 0x5F); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let trailing_record = wrap_tls_application_data(b"first-session"); + + let expected_first = replayed_hello.clone(); + let expected_second = replayed_hello.clone(); + let expected_trailing = trailing_record.clone(); + + let accept_task = tokio::spawn(async move { + let (mut s1, _) = listener.accept().await.unwrap(); + let mut got1 = vec![0u8; expected_first.len()]; + s1.read_exact(&mut got1).await.unwrap(); + assert_eq!(got1, expected_first); + + let mut got1_tail = vec![0u8; expected_trailing.len()]; + s1.read_exact(&mut got1_tail).await.unwrap(); + assert_eq!(got1_tail, expected_trailing); + drop(s1); + + let (mut s2, _) = listener.accept().await.unwrap(); + let mut got2 = vec![0u8; expected_second.len()]; + s2.read_exact(&mut got2).await.unwrap(); + assert_eq!(got2, expected_second); + }); + + let harness = build_harness("8f8f8f8f8f8f8f8f8f8f8f8f8f8f8f8f", backend_addr.port()); + let stats_for_assert = harness.stats.clone(); + let bad_before = stats_for_assert.get_connects_bad(); + + let run_session = |hello: Vec, send_mtproto: bool| { + let (server_side, mut client_side) = duplex(131072); + let config = harness.config.clone(); + let stats = harness.stats.clone(); + let upstream = harness.upstream_manager.clone(); + let replay = harness.replay_checker.clone(); + let pool = harness.buffer_pool.clone(); + let rng = harness.rng.clone(); + let route = harness.route_runtime.clone(); + let ipt = harness.ip_tracker.clone(); + let beob = harness.beobachten.clone(); + let invalid_mtproto_record = invalid_mtproto_record.clone(); + let trailing_record = trailing_record.clone(); + async move { + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.195:56095".parse().unwrap(), + config, + stats, + upstream, + replay, + pool, + rng, + None, + route, + None, + ipt, + beob, + false, + )); + + client_side.write_all(&hello).await.unwrap(); + if send_mtproto { + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&trailing_record).await.unwrap(); + } else { + let mut one = [0u8; 1]; + let no_server_hello = tokio::time::timeout( + Duration::from_millis(300), + client_side.read_exact(&mut one), + ) + .await; + assert!( + no_server_hello.is_err() || no_server_hello.unwrap().is_err(), + "replayed TLS hello must not receive authenticated TLS ServerHello" + ); + } + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + } + }; + + run_session(replayed_hello.clone(), true).await; + run_session(replayed_hello.clone(), false).await; + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); + + let bad_after = stats_for_assert.get_connects_bad(); + assert!( + bad_after >= bad_before + 2, + "both invalid-mtproto and replayed-tls paths must increment bad connection accounting" + ); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_large_multi_record_chaos_under_backpressure() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x90u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 18, 600, 0x60); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let a = wrap_tls_application_data(&vec![0xA1; 2048]); + let b = wrap_tls_application_data(&vec![0xB2; 3072]); + let c = wrap_tls_application_data(&vec![0xC3; 1536]); + let expected = [a.clone(), b.clone(), c.clone()].concat(); + + let expected_hello = client_hello.clone(); + let expected_payload = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got_hello = vec![0u8; expected_hello.len()]; + stream.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_hello); + + let mut got = vec![0u8; expected_payload.len()]; + let mut pos = 0usize; + while pos < got.len() { + let step = (pos % 257).max(1).min(got.len() - pos); + stream.read_exact(&mut got[pos..pos + step]).await.unwrap(); + pos += step; + tokio::time::sleep(Duration::from_millis(1)).await; + } + assert_eq!(got, expected_payload); + }); + + let harness = build_harness("90909090909090909090909090909090", backend_addr.port()); + let (server_side, mut client_side) = duplex(262144); + let peer: SocketAddr = "198.51.100.196:56096".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + + let chaos = [5usize, 23, 11, 47, 3, 19, 29, 13, 7, 31]; + for record in [&a, &b, &c] { + let mut pos = 0usize; + let mut idx = 0usize; + while pos < record.len() { + let step = chaos[idx % chaos.len()]; + let end = (pos + step).min(record.len()); + client_side.write_all(&record[pos..end]).await.unwrap(); + pos = end; + idx += 1; + } + } + + tokio::time::timeout(Duration::from_secs(5), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_interleaved_control_and_application_records_verbatim() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let secret = [0x91u8; 16]; + let client_hello = make_valid_tls_client_hello(&secret, 19, 600, 0x61); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + + let ccs = wrap_tls_record(0x14, &[0x01]); + let app = wrap_tls_application_data(b"opaque"); + let alert = wrap_tls_record(0x15, &[0x01, 0x00]); + let expected = [ccs.clone(), app.clone(), alert.clone()].concat(); + + let expected_hello = client_hello.clone(); + let expected_records = expected.clone(); + let accept_task = tokio::spawn(async move { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut got_hello = vec![0u8; expected_hello.len()]; + stream.read_exact(&mut got_hello).await.unwrap(); + assert_eq!(got_hello, expected_hello); + + let mut got = vec![0u8; expected_records.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected_records); + }); + + let harness = build_harness("91919191919191919191919191919191", backend_addr.port()); + let (server_side, mut client_side) = duplex(131072); + let peer: SocketAddr = "198.51.100.197:56097".parse().unwrap(); + + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut tls_response_head = [0u8; 5]; + client_side.read_exact(&mut tls_response_head).await.unwrap(); + assert_eq!(tls_response_head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side.write_all(&ccs).await.unwrap(); + client_side.write_all(&app).await.unwrap(); + client_side.write_all(&alert).await.unwrap(); + + tokio::time::timeout(Duration::from_secs(3), accept_task) + .await + .unwrap() + .unwrap(); + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tls_bad_mtproto_fallback_many_short_sessions_with_chaos_no_cross_leak() { + let sessions = 40usize; + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let backend_addr = listener.local_addr().unwrap(); + + let mut expected_pairs = std::collections::HashMap::new(); + let secret = [0x92u8; 16]; + for idx in 0..sessions { + let hello = make_valid_tls_client_hello(&secret, idx as u32 + 200, 600, 0x70 + idx as u8); + let payload = vec![idx as u8; 33 + (idx % 17)]; + let record = wrap_tls_application_data(&payload); + expected_pairs.insert(hello, record); + } + + let accept_task = tokio::spawn(async move { + let mut remaining = expected_pairs; + for idx in 0..sessions { + let (mut stream, _) = listener.accept().await.unwrap(); + + let _ = idx; + let mut got_hello = vec![0u8; 605]; + stream.read_exact(&mut got_hello).await.unwrap(); + let expected_record = remaining + .remove(&got_hello) + .expect("unexpected client hello in short-session chaos test"); + + let mut got = vec![0u8; expected_record.len()]; + stream.read_exact(&mut got).await.unwrap(); + assert_eq!(got, expected_record); + } + + assert!(remaining.is_empty(), "all expected sessions must be consumed exactly once"); + }); + + let mut tasks = Vec::with_capacity(sessions); + for idx in 0..sessions { + let harness = build_harness("92929292929292929292929292929292", backend_addr.port()); + let secret = [0x92u8; 16]; + let client_hello = + make_valid_tls_client_hello(&secret, idx as u32 + 200, 600, 0x70 + idx as u8); + let invalid_mtproto_record = wrap_tls_application_data(&vec![0u8; HANDSHAKE_LEN]); + let payload = vec![idx as u8; 33 + (idx % 17)]; + let record = wrap_tls_application_data(&payload); + + let peer: SocketAddr = format!("198.51.100.198:{}", 58000 + idx as u16) + .parse() + .unwrap(); + + tasks.push(tokio::spawn(async move { + let (server_side, mut client_side) = duplex(131072); + let handler = tokio::spawn(handle_client_stream( + server_side, + peer, + harness.config, + harness.stats, + harness.upstream_manager, + harness.replay_checker, + harness.buffer_pool, + harness.rng, + None, + harness.route_runtime, + None, + harness.ip_tracker, + harness.beobachten, + false, + )); + + client_side.write_all(&client_hello).await.unwrap(); + let mut head = [0u8; 5]; + client_side.read_exact(&mut head).await.unwrap(); + assert_eq!(head[0], 0x16); + + client_side.write_all(&invalid_mtproto_record).await.unwrap(); + for chunk in record.chunks((idx % 9) + 1) { + client_side.write_all(chunk).await.unwrap(); + } + + drop(client_side); + let _ = tokio::time::timeout(Duration::from_secs(3), handler) + .await + .unwrap() + .unwrap(); + })); + } + + for task in tasks { + task.await.unwrap(); + } + + tokio::time::timeout(Duration::from_secs(6), accept_task) + .await + .unwrap() + .unwrap(); +}