diff --git a/docs/CONFIG_PARAMS.en.md b/docs/CONFIG_PARAMS.en.md index 3eee3a7..f8c56a0 100644 --- a/docs/CONFIG_PARAMS.en.md +++ b/docs/CONFIG_PARAMS.en.md @@ -198,11 +198,14 @@ This document lists all configuration keys accepted by `config.toml`. | listen_tcp | `bool \| null` | `null` (auto) | — | Explicit TCP listener enable/disable override. | | proxy_protocol | `bool` | `false` | — | Enables HAProxy PROXY protocol parsing on incoming client connections. | | proxy_protocol_header_timeout_ms | `u64` | `500` | Must be `> 0`. | Timeout for PROXY protocol header read/parse (ms). | +| proxy_protocol_trusted_cidrs | `IpNetwork[]` | `[]` | — | When non-empty, only connections from these proxy source CIDRs are allowed to provide PROXY protocol headers. If empty, PROXY headers are rejected by default (security hardening). | | metrics_port | `u16 \| null` | `null` | — | Metrics endpoint port (enables metrics listener). | | metrics_listen | `String \| null` | `null` | — | Full metrics bind address (`IP:PORT`), overrides `metrics_port`. | | metrics_whitelist | `IpNetwork[]` | `["127.0.0.1/32", "::1/128"]` | — | CIDR whitelist for metrics endpoint access. | | max_connections | `u32` | `10000` | — | Max concurrent client connections (`0` = unlimited). | +Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers are parsed from the first bytes of the connection and the client source address is replaced with `src_addr` from the header. For security, the peer source IP (the direct connection address) is verified against `server.proxy_protocol_trusted_cidrs`; if this list is empty, PROXY headers are rejected and the connection is considered untrusted. + ## [server.api] | Parameter | Type | Default | Constraints / validation | Description | diff --git a/src/proxy/client.rs b/src/proxy/client.rs index 4b7f57e..1941f36 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -1328,3 +1328,15 @@ mod beobachten_ttl_bounds_security_tests; #[cfg(test)] #[path = "tests/client_tls_record_wrap_hardening_security_tests.rs"] mod tls_record_wrap_hardening_security_tests; + +#[cfg(test)] +#[path = "tests/client_clever_advanced_tests.rs"] +mod client_clever_advanced_tests; + +#[cfg(test)] +#[path = "tests/client_more_advanced_tests.rs"] +mod client_more_advanced_tests; + +#[cfg(test)] +#[path = "tests/client_deep_invariants_tests.rs"] +mod client_deep_invariants_tests; diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 5632977..7e4b62c 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -605,16 +605,6 @@ where } }; - // Replay tracking is applied only after successful authentication to avoid - // letting unauthenticated probes evict valid entries from the replay cache. - let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; - if replay_checker.check_and_add_tls_digest(digest_half) { - auth_probe_record_failure(peer.ip(), Instant::now()); - maybe_apply_server_hello_delay(config).await; - warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); - return HandshakeResult::BadClient { reader, writer }; - } - let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { Some((_, s)) => s, None => { @@ -670,6 +660,16 @@ where None }; + // Replay tracking is applied only after full policy validation (including + // ALPN checks) so rejected handshakes cannot poison replay state. + let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; + if replay_checker.check_and_add_tls_digest(digest_half) { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); + return HandshakeResult::BadClient { reader, writer }; + } + let response = if let Some((cached_entry, use_full_cert_payload)) = cached { emulator::build_emulated_server_hello( secret, @@ -979,6 +979,22 @@ mod saturation_poison_security_tests; #[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"] mod auth_probe_hardening_adversarial_tests; +#[cfg(test)] +#[path = "tests/handshake_advanced_clever_tests.rs"] +mod advanced_clever_tests; + +#[cfg(test)] +#[path = "tests/handshake_more_clever_tests.rs"] +mod more_clever_tests; + +#[cfg(test)] +#[path = "tests/handshake_real_bug_stress_tests.rs"] +mod real_bug_stress_tests; + +#[cfg(test)] +#[path = "tests/handshake_timing_manual_bench_tests.rs"] +mod timing_manual_bench_tests; + /// Compile-time guard: HandshakeSuccess holds cryptographic key material and /// must never be Copy. A Copy impl would allow silent key duplication, /// undermining the zeroize-on-drop guarantee. diff --git a/src/proxy/tests/client_clever_advanced_tests.rs b/src/proxy/tests/client_clever_advanced_tests.rs new file mode 100644 index 0000000..da2e703 --- /dev/null +++ b/src/proxy/tests/client_clever_advanced_tests.rs @@ -0,0 +1,409 @@ +use super::*; +use crate::config::{UpstreamConfig, UpstreamType, ProxyConfig}; +use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE}; +use crate::stats::Stats; +use crate::transport::UpstreamManager; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf, duplex}; +use tokio::net::TcpListener; + +#[test] +fn edge_mask_reject_delay_min_greater_than_max_does_not_panic() { + let mut config = ProxyConfig::default(); + config.censorship.server_hello_delay_min_ms = 5000; + config.censorship.server_hello_delay_max_ms = 1000; + + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let start = std::time::Instant::now(); + maybe_apply_mask_reject_delay(&config).await; + let elapsed = start.elapsed(); + + assert!(elapsed >= Duration::from_millis(1000)); + assert!(elapsed < Duration::from_millis(1500)); + }); +} + +#[test] +fn edge_handshake_timeout_with_mask_grace_saturating_add_prevents_overflow() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = u64::MAX; + config.censorship.mask = true; + + let timeout = handshake_timeout_with_mask_grace(&config); + assert_eq!(timeout.as_secs(), u64::MAX); +} + +#[test] +fn edge_tls_clienthello_len_in_bounds_exact_boundaries() { + assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE)); + assert!(!tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE - 1)); + assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE)); + assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1)); +} + +#[test] +fn edge_synthetic_local_addr_boundaries() { + assert_eq!(synthetic_local_addr(0).port(), 0); + assert_eq!(synthetic_local_addr(80).port(), 80); + assert_eq!(synthetic_local_addr(u16::MAX).port(), u16::MAX); +} + +#[test] +fn edge_beobachten_record_handshake_failure_class_stream_error_eof() { + let beobachten = BeobachtenStore::new(); + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 1; + + let eof_err = ProxyError::Stream(crate::error::StreamError::UnexpectedEof); + let peer_ip: IpAddr = "198.51.100.100".parse().unwrap(); + + record_handshake_failure_class(&beobachten, &config, peer_ip, &eof_err); + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[expected_64_got_0]")); +} + +#[tokio::test] +async fn adversarial_tls_handshake_timeout_during_masking_delay() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = false; + cfg.timeouts.client_handshake = 1; + cfg.censorship.mask = true; + cfg.censorship.server_hello_delay_min_ms = 3000; + cfg.censorship.server_hello_delay_max_ms = 3000; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let (server_side, mut client_side) = duplex(4096); + + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.1:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(&[0x16, 0x03, 0x01, 0xFF, 0xFF]).await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(4), handle) + .await + .unwrap() + .unwrap(); + + assert!(matches!(result, Err(ProxyError::TgHandshakeTimeout))); + assert_eq!(stats.get_handshake_timeouts(), 1); +} + +#[tokio::test] +async fn blackhat_proxy_protocol_slowloris_timeout() { + let mut cfg = ProxyConfig::default(); + cfg.server.proxy_protocol_header_timeout_ms = 200; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.2:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + true, + )); + + client_side.write_all(b"PROXY TCP4 192.").await.unwrap(); + tokio::time::sleep(Duration::from_millis(300)).await; + + let result = tokio::time::timeout(Duration::from_secs(2), handle) + .await + .unwrap() + .unwrap(); + + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[test] +fn blackhat_ipv4_mapped_ipv6_proxy_source_bypass_attempt() { + let trusted = vec!["192.0.2.0/24".parse().unwrap()]; + let peer_ip = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc000, 0x0201)); + assert!(!is_trusted_proxy_source(peer_ip, &trusted)); +} + +#[tokio::test] +async fn negative_proxy_protocol_enabled_but_client_sends_tls_hello() { + let mut cfg = ProxyConfig::default(); + cfg.server.proxy_protocol_header_timeout_ms = 500; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.3:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + true, + )); + + client_side.write_all(&[0x16, 0x03, 0x01, 0x02, 0x00]).await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(2), handle) + .await + .unwrap() + .unwrap(); + + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn edge_client_stream_exactly_4_bytes_eof() { + let config = Arc::new(ProxyConfig::default()); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.4:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(&[0x16, 0x03, 0x01, 0x00]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[expected_64_got_0]")); +} + +#[tokio::test] +async fn edge_client_stream_tls_header_valid_but_body_1_byte_short_eof() { + let config = Arc::new(ProxyConfig::default()); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.5:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(&[0x16, 0x03, 0x01, 0x00, 100]).await.unwrap(); + client_side.write_all(&vec![0x41; 99]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn integration_non_tls_modes_disabled_immediately_masks() { + let mut cfg = ProxyConfig::default(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + cfg.censorship.mask = true; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handle = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.6:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(b"GET / HTTP/1.1\r\n").await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; + assert_eq!(stats.get_connects_bad(), 1); +} + +struct YieldingReader { + data: Vec, + pos: usize, + yields_left: usize, +} + +impl AsyncRead for YieldingReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = self.get_mut(); + if this.yields_left > 0 { + this.yields_left -= 1; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + if this.pos >= this.data.len() { + return Poll::Ready(Ok(())); + } + buf.put_slice(&this.data[this.pos..this.pos + 1]); + this.pos += 1; + this.yields_left = 2; + Poll::Ready(Ok(())) + } +} + +#[tokio::test] +async fn fuzz_read_with_progress_heavy_yielding() { + let expected_data = b"HEAVY_YIELD_TEST_DATA".to_vec(); + let mut reader = YieldingReader { + data: expected_data.clone(), + pos: 0, + yields_left: 2, + }; + + let mut buf = vec![0u8; expected_data.len()]; + let read_bytes = read_with_progress(&mut reader, &mut buf).await.unwrap(); + + assert_eq!(read_bytes, expected_data.len()); + assert_eq!(buf, expected_data); +} + +#[test] +fn edge_wrap_tls_application_record_exactly_u16_max() { + let payload = vec![0u8; 65535]; + let wrapped = wrap_tls_application_record(&payload); + assert_eq!(wrapped.len(), 65540); + assert_eq!(wrapped[0], TLS_RECORD_APPLICATION); + assert_eq!(&wrapped[3..5], &65535u16.to_be_bytes()); +} + +#[test] +fn fuzz_wrap_tls_application_record_lengths() { + let lengths = [0, 1, 65534, 65535, 65536, 131070, 131071, 131072]; + for len in lengths { + let payload = vec![0u8; len]; + let wrapped = wrap_tls_application_record(&payload); + let expected_chunks = len.div_ceil(65535).max(1); + assert_eq!(wrapped.len(), len + 5 * expected_chunks); + } +} + +#[tokio::test] +async fn stress_user_connection_reservation_concurrent_same_ip_exhaustion() { + let user = "stress-same-ip-user"; + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 5); + + let config = Arc::new(config); + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 10).await; + + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77)), 55000); + + let mut tasks = tokio::task::JoinSet::new(); + let mut reservations = Vec::new(); + + for _ in 0..10 { + let config = config.clone(); + let stats = stats.clone(); + let ip_tracker = ip_tracker.clone(); + tasks.spawn(async move { + RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await + }); + } + + let mut successes = 0; + let mut failures = 0; + + while let Some(res) = tasks.join_next().await { + match res.unwrap() { + Ok(r) => { + successes += 1; + reservations.push(r); + } + Err(_) => failures += 1, + } + } + + assert_eq!(successes, 5); + assert_eq!(failures, 5); + assert_eq!(stats.get_user_curr_connects(user), 5); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 1); + + for reservation in reservations { + reservation.release().await; + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} diff --git a/src/proxy/tests/client_deep_invariants_tests.rs b/src/proxy/tests/client_deep_invariants_tests.rs new file mode 100644 index 0000000..97c55c6 --- /dev/null +++ b/src/proxy/tests/client_deep_invariants_tests.rs @@ -0,0 +1,196 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE; +use crate::stats::Stats; +use crate::transport::UpstreamManager; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncWriteExt, duplex}; + +#[test] +fn invariant_wrap_tls_application_record_exact_multiples() { + let chunk_size = u16::MAX as usize; + let payload = vec![0xAA; chunk_size * 2]; + + let wrapped = wrap_tls_application_record(&payload); + + assert_eq!(wrapped.len(), 2 * (5 + chunk_size)); + assert_eq!(wrapped[0], TLS_RECORD_APPLICATION); + assert_eq!(&wrapped[3..5], &65535u16.to_be_bytes()); + + let second_header_idx = 5 + chunk_size; + assert_eq!(wrapped[second_header_idx], TLS_RECORD_APPLICATION); + assert_eq!( + &wrapped[second_header_idx + 3..second_header_idx + 5], + &65535u16.to_be_bytes() + ); +} + +#[tokio::test] +async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking() { + let config = Arc::new(ProxyConfig::default()); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.20:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + let claimed_len = MIN_TLS_CLIENT_HELLO_SIZE as u16; + let mut header = vec![0x16, 0x03, 0x01]; + header.extend_from_slice(&claimed_len.to_be_bytes()); + + client_side.write_all(&header).await.unwrap(); + client_side + .write_all(&vec![0x42; MIN_TLS_CLIENT_HELLO_SIZE - 1]) + .await + .unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap(); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn invariant_acquire_reservation_ip_limit_rollback() { + let user = "rollback-test-user"; + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 10); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let peer_a = "198.51.100.21:55000".parse().unwrap(); + let _res_a = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_a, + ip_tracker.clone(), + ) + .await + .unwrap(); + + assert_eq!(stats.get_user_curr_connects(user), 1); + + let peer_b = "203.0.113.22:55000".parse().unwrap(); + let res_b = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer_b, + ip_tracker.clone(), + ) + .await; + + assert!(matches!( + res_b, + Err(ProxyError::ConnectionLimitExceeded { .. }) + )); + assert_eq!(stats.get_user_curr_connects(user), 1); +} + +#[tokio::test] +async fn invariant_quota_exact_boundary_inclusive() { + let user = "quota-strict-user"; + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert(user.to_string(), 1000); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + let peer = "198.51.100.23:55000".parse().unwrap(); + + stats.add_user_octets_from(user, 999); + let res1 = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + assert!(res1.is_ok()); + res1.unwrap().release().await; + + stats.add_user_octets_from(user, 1); + let res2 = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await; + assert!(matches!(res2, Err(ProxyError::DataQuotaExceeded { .. }))); +} + +#[tokio::test] +async fn invariant_direct_mode_partial_header_eof_is_error_not_bad_connect() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.25:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(&[0xEF, 0xEF, 0xEF]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + + assert!(result.is_err()); + assert_eq!(stats.get_connects_bad(), 0); + let snapshot = beobachten.snapshot_text(Duration::from_secs(60)); + assert!(snapshot.contains("[expected_64_got_0]")); +} + +#[tokio::test] +async fn invariant_route_mode_snapshot_picks_up_latest_mode() { + let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)); + assert!(matches!( + route_runtime.snapshot().mode, + RelayRouteMode::Direct + )); + + route_runtime.set_mode(RelayRouteMode::Middle); + assert!(matches!( + route_runtime.snapshot().mode, + RelayRouteMode::Middle + )); +} diff --git a/src/proxy/tests/client_more_advanced_tests.rs b/src/proxy/tests/client_more_advanced_tests.rs new file mode 100644 index 0000000..021848a --- /dev/null +++ b/src/proxy/tests/client_more_advanced_tests.rs @@ -0,0 +1,257 @@ +use super::*; +use crate::config::ProxyConfig; +use crate::stats::Stats; +use crate::transport::UpstreamManager; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; + +#[tokio::test] +async fn edge_mask_delay_bypassed_if_max_is_zero() { + let mut config = ProxyConfig::default(); + config.censorship.server_hello_delay_min_ms = 10_000; + config.censorship.server_hello_delay_max_ms = 0; + + let start = std::time::Instant::now(); + maybe_apply_mask_reject_delay(&config).await; + assert!(start.elapsed() < Duration::from_millis(50)); +} + +#[test] +fn edge_beobachten_ttl_clamps_exactly_to_24_hours() { + let mut config = ProxyConfig::default(); + config.general.beobachten = true; + config.general.beobachten_minutes = 100_000; + + let ttl = beobachten_ttl(&config); + assert_eq!(ttl.as_secs(), 24 * 60 * 60); +} + +#[test] +fn edge_wrap_tls_application_record_empty_payload() { + let wrapped = wrap_tls_application_record(&[]); + assert_eq!(wrapped.len(), 5); + assert_eq!(wrapped[0], TLS_RECORD_APPLICATION); + assert_eq!(&wrapped[3..5], &[0, 0]); +} + +#[tokio::test] +async fn boundary_user_data_quota_exact_match_rejects() { + let user = "quota-boundary-user"; + let mut config = ProxyConfig::default(); + config.access.user_data_quota.insert(user.to_string(), 1024); + + let stats = Arc::new(Stats::new()); + stats.add_user_octets_from(user, 1024); + + let ip_tracker = Arc::new(UserIpTracker::new()); + let peer = "198.51.100.10:55000".parse().unwrap(); + + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await; + + assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); +} + +#[tokio::test] +async fn boundary_user_expiration_in_past_rejects() { + let user = "expired-boundary-user"; + let mut config = ProxyConfig::default(); + let expired_time = chrono::Utc::now() - chrono::Duration::milliseconds(1); + config + .access + .user_expirations + .insert(user.to_string(), expired_time); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + let peer = "198.51.100.11:55000".parse().unwrap(); + + let result = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats, + peer, + ip_tracker, + ) + .await; + + assert!(matches!(result, Err(ProxyError::UserExpired { .. }))); +} + +#[tokio::test] +async fn blackhat_proxy_protocol_massive_garbage_rejected_quickly() { + let mut cfg = ProxyConfig::default(); + cfg.server.proxy_protocol_header_timeout_ms = 300; + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.12:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + true, + )); + + client_side.write_all(&vec![b'A'; 2000]).await.unwrap(); + + let result = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap() + .unwrap(); + assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol))); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn edge_tls_body_immediate_eof_triggers_masking_and_bad_connect() { + let mut cfg = ProxyConfig::default(); + cfg.general.beobachten = true; + cfg.general.beobachten_minutes = 1; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + let beobachten = Arc::new(BeobachtenStore::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.13:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + beobachten.clone(), + false, + )); + + client_side.write_all(&[0x16, 0x03, 0x01, 0x00, 100]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap(); + + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn security_classic_mode_disabled_masks_valid_length_payload() { + let mut cfg = ProxyConfig::default(); + cfg.general.modes.classic = false; + cfg.general.modes.secure = false; + cfg.censorship.mask = true; + + let config = Arc::new(cfg); + let stats = Arc::new(Stats::new()); + + let (server_side, mut client_side) = duplex(4096); + let handler = tokio::spawn(handle_client_stream( + server_side, + "198.51.100.15:55000".parse().unwrap(), + config, + stats.clone(), + Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), + Arc::new(BufferPool::new()), + Arc::new(SecureRandom::new()), + None, + Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)), + None, + Arc::new(UserIpTracker::new()), + Arc::new(BeobachtenStore::new()), + false, + )); + + client_side.write_all(&vec![0xEF; 64]).await.unwrap(); + client_side.shutdown().await.unwrap(); + + let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap(); + assert_eq!(stats.get_connects_bad(), 1); +} + +#[tokio::test] +async fn concurrency_ip_tracker_strict_limit_one_rapid_churn() { + let user = "rapid-churn-user"; + let mut config = ProxyConfig::default(); + config.access.user_max_tcp_conns.insert(user.to_string(), 10); + + let stats = Arc::new(Stats::new()); + let ip_tracker = Arc::new(UserIpTracker::new()); + ip_tracker.set_user_limit(user, 1).await; + + let peer = "198.51.100.16:55000".parse().unwrap(); + + for _ in 0..500 { + let reservation = RunningClientHandler::acquire_user_connection_reservation_static( + user, + &config, + stats.clone(), + peer, + ip_tracker.clone(), + ) + .await + .unwrap(); + reservation.release().await; + } + + assert_eq!(stats.get_user_curr_connects(user), 0); + assert_eq!(ip_tracker.get_active_ip_count(user).await, 0); +} + +#[tokio::test] +async fn quirk_read_with_progress_zero_length_buffer_returns_zero_immediately() { + let (mut server_side, _client_side) = duplex(4096); + let mut empty_buf = &mut [][..]; + + let result = tokio::time::timeout( + Duration::from_millis(50), + read_with_progress(&mut server_side, &mut empty_buf), + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), 0); +} + +#[tokio::test] +async fn stress_read_with_progress_cancellation_safety() { + let (mut server_side, mut client_side) = duplex(4096); + + client_side.write_all(b"12345").await.unwrap(); + + let mut buf = [0u8; 10]; + let result = tokio::time::timeout( + Duration::from_millis(50), + read_with_progress(&mut server_side, &mut buf), + ) + .await; + + assert!(result.is_err()); + + client_side.write_all(b"67890").await.unwrap(); + let mut buf2 = [0u8; 5]; + server_side.read_exact(&mut buf2).await.unwrap(); + assert_eq!(&buf2, b"67890"); +} diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 6338e23..2b1fae6 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -7,6 +7,9 @@ use crate::protocol::tls; use crate::proxy::handshake::HandshakeSuccess; use crate::stream::{CryptoReader, CryptoWriter}; use crate::transport::proxy_protocol::ProxyProtocolV1Builder; +use rand::rngs::StdRng; +use rand::Rng; +use rand::SeedableRng; use std::net::Ipv4Addr; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::{TcpListener, TcpStream}; @@ -25,6 +28,202 @@ fn synthetic_local_addr_uses_configured_port_for_max() { assert_eq!(addr.port(), u16::MAX); } +#[test] +fn handshake_timeout_with_mask_grace_includes_mask_margin() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = 2; + + config.censorship.mask = false; + assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_secs(2)); + + config.censorship.mask = true; + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_millis(2750), + "mask mode extends handshake timeout by 750 ms" + ); +} + +#[tokio::test] +async fn read_with_progress_reads_partial_buffers_before_eof() { + let data = vec![0xAA, 0xBB, 0xCC]; + let mut reader = std::io::Cursor::new(data); + let mut buf = [0u8; 5]; + + let read = read_with_progress(&mut reader, &mut buf).await.unwrap(); + assert_eq!(read, 3); + assert_eq!(&buf[..3], &[0xAA, 0xBB, 0xCC]); +} + +#[test] +fn is_trusted_proxy_source_respects_cidr_list_and_empty_rejects_all() { + let peer: IpAddr = "10.10.10.10".parse().unwrap(); + assert!(!is_trusted_proxy_source(peer, &[])); + + let trusted = vec!["10.0.0.0/8".parse().unwrap()]; + assert!(is_trusted_proxy_source(peer, &trusted)); + + let not_trusted = vec!["192.0.2.0/24".parse().unwrap()]; + assert!(!is_trusted_proxy_source(peer, ¬_trusted)); +} + +#[test] +fn is_trusted_proxy_source_accepts_cidr_zero_zero_as_global_cidr() { + let peer: IpAddr = "203.0.113.42".parse().unwrap(); + let trust_all = vec!["0.0.0.0/0".parse().unwrap()]; + assert!(is_trusted_proxy_source(peer, &trust_all)); + + let peer_v6: IpAddr = "2001:db8::1".parse().unwrap(); + let trust_all_v6 = vec!["::/0".parse().unwrap()]; + assert!(is_trusted_proxy_source(peer_v6, &trust_all_v6)); +} + +struct ErrorReader; + +impl tokio::io::AsyncRead for ErrorReader { + fn poll_read( + self: std::pin::Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + _buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + std::task::Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "fake error"))) + } +} + +#[tokio::test] +async fn read_with_progress_returns_error_from_failed_reader() { + let mut reader = ErrorReader; + let mut buf = [0u8; 8]; + let err = read_with_progress(&mut reader, &mut buf).await.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof); +} + +#[test] +fn handshake_timeout_with_mask_grace_handles_maximum_values_without_overflow() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = u64::MAX; + config.censorship.mask = true; + + let timeout = handshake_timeout_with_mask_grace(&config); + assert!(timeout >= Duration::from_secs(u64::MAX)); +} + +#[tokio::test] +async fn read_with_progress_zero_length_buffer_returns_zero() { + let data = vec![1, 2, 3]; + let mut reader = std::io::Cursor::new(data); + let mut buf = []; + + let read = read_with_progress(&mut reader, &mut buf).await.unwrap(); + assert_eq!(read, 0); +} + +#[test] +fn handshake_timeout_without_mask_is_exact_base() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = 7; + config.censorship.mask = false; + + assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_secs(7)); +} + +#[test] +fn handshake_timeout_mask_enabled_adds_750ms() { + let mut config = ProxyConfig::default(); + config.timeouts.client_handshake = 3; + config.censorship.mask = true; + + assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_millis(3750)); +} + +#[tokio::test] +async fn read_with_progress_full_then_empty_transition() { + let data = vec![0x10, 0x20]; + let mut cursor = std::io::Cursor::new(data); + let mut buf = [0u8; 2]; + + assert_eq!(read_with_progress(&mut cursor, &mut buf).await.unwrap(), 2); + assert_eq!(read_with_progress(&mut cursor, &mut buf).await.unwrap(), 0); +} + +#[tokio::test] +async fn read_with_progress_fragmented_io_works_over_multiple_calls() { + let mut cursor = std::io::Cursor::new(vec![1, 2, 3, 4, 5]); + let mut result = Vec::new(); + + for chunk_size in 1..=5 { + let mut b = vec![0u8; chunk_size]; + let n = read_with_progress(&mut cursor, &mut b).await.unwrap(); + result.extend_from_slice(&b[..n]); + if n == 0 { break; } + } + + assert_eq!(result, vec![1,2,3,4,5]); +} + +#[tokio::test] +async fn read_with_progress_stress_randomized_chunk_sizes() { + for i in 0..128 { + let mut rng = StdRng::seed_from_u64(i as u64 + 1); + let mut input: Vec = (0..(i % 41)).map(|_| rng.next_u32() as u8).collect(); + let mut cursor = std::io::Cursor::new(input.clone()); + let mut collected = Vec::new(); + + while cursor.position() < cursor.get_ref().len() as u64 { + let chunk = 1 + (rng.next_u32() as usize % 8); + let mut b = vec![0u8; chunk]; + let read = read_with_progress(&mut cursor, &mut b).await.unwrap(); + collected.extend_from_slice(&b[..read]); + if read == 0 { break; } + } + + assert_eq!(collected, input); + } +} + +#[test] +fn is_trusted_proxy_source_boundary_narrow_ipv4() { + let matching = "172.16.0.1".parse().unwrap(); + let not_matching = "172.15.255.255".parse().unwrap(); + let cidr = vec!["172.16.0.0/12".parse().unwrap()]; + assert!(is_trusted_proxy_source(matching, &cidr)); + assert!(!is_trusted_proxy_source(not_matching, &cidr)); +} + +#[test] +fn is_trusted_proxy_source_rejects_out_of_family_ipv6_v4_cidr() { + let peer = "2001:db8::1".parse().unwrap(); + let cidr = vec!["10.0.0.0/8".parse().unwrap()]; + assert!(!is_trusted_proxy_source(peer, &cidr)); +} + +#[test] +fn wrap_tls_application_record_reserved_chunks_look_reasonable() { + let payload = vec![0xAA; 1 + (u16::MAX as usize) + 2]; + let wrapped = wrap_tls_application_record(&payload); + assert!(wrapped.len() > payload.len()); + assert!(wrapped.contains(&0x17)); +} + +#[test] +fn wrap_tls_application_record_roundtrip_size_check() { + let payload_len = 3000; + let payload = vec![0x55; payload_len]; + let wrapped = wrap_tls_application_record(&payload); + + let mut idx = 0; + let mut consumed = 0; + while idx + 5 <= wrapped.len() { + assert_eq!(wrapped[idx], 0x17); + let len = u16::from_be_bytes([wrapped[idx+3], wrapped[idx+4]]) as usize; + consumed += len; + idx += 5 + len; + if idx >= wrapped.len() { break; } + } + + assert_eq!(consumed, payload_len); +} + fn make_crypto_reader(reader: R) -> CryptoReader where R: tokio::io::AsyncRead + Unpin, diff --git a/src/proxy/tests/handshake_advanced_clever_tests.rs b/src/proxy/tests/handshake_advanced_clever_tests.rs new file mode 100644 index 0000000..9b12f21 --- /dev/null +++ b/src/proxy/tests/handshake_advanced_clever_tests.rs @@ -0,0 +1,647 @@ +use super::*; +use crate::crypto::{sha256, sha256_hmac, AesCtr}; +use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +// --- Helpers --- + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg.general.modes.classic = true; + cfg.general.modes.tls = true; + cfg +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +// --- Category 1: Edge Cases & Protocol Boundaries --- + +#[tokio::test] +async fn tls_minimum_viable_length_boundary() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x11u8; 16]; + let config = test_config_with_secret_hex("11111111111111111111111111111111"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap(); + + let min_len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1; + let mut exact_min_handshake = vec![0x42u8; min_len]; + exact_min_handshake[min_len - 1] = 0; + exact_min_handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let digest = sha256_hmac(&secret, &exact_min_handshake); + exact_min_handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + let res = handle_tls_handshake( + &exact_min_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res, HandshakeResult::Success(_)), "Exact minimum length TLS handshake must succeed"); + + let short_handshake = vec![0x42u8; min_len - 1]; + let res_short = handle_tls_handshake( + &short_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res_short, HandshakeResult::BadClient { .. }), "Handshake 1 byte shorter than minimum must fail closed"); +} + +#[tokio::test] +async fn mtproto_extreme_dc_index_serialization() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "22222222222222222222222222222222"; + let config = test_config_with_secret_hex(secret_hex); + for (idx, extreme_dc) in [i16::MIN, i16::MAX, -1, 0].into_iter().enumerate() { + // Keep replay state independent per case so we validate dc_idx encoding, + // not duplicate-handshake rejection behavior. + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 0, 2, 2)), 12345 + idx as u16); + let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, extreme_dc); + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + match res { + HandshakeResult::Success((_, _, success)) => { + assert_eq!(success.dc_idx, extreme_dc, "Extreme DC index {} must serialize/deserialize perfectly", extreme_dc); + } + _ => panic!("MTProto handshake with extreme DC index {} failed", extreme_dc), + } + } +} + +#[tokio::test] +async fn alpn_strict_case_and_padding_rejection() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x33u8; 16]; + let mut config = test_config_with_secret_hex("33333333333333333333333333333333"); + config.censorship.alpn_enforce = true; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap(); + + let bad_alpns: &[&[u8]] = &[b"H2", b"h2\0", b" http/1.1", b"http/1.1\n"]; + + for bad_alpn in bad_alpns { + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[*bad_alpn]); + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res, HandshakeResult::BadClient { .. }), "ALPN strict enforcement must reject {:?}", bad_alpn); + } +} + +#[test] +fn ipv4_mapped_ipv6_bucketing_anomaly() { + let ipv4_mapped_1 = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc000, 0x0201)); + let ipv4_mapped_2 = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc633, 0x6402)); + + let norm_1 = normalize_auth_probe_ip(ipv4_mapped_1); + let norm_2 = normalize_auth_probe_ip(ipv4_mapped_2); + + assert_eq!(norm_1, norm_2, "IPv4-mapped IPv6 addresses must collapse into the same /64 bucket (::0)"); + assert_eq!(norm_1, IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), "The bucket must be exactly ::0"); +} + +// --- Category 2: Adversarial & Black Hat --- + +#[tokio::test] +async fn mtproto_invalid_ciphertext_does_not_poison_replay_cache() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "55555555555555555555555555555555"; + let config = test_config_with_secret_hex(secret_hex); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.5:12345".parse().unwrap(); + + let valid_handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1); + let mut invalid_handshake = valid_handshake; + invalid_handshake[SKIP_LEN + PREKEY_LEN + IV_LEN + 1] ^= 0xFF; + + let res_invalid = handle_mtproto_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(res_invalid, HandshakeResult::BadClient { .. })); + + let res_valid = handle_mtproto_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + assert!(matches!(res_valid, HandshakeResult::Success(_)), "Invalid MTProto ciphertext must not poison the replay cache"); +} + +#[tokio::test] +async fn tls_invalid_session_does_not_poison_replay_cache() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x66u8; 16]; + let config = test_config_with_secret_hex("66666666666666666666666666666666"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.6:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + let mut invalid_handshake = valid_handshake.clone(); + let session_idx = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1; + invalid_handshake[session_idx] ^= 0xFF; + + let res_invalid = handle_tls_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res_invalid, HandshakeResult::BadClient { .. })); + + let res_valid = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res_valid, HandshakeResult::Success(_)), "Invalid TLS payload must not poison the replay cache"); +} + +#[tokio::test] +async fn server_hello_delay_timing_neutrality_on_hmac_failure() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x77u8; 16]; + let mut config = test_config_with_secret_hex("77777777777777777777777777777777"); + config.censorship.server_hello_delay_min_ms = 50; + config.censorship.server_hello_delay_max_ms = 50; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.7:12345".parse().unwrap(); + + let mut invalid_handshake = make_valid_tls_handshake(&secret, 0); + invalid_handshake[tls::TLS_DIGEST_POS] ^= 0xFF; + + let start = Instant::now(); + let res = handle_tls_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = start.elapsed(); + + assert!(matches!(res, HandshakeResult::BadClient { .. })); + assert!(elapsed >= Duration::from_millis(45), "Invalid HMAC must still incur the configured ServerHello delay to prevent timing side-channels"); +} + +#[tokio::test] +async fn server_hello_delay_inversion_resilience() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x88u8; 16]; + let mut config = test_config_with_secret_hex("88888888888888888888888888888888"); + config.censorship.server_hello_delay_min_ms = 100; + config.censorship.server_hello_delay_max_ms = 10; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.8:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + + let start = Instant::now(); + let res = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + let elapsed = start.elapsed(); + + assert!(matches!(res, HandshakeResult::Success(_))); + assert!(elapsed >= Duration::from_millis(90), "Delay logic must gracefully handle min > max inversions via max.max(min)"); +} + +#[tokio::test] +async fn mixed_valid_and_invalid_user_secrets_configuration() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + let _warn_guard = warned_secrets_test_lock().lock().unwrap(); + clear_warned_secrets_for_testing(); + + let mut config = ProxyConfig::default(); + config.access.ignore_time_skew = true; + + for i in 0..9 { + let bad_secret = if i % 2 == 0 { "badhex!" } else { "1122" }; + config.access.users.insert(format!("bad_user_{}", i), bad_secret.to_string()); + } + let valid_secret_hex = "99999999999999999999999999999999"; + config.access.users.insert("good_user".to_string(), valid_secret_hex.to_string()); + config.general.modes.secure = true; + config.general.modes.classic = true; + config.general.modes.tls = true; + + let secret = [0x99u8; 16]; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.9:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + + let res = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_)), "Proxy must gracefully skip invalid secrets and authenticate the valid one"); +} + +#[tokio::test] +async fn tls_emulation_fallback_when_cache_missing() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0xAAu8; 16]; + let mut config = test_config_with_secret_hex("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + config.censorship.tls_emulation = true; + config.general.modes.tls = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.10:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_handshake(&secret, 0); + + let res = handle_tls_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_)), "TLS emulation must gracefully fall back to standard ServerHello if cache is missing"); +} + +#[tokio::test] +async fn classic_mode_over_tls_transport_protocol_confusion() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"; + let mut config = test_config_with_secret_hex(secret_hex); + config.general.modes.classic = true; + config.general.modes.tls = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.11:12345".parse().unwrap(); + + let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Intermediate, 1); + + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + true, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_)), "Intermediate tag over TLS must succeed if classic mode is enabled, locking in cross-transport behavior"); +} + +#[test] +fn generate_tg_nonce_never_emits_reserved_bytes() { + let client_enc_key = [0xCCu8; 32]; + let client_enc_iv = 123456789u128; + let rng = SecureRandom::new(); + + for _ in 0..10_000 { + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 1, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + assert!(!RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]), "Nonce must never start with reserved bytes"); + let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]]; + assert!(!RESERVED_NONCE_BEGINNINGS.contains(&first_four), "Nonce must never match reserved 4-byte beginnings"); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn dashmap_concurrent_saturation_stress() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let ip_a: IpAddr = "192.0.2.13".parse().unwrap(); + let ip_b: IpAddr = "198.51.100.13".parse().unwrap(); + let mut tasks = Vec::new(); + + for i in 0..100 { + let target_ip = if i % 2 == 0 { ip_a } else { ip_b }; + tasks.push(tokio::spawn(async move { + for _ in 0..50 { + auth_probe_record_failure(target_ip, Instant::now()); + } + })); + } + + for task in tasks { + task.await.expect("Task panicked during concurrent DashMap stress"); + } + + assert!(auth_probe_is_throttled_for_testing(ip_a), "IP A must be throttled after concurrent stress"); + assert!(auth_probe_is_throttled_for_testing(ip_b), "IP B must be throttled after concurrent stress"); +} + +#[test] +fn prototag_invalid_bytes_fail_closed() { + let invalid_tags: [[u8; 4]; 5] = [ + [0, 0, 0, 0], + [0xFF, 0xFF, 0xFF, 0xFF], + [0xDE, 0xAD, 0xBE, 0xEF], + [0xDD, 0xDD, 0xDD, 0xDE], + [0x11, 0x22, 0x33, 0x44], + ]; + + for tag in invalid_tags { + assert_eq!(ProtoTag::from_bytes(tag), None, "Invalid ProtoTag bytes {:?} must fail closed", tag); + } +} + +#[test] +fn auth_probe_eviction_hash_collision_stress() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let now = Instant::now(); + + for i in 0..10_000u32 { + let ip = IpAddr::V4(Ipv4Addr::new(10, 0, (i >> 8) as u8, (i & 0xFF) as u8)); + auth_probe_record_failure_with_state(state, ip, now); + } + + assert!(state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, "Eviction logic must successfully bound the map size under heavy insertion stress"); +} + +#[test] +fn encrypt_tg_nonce_with_ciphers_advances_counter_correctly() { + let client_enc_key = [0xDDu8; 32]; + let client_enc_iv = 987654321u128; + let rng = SecureRandom::new(); + + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + let (_, mut returned_encryptor, _) = encrypt_tg_nonce_with_ciphers(&nonce); + let zeros = [0u8; 64]; + let returned_keystream = returned_encryptor.encrypt(&zeros); + + let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN]; + let mut expected_enc_key = [0u8; 32]; + expected_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]); + let mut expected_enc_iv_arr = [0u8; IV_LEN]; + expected_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]); + let expected_enc_iv = u128::from_be_bytes(expected_enc_iv_arr); + + let mut manual_encryptor = AesCtr::new(&expected_enc_key, expected_enc_iv); + + let mut manual_input = Vec::new(); + manual_input.extend_from_slice(&nonce); + manual_input.extend_from_slice(&zeros); + let manual_output = manual_encryptor.encrypt(&manual_input); + + assert_eq!( + returned_keystream, + &manual_output[64..128], + "encrypt_tg_nonce_with_ciphers must correctly advance the AES-CTR counter by exactly the nonce length" + ); +} diff --git a/src/proxy/tests/handshake_more_clever_tests.rs b/src/proxy/tests/handshake_more_clever_tests.rs new file mode 100644 index 0000000..77df442 --- /dev/null +++ b/src/proxy/tests/handshake_more_clever_tests.rs @@ -0,0 +1,614 @@ +use super::*; +use crate::crypto::{sha256, sha256_hmac, AesCtr}; +use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; +use rand::{Rng, SeedableRng}; +use rand::rngs::StdRng; +use std::collections::HashSet; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Barrier; + +// --- Helpers --- + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg.general.modes.classic = true; + cfg.general.modes.tls = true; + cfg +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_tls_client_hello_with_sni_and_alpn( + secret: &[u8], + timestamp: u32, + sni_host: &str, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + + let host_bytes = sni_host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +// --- Category 1: Timing & Delay Invariants --- + +#[tokio::test] +async fn server_hello_delay_bypassed_if_max_is_zero_despite_high_min() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x1Au8; 16]; + let mut config = test_config_with_secret_hex("1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a"); + config.censorship.server_hello_delay_min_ms = 5000; + config.censorship.server_hello_delay_max_ms = 0; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.101:12345".parse().unwrap(); + + let mut invalid_handshake = make_valid_tls_handshake(&secret, 0); + invalid_handshake[tls::TLS_DIGEST_POS] ^= 0xFF; + + let fut = handle_tls_handshake( + &invalid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ); + + // Deterministic assertion: with max_ms == 0 there must be no sleep path, + // so the handshake should complete promptly under a generous timeout budget. + let res = tokio::time::timeout(Duration::from_millis(250), fut) + .await + .expect("max_ms=0 should bypass artificial delay and complete quickly"); + + assert!(matches!(res, HandshakeResult::BadClient { .. })); +} + +#[test] +fn auth_probe_backoff_extreme_fail_streak_clamps_safely() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 99)); + let now = Instant::now(); + + state.insert( + peer_ip, + AuthProbeState { + fail_streak: u32::MAX - 1, + blocked_until: now, + last_seen: now, + }, + ); + + auth_probe_record_failure_with_state(&state, peer_ip, now); + + let updated = state.get(&peer_ip).unwrap(); + assert_eq!(updated.fail_streak, u32::MAX); + + let expected_blocked_until = now + Duration::from_millis(AUTH_PROBE_BACKOFF_MAX_MS); + assert_eq!(updated.blocked_until, expected_blocked_until, "Extreme fail streak must clamp cleanly to AUTH_PROBE_BACKOFF_MAX_MS"); +} + +#[test] +fn generate_tg_nonce_cryptographic_uniqueness_and_entropy() { + let client_enc_key = [0x2Bu8; 32]; + let client_enc_iv = 1337u128; + let rng = SecureRandom::new(); + + let mut nonces = HashSet::new(); + let mut total_set_bits = 0usize; + let iterations = 5_000; + + for _ in 0..iterations { + let (nonce, _, _, _, _) = generate_tg_nonce( + ProtoTag::Secure, + 2, + &client_enc_key, + client_enc_iv, + &rng, + false, + ); + + for byte in nonce.iter() { + total_set_bits += byte.count_ones() as usize; + } + + assert!(nonces.insert(nonce), "generate_tg_nonce emitted a duplicate nonce! RNG is stuck."); + } + + let total_bits = iterations * HANDSHAKE_LEN * 8; + let ratio = (total_set_bits as f64) / (total_bits as f64); + assert!(ratio > 0.48 && ratio < 0.52, "Nonce entropy is degraded. Set bit ratio: {}", ratio); +} + +#[tokio::test] +async fn mtproto_multi_user_decryption_isolation() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let mut config = ProxyConfig::default(); + config.general.modes.secure = true; + config.access.ignore_time_skew = true; + + config.access.users.insert("user_a".to_string(), "11111111111111111111111111111111".to_string()); + config.access.users.insert("user_b".to_string(), "22222222222222222222222222222222".to_string()); + let good_secret_hex = "33333333333333333333333333333333"; + config.access.users.insert("user_c".to_string(), good_secret_hex.to_string()); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.104:12345".parse().unwrap(); + + let valid_handshake = make_valid_mtproto_handshake(good_secret_hex, ProtoTag::Secure, 1); + + let res = handle_mtproto_handshake( + &valid_handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + match res { + HandshakeResult::Success((_, _, success)) => { + assert_eq!(success.user, "user_c", "Decryption attempts on previous users must not corrupt the handshake buffer for the valid user"); + } + _ => panic!("Multi-user MTProto handshake failed. Decryption buffer might be mutating in place."), + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn invalid_secret_warning_lock_contention_and_bound() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + + let tasks = 50; + let iterations_per_task = 100; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::new(); + + for t in 0..tasks { + let b = barrier.clone(); + handles.push(tokio::spawn(async move { + b.wait().await; + for i in 0..iterations_per_task { + let user_name = format!("contention_user_{}_{}", t, i); + warn_invalid_secret_once(&user_name, "invalid_hex", ACCESS_SECRET_BYTES, None); + } + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + let warned = INVALID_SECRET_WARNED.get().unwrap(); + let guard = warned.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + + assert_eq!( + guard.len(), + WARNED_SECRET_MAX_ENTRIES, + "Concurrent spam of invalid secrets must strictly bound the HashSet memory to WARNED_SECRET_MAX_ENTRIES" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn mtproto_strict_concurrent_replay_race_condition() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret_hex = "4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A"; + let config = Arc::new(test_config_with_secret_hex(secret_hex)); + let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); + let valid_handshake = Arc::new(make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1)); + + let tasks = 100; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::new(); + + for i in 0..tasks { + let b = barrier.clone(); + let cfg = config.clone(); + let rc = replay_checker.clone(); + let hs = valid_handshake.clone(); + + handles.push(tokio::spawn(async move { + let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 250) as u8)), 10000 + i as u16); + b.wait().await; + handle_mtproto_handshake( + &hs, + tokio::io::empty(), + tokio::io::sink(), + peer, + &cfg, + &rc, + false, + None, + ) + .await + })); + } + + let mut successes = 0; + let mut failures = 0; + + for handle in handles { + match handle.await.unwrap() { + HandshakeResult::Success(_) => successes += 1, + HandshakeResult::BadClient { .. } => failures += 1, + _ => panic!("Unexpected error result in concurrent MTProto replay test"), + } + } + + assert_eq!(successes, 1, "Replay cache race condition allowed multiple identical MTProto handshakes to succeed"); + assert_eq!(failures, tasks - 1, "Replay cache failed to forcefully reject concurrent duplicates"); +} + +#[tokio::test] +async fn tls_alpn_zero_length_protocol_handled_safely() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x5Bu8; 16]; + let mut config = test_config_with_secret_hex("5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b"); + config.censorship.alpn_enforce = true; + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.107:12345".parse().unwrap(); + + let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b""]); + + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::BadClient { .. }), "0-length ALPN must be safely rejected without panicking"); +} + +#[tokio::test] +async fn tls_sni_massive_hostname_does_not_panic() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x6Cu8; 16]; + let config = test_config_with_secret_hex("6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.108:12345".parse().unwrap(); + + let massive_hostname = String::from_utf8(vec![b'a'; 65000]).unwrap(); + let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, &massive_hostname, &[]); + + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_) | HandshakeResult::BadClient { .. }), "Massive SNI hostname must be processed or ignored without stack overflow or panic"); +} + +#[tokio::test] +async fn tls_progressive_truncation_fuzzing_no_panics() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x7Du8; 16]; + let config = test_config_with_secret_hex("7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.109:12345".parse().unwrap(); + + let valid_handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b"h2"]); + let full_len = valid_handshake.len(); + + // Truncated corpus only: full_len is a valid baseline and should not be + // asserted as BadClient in a truncation-specific test. + for i in (0..full_len).rev() { + let truncated = &valid_handshake[..i]; + let res = handle_tls_handshake( + truncated, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + assert!(matches!(res, HandshakeResult::BadClient { .. }), "Truncated TLS handshake at len {} must fail safely without panicking", i); + } +} + +#[tokio::test] +async fn mtproto_pure_entropy_fuzzing_no_panics() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.110:12345".parse().unwrap(); + + let mut seeded = StdRng::seed_from_u64(0xDEADBEEFCAFE); + + for _ in 0..10_000 { + let mut noise = [0u8; HANDSHAKE_LEN]; + seeded.fill_bytes(&mut noise); + + let res = handle_mtproto_handshake( + &noise, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::BadClient { .. }), "Pure entropy MTProto payload must fail closed and never panic"); + } +} + +#[test] +fn decode_user_secret_odd_length_hex_rejection() { + let _guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_warned_secrets_for_testing(); + + let mut config = ProxyConfig::default(); + config.access.users.clear(); + config.access.users.insert("odd_user".to_string(), "1234567890123456789012345678901".to_string()); + + let decoded = decode_user_secrets(&config, None); + assert!(decoded.is_empty(), "Odd-length hex string must be gracefully rejected by hex::decode without unwrapping"); +} + +#[test] +fn saturation_grace_pre_existing_high_fail_streak_immediate_throttle() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 112)); + let now = Instant::now(); + + let extreme_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS + 5; + state.insert( + peer_ip, + AuthProbeState { + fail_streak: extreme_streak, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }, + ); + + { + let mut guard = auth_probe_saturation_state_lock(); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let is_throttled = auth_probe_should_apply_preauth_throttle(peer_ip, now); + assert!(is_throttled, "A peer with a pre-existing high fail streak must be immediately throttled when saturation begins, receiving no unearned grace period"); +} + +#[test] +fn auth_probe_saturation_note_resets_retention_window() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let base_time = Instant::now(); + + auth_probe_note_saturation(base_time); + let later = base_time + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS - 1); + auth_probe_note_saturation(later); + + let check_time = base_time + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 5); + + // This call may return false if backoff has elapsed, but it must not clear + // the saturation state because `later` refreshed last_seen. + let _ = auth_probe_saturation_is_throttled_at_for_testing(check_time); + let guard = auth_probe_saturation_state_lock(); + assert!( + guard.is_some(), + "Ongoing saturation notes must refresh last_seen so saturation state remains retained past the original window" + ); +} + +#[test] +fn mtproto_classic_tags_rejected_when_only_secure_mode_enabled() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = false; + config.general.modes.secure = true; + config.general.modes.tls = false; + + assert!(!mode_enabled_for_proto(&config, ProtoTag::Abridged, false)); + assert!(!mode_enabled_for_proto(&config, ProtoTag::Intermediate, false)); +} + +#[test] +fn mtproto_secure_tag_rejected_when_only_classic_mode_enabled() { + let mut config = ProxyConfig::default(); + config.general.modes.classic = true; + config.general.modes.secure = false; + config.general.modes.tls = false; + + assert!(!mode_enabled_for_proto(&config, ProtoTag::Secure, false)); +} + +#[test] +fn ipv6_localhost_and_unspecified_normalization() { + let localhost = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)); + let unspecified = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)); + + let norm_local = normalize_auth_probe_ip(localhost); + let norm_unspec = normalize_auth_probe_ip(unspecified); + + let expected_bucket = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)); + + assert_eq!(norm_local, expected_bucket); + assert_eq!(norm_unspec, expected_bucket); +} diff --git a/src/proxy/tests/handshake_real_bug_stress_tests.rs b/src/proxy/tests/handshake_real_bug_stress_tests.rs new file mode 100644 index 0000000..d7234ff --- /dev/null +++ b/src/proxy/tests/handshake_real_bug_stress_tests.rs @@ -0,0 +1,337 @@ +use super::*; +use crate::crypto::{sha256, sha256_hmac, AesCtr, SecureRandom}; +use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::sync::Barrier; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { + let mut cfg = ProxyConfig::default(); + cfg.access.users.clear(); + cfg.access + .users + .insert("user".to_string(), secret_hex.to_string()); + cfg.access.ignore_time_skew = true; + cfg.general.modes.secure = true; + cfg.general.modes.classic = true; + cfg.general.modes.tls = true; + cfg +} + +fn make_valid_tls_client_hello_with_alpn( + secret: &[u8], + timestamp: u32, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + + record +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +#[tokio::test] +async fn tls_alpn_reject_does_not_pollute_replay_cache() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let secret = [0x11u8; 16]; + let mut config = test_config_with_secret_hex("11111111111111111111111111111111"); + config.censorship.alpn_enforce = true; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.201:12345".parse().unwrap(); + + let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]); + let before = replay_checker.stats(); + + let res = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + let after = replay_checker.stats(); + + assert!(matches!(res, HandshakeResult::BadClient { .. })); + assert_eq!( + before.total_additions, after.total_additions, + "ALPN policy reject must not add TLS digest into replay cache" + ); +} + +#[tokio::test] +async fn tls_truncated_session_id_len_fails_closed_without_panic() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let config = test_config_with_secret_hex("33333333333333333333333333333333"); + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "192.0.2.203:12345".parse().unwrap(); + + let min_len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1; + let mut malicious = vec![0x42u8; min_len]; + malicious[min_len - 1] = u8::MAX; + + let res = handle_tls_handshake( + &malicious, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::BadClient { .. })); +} + +#[test] +fn auth_probe_eviction_identical_timestamps_keeps_map_bounded() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let state = auth_probe_state_map(); + let same = Instant::now(); + + for i in 0..AUTH_PROBE_TRACK_MAX_ENTRIES { + let ip = IpAddr::V4(Ipv4Addr::new(10, 1, (i >> 8) as u8, (i & 0xFF) as u8)); + state.insert( + ip, + AuthProbeState { + fail_streak: 7, + blocked_until: same, + last_seen: same, + }, + ); + } + + let new_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 21, 21)); + auth_probe_record_failure_with_state(state, new_ip, same + Duration::from_millis(1)); + + assert_eq!(state.len(), AUTH_PROBE_TRACK_MAX_ENTRIES); + assert!(state.contains_key(&new_ip)); +} + +#[test] +fn clear_auth_probe_state_recovers_from_poisoned_saturation_lock() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let saturation = auth_probe_saturation_state(); + let poison_thread = std::thread::spawn(move || { + let _hold = saturation + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + panic!("intentional poison for regression coverage"); + }); + let _ = poison_thread.join(); + + clear_auth_probe_state_for_testing(); + + let guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + assert!(guard.is_none()); +} + +#[tokio::test] +async fn mtproto_invalid_length_secret_is_ignored_and_valid_user_still_auths() { + let _probe_guard = auth_probe_test_guard(); + let _warn_guard = warned_secrets_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + clear_auth_probe_state_for_testing(); + clear_warned_secrets_for_testing(); + + let mut config = ProxyConfig::default(); + config.general.modes.secure = true; + config.access.ignore_time_skew = true; + + config.access.users.insert( + "short_user".to_string(), + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string(), + ); + + let valid_secret_hex = "77777777777777777777777777777777"; + config + .access + .users + .insert("good_user".to_string(), valid_secret_hex.to_string()); + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let peer: SocketAddr = "192.0.2.207:12345".parse().unwrap(); + let handshake = make_valid_mtproto_handshake(valid_secret_hex, ProtoTag::Secure, 1); + + let res = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + false, + None, + ) + .await; + + assert!(matches!(res, HandshakeResult::Success(_))); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn saturation_grace_exhaustion_under_concurrency_keeps_peer_throttled() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 80)); + let now = Instant::now(); + + { + let mut guard = auth_probe_saturation_state() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); + *guard = Some(AuthProbeSaturationState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS, + blocked_until: now + Duration::from_secs(5), + last_seen: now, + }); + } + + let state = auth_probe_state_map(); + state.insert( + peer_ip, + AuthProbeState { + fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS - 1, + blocked_until: now, + last_seen: now, + }, + ); + + let tasks = 32; + let barrier = Arc::new(Barrier::new(tasks)); + let mut handles = Vec::new(); + + for _ in 0..tasks { + let b = barrier.clone(); + handles.push(tokio::spawn(async move { + b.wait().await; + auth_probe_record_failure(peer_ip, Instant::now()); + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + let final_state = state.get(&peer_ip).expect("state must exist"); + assert!( + final_state.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS + ); + assert!(auth_probe_should_apply_preauth_throttle(peer_ip, Instant::now())); +} diff --git a/src/proxy/tests/handshake_timing_manual_bench_tests.rs b/src/proxy/tests/handshake_timing_manual_bench_tests.rs new file mode 100644 index 0000000..95e9f49 --- /dev/null +++ b/src/proxy/tests/handshake_timing_manual_bench_tests.rs @@ -0,0 +1,318 @@ +use super::*; +use crate::crypto::{sha256, sha256_hmac, AesCtr, SecureRandom}; +use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION}; +use std::net::SocketAddr; +use std::time::{Duration, Instant}; + +fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> { + auth_probe_test_lock() + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) +} + +fn make_valid_mtproto_handshake( + secret_hex: &str, + proto_tag: ProtoTag, + dc_idx: i16, + salt: u8, +) -> [u8; HANDSHAKE_LEN] { + let secret = hex::decode(secret_hex).expect("secret hex must decode"); + let mut handshake = [0x5Au8; HANDSHAKE_LEN]; + + for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN] + .iter_mut() + .enumerate() + { + *b = (idx as u8).wrapping_add(1).wrapping_add(salt); + } + + let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN]; + let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; + + let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len()); + dec_key_input.extend_from_slice(dec_prekey); + dec_key_input.extend_from_slice(&secret); + let dec_key = sha256(&dec_key_input); + + let mut dec_iv_arr = [0u8; IV_LEN]; + dec_iv_arr.copy_from_slice(dec_iv_bytes); + let dec_iv = u128::from_be_bytes(dec_iv_arr); + + let mut stream = AesCtr::new(&dec_key, dec_iv); + let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]); + + let mut target_plain = [0u8; HANDSHAKE_LEN]; + target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes()); + target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes()); + + for idx in PROTO_TAG_POS..HANDSHAKE_LEN { + handshake[idx] = target_plain[idx] ^ keystream[idx]; + } + + handshake +} + +fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec { + let session_id_len: usize = 32; + let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len; + let mut handshake = vec![0x42u8; len]; + + handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8; + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + + let computed = sha256_hmac(secret, &handshake); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + + handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] + .copy_from_slice(&digest); + handshake +} + +fn make_valid_tls_client_hello_with_sni_and_alpn( + secret: &[u8], + timestamp: u32, + sni_host: &str, + alpn_protocols: &[&[u8]], +) -> Vec { + let mut body = Vec::new(); + body.extend_from_slice(&TLS_VERSION); + body.extend_from_slice(&[0u8; 32]); + body.push(32); + body.extend_from_slice(&[0x42u8; 32]); + body.extend_from_slice(&2u16.to_be_bytes()); + body.extend_from_slice(&[0x13, 0x01]); + body.push(1); + body.push(0); + + let mut ext_blob = Vec::new(); + let host_bytes = sni_host.as_bytes(); + let mut sni_payload = Vec::new(); + sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes()); + sni_payload.push(0); + sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes()); + sni_payload.extend_from_slice(host_bytes); + ext_blob.extend_from_slice(&0x0000u16.to_be_bytes()); + ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&sni_payload); + + if !alpn_protocols.is_empty() { + let mut alpn_list = Vec::new(); + for proto in alpn_protocols { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + let mut alpn_data = Vec::new(); + alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + alpn_data.extend_from_slice(&alpn_list); + + ext_blob.extend_from_slice(&0x0010u16.to_be_bytes()); + ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes()); + ext_blob.extend_from_slice(&alpn_data); + } + + body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes()); + body.extend_from_slice(&ext_blob); + + let mut handshake = Vec::new(); + handshake.push(0x01); + let body_len = (body.len() as u32).to_be_bytes(); + handshake.extend_from_slice(&body_len[1..4]); + handshake.extend_from_slice(&body); + + let mut record = Vec::new(); + record.push(TLS_RECORD_HANDSHAKE); + record.extend_from_slice(&[0x03, 0x01]); + record.extend_from_slice(&(handshake.len() as u16).to_be_bytes()); + record.extend_from_slice(&handshake); + + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0); + let computed = sha256_hmac(secret, &record); + let mut digest = computed; + let ts = timestamp.to_le_bytes(); + for i in 0..4 { + digest[28 + i] ^= ts[i]; + } + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); + + record +} + +fn median_ns(samples: &mut [u128]) -> u128 { + samples.sort_unstable(); + samples[samples.len() / 2] +} + +#[tokio::test] +#[ignore = "manual benchmark: timing-sensitive and host-dependent"] +async fn mtproto_user_scan_timing_manual_benchmark() { + let _guard = auth_probe_test_guard(); + clear_auth_probe_state_for_testing(); + + const DECOY_USERS: usize = 8_000; + const ITERATIONS: usize = 250; + + let preferred_user = "target_user"; + let target_secret_hex = "dededededededededededededededede"; + + let mut config = ProxyConfig::default(); + config.general.modes.secure = true; + config.access.ignore_time_skew = true; + + for i in 0..DECOY_USERS { + config.access.users.insert( + format!("decoy_{i}"), + "00000000000000000000000000000000".to_string(), + ); + } + + config.access.users.insert( + preferred_user.to_string(), + target_secret_hex.to_string(), + ); + + let replay_checker_preferred = ReplayChecker::new(65_536, Duration::from_secs(60)); + let replay_checker_full_scan = ReplayChecker::new(65_536, Duration::from_secs(60)); + let peer_a: SocketAddr = "192.0.2.241:12345".parse().unwrap(); + let peer_b: SocketAddr = "192.0.2.242:12345".parse().unwrap(); + + let mut preferred_samples = Vec::with_capacity(ITERATIONS); + let mut full_scan_samples = Vec::with_capacity(ITERATIONS); + + for i in 0..ITERATIONS { + let handshake = make_valid_mtproto_handshake( + target_secret_hex, + ProtoTag::Secure, + 1 + i as i16, + (i % 251) as u8, + ); + + let started_preferred = Instant::now(); + let preferred = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer_a, + &config, + &replay_checker_preferred, + false, + Some(preferred_user), + ) + .await; + preferred_samples.push(started_preferred.elapsed().as_nanos()); + assert!(matches!(preferred, HandshakeResult::Success(_))); + + let started_scan = Instant::now(); + let full_scan = handle_mtproto_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer_b, + &config, + &replay_checker_full_scan, + false, + None, + ) + .await; + full_scan_samples.push(started_scan.elapsed().as_nanos()); + assert!(matches!(full_scan, HandshakeResult::Success(_))); + } + + let preferred_median = median_ns(&mut preferred_samples); + let full_scan_median = median_ns(&mut full_scan_samples); + + let ratio = if preferred_median == 0 { + 0.0 + } else { + full_scan_median as f64 / preferred_median as f64 + }; + + println!( + "manual timing benchmark: decoys={DECOY_USERS}, iters={ITERATIONS}, preferred_median_ns={preferred_median}, full_scan_median_ns={full_scan_median}, ratio={ratio:.3}" + ); + + assert!( + full_scan_median >= preferred_median, + "full user scan should not be faster than preferred-user path in this benchmark" + ); +} + +#[tokio::test] +#[ignore = "manual benchmark: timing-sensitive and host-dependent"] +async fn tls_sni_preferred_vs_no_sni_fallback_manual_benchmark() { + let _guard = auth_probe_test_guard(); + + const DECOY_USERS: usize = 8_000; + const ITERATIONS: usize = 250; + + let preferred_user = "user-b"; + let target_secret_hex = "abababababababababababababababab"; + let target_secret = [0xABu8; 16]; + + let mut config = ProxyConfig::default(); + config.general.modes.tls = true; + config.access.ignore_time_skew = true; + + for i in 0..DECOY_USERS { + config.access.users.insert( + format!("decoy_{i}"), + "00000000000000000000000000000000".to_string(), + ); + } + + config + .access + .users + .insert(preferred_user.to_string(), target_secret_hex.to_string()); + + let mut sni_samples = Vec::with_capacity(ITERATIONS); + let mut no_sni_samples = Vec::with_capacity(ITERATIONS); + + for i in 0..ITERATIONS { + let with_sni = make_valid_tls_client_hello_with_sni_and_alpn( + &target_secret, + i as u32, + preferred_user, + &[b"h2"], + ); + let no_sni = make_valid_tls_handshake(&target_secret, (i as u32).wrapping_add(10_000)); + + let started_sni = Instant::now(); + let sni_secrets = decode_user_secrets(&config, Some(preferred_user)); + let sni_result = tls::validate_tls_handshake_with_replay_window( + &with_sni, + &sni_secrets, + config.access.ignore_time_skew, + config.access.replay_window_secs, + ); + sni_samples.push(started_sni.elapsed().as_nanos()); + assert!(sni_result.is_some()); + + let started_no_sni = Instant::now(); + let no_sni_secrets = decode_user_secrets(&config, None); + let no_sni_result = tls::validate_tls_handshake_with_replay_window( + &no_sni, + &no_sni_secrets, + config.access.ignore_time_skew, + config.access.replay_window_secs, + ); + no_sni_samples.push(started_no_sni.elapsed().as_nanos()); + assert!(no_sni_result.is_some()); + } + + let sni_median = median_ns(&mut sni_samples); + let no_sni_median = median_ns(&mut no_sni_samples); + + let ratio = if sni_median == 0 { + 0.0 + } else { + no_sni_median as f64 / sni_median as f64 + }; + + println!( + "manual tls benchmark: decoys={DECOY_USERS}, iters={ITERATIONS}, sni_median_ns={sni_median}, no_sni_median_ns={no_sni_median}, ratio_no_sni_over_sni={ratio:.3}" + ); +}