From 81ae483201d89c9fcf9b8744cdef83e09ac2b0b0 Mon Sep 17 00:00:00 2001 From: Alexey <247128645+axkurcom@users.noreply.github.com> Date: Tue, 30 Jun 2026 13:13:11 +0300 Subject: [PATCH] Add regression coverage for ME routing, D2C padding, synlimit, and MSS bulk validation Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com> --- src/config/load.rs | 4 + src/config/tests/load_basic_tests.rs | 27 ++++ src/proxy/middle_relay.rs | 8 +- ..._relay_d2c_flush_padding_security_tests.rs | 148 ++++++++++++++++++ src/synlimit_control/iptables.rs | 34 ++++ src/synlimit_control/model.rs | 108 +++++++++++++ src/synlimit_control/nftables.rs | 33 ++++ src/transport/middle_proxy/handshake.rs | 109 +++++++++++++ .../tests/send_adversarial_tests.rs | 60 +++++++ 9 files changed, 530 insertions(+), 1 deletion(-) create mode 100644 src/proxy/tests/middle_relay_d2c_flush_padding_security_tests.rs diff --git a/src/config/load.rs b/src/config/load.rs index 21932ec..20a571a 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -958,6 +958,10 @@ impl ProxyConfig { .server .client_mss_value() .map_err(|error| ProxyError::Config(format!("server.client_mss {error}")))?; + config + .server + .client_mss_bulk_value() + .map_err(|error| ProxyError::Config(format!("server.client_mss_bulk {error}")))?; for (idx, listener) in config.server.listeners.iter().enumerate() { if listener.client_mss.is_some() { listener diff --git a/src/config/tests/load_basic_tests.rs b/src/config/tests/load_basic_tests.rs index 4e46bfe..bb33770 100644 --- a/src/config/tests/load_basic_tests.rs +++ b/src/config/tests/load_basic_tests.rs @@ -1652,6 +1652,7 @@ fn client_mss_custom_value_is_accepted() { let toml = r#" [server] client_mss = "4096" + client_mss_bulk = "1400" [censorship] tls_domain = "example.com" @@ -1665,6 +1666,7 @@ fn client_mss_custom_value_is_accepted() { let cfg = ProxyConfig::load(&path).unwrap(); assert_eq!(cfg.server.client_mss_value(), Ok(Some(4096))); + assert_eq!(cfg.server.client_mss_bulk_value(), Ok(Some(1400))); let _ = std::fs::remove_file(path); } @@ -1693,6 +1695,31 @@ fn client_mss_out_of_range_is_rejected() { } } +#[test] +fn client_mss_bulk_out_of_range_is_rejected() { + for value in ["87", "4097"] { + let toml = format!( + r#" + [server] + client_mss_bulk = "{value}" + + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "# + ); + let dir = std::env::temp_dir(); + let path = dir.join(format!("telemt_client_mss_bulk_out_of_range_{value}_test.toml")); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + + assert!(err.contains("server.client_mss_bulk custom value must be within [88, 4096]")); + let _ = std::fs::remove_file(path); + } +} + #[test] fn client_mss_unquoted_number_is_rejected() { let toml = r#" diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 060d21e..f59bb6b 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -69,7 +69,9 @@ use self::quota::{ #[cfg(test)] use self::c2me::enqueue_c2me_command; #[cfg(test)] -use self::d2c::{compute_intermediate_secure_wire_len, process_me_writer_response}; +use self::d2c::{ + compute_intermediate_secure_wire_len, process_me_writer_response, write_client_payload, +}; #[cfg(test)] pub(crate) use self::desync::{ clear_desync_dedup_for_testing_in_shared, desync_dedup_get_for_testing, @@ -166,3 +168,7 @@ mod middle_relay_atomic_quota_invariant_tests; #[cfg(test)] #[path = "tests/middle_relay_baseline_invariant_tests.rs"] mod middle_relay_baseline_invariant_tests; + +#[cfg(test)] +#[path = "tests/middle_relay_d2c_flush_padding_security_tests.rs"] +mod middle_relay_d2c_flush_padding_security_tests; diff --git a/src/proxy/tests/middle_relay_d2c_flush_padding_security_tests.rs b/src/proxy/tests/middle_relay_d2c_flush_padding_security_tests.rs new file mode 100644 index 0000000..12a0ca7 --- /dev/null +++ b/src/proxy/tests/middle_relay_d2c_flush_padding_security_tests.rs @@ -0,0 +1,148 @@ +use std::io; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; + +use tokio::io::AsyncWrite; + +use super::*; +use crate::crypto::AesCtr; +use crate::protocol::framing::INTERMEDIATE_WIRE_LEN_MASK; + +#[derive(Clone, Default)] +struct RecordingWriter { + writes: Arc>>, + flushes: Arc, +} + +impl RecordingWriter { + fn captured(&self) -> Vec { + self.writes + .lock() + .expect("test writer capture lock must not be poisoned") + .clone() + } +} + +impl AsyncWrite for RecordingWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.writes + .lock() + .expect("test writer capture lock must not be poisoned") + .extend_from_slice(buf); + Poll::Ready(Ok(buf.len())) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + self.flushes.fetch_add(1, Ordering::Relaxed); + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +fn crypto_writer(inner: RecordingWriter) -> CryptoWriter { + let key = [0u8; 32]; + CryptoWriter::new(inner, AesCtr::new(&key, 0), 8 * 1024 * 1024) +} + +fn decrypt_capture(mut encrypted: Vec) -> Vec { + let key = [0u8; 32]; + let mut cipher = AesCtr::new(&key, 0); + cipher.apply(&mut encrypted); + encrypted +} + +fn secure_wire_len(cleartext: &[u8]) -> usize { + let header = cleartext + .get(..4) + .expect("secure frame must include an intermediate header"); + (u32::from_le_bytes( + header + .try_into() + .expect("secure frame header must be four bytes"), + ) & INTERMEDIATE_WIRE_LEN_MASK) as usize +} + +async fn write_secure_payload(payload_len: usize) -> (MeD2cWriteMode, Vec) { + let inner = RecordingWriter::default(); + let capture = inner.clone(); + let mut writer = crypto_writer(inner); + let payload = vec![0xa5; payload_len]; + let mut frame_buf = Vec::new(); + let cancel = CancellationToken::new(); + let rng = SecureRandom::new(); + + let mode = write_client_payload( + &mut writer, + ProtoTag::Secure, + 0, + &payload, + &rng, + &mut frame_buf, + &cancel, + ) + .await + .expect("secure payload write must succeed"); + flush_client_or_cancel(&mut writer, &cancel) + .await + .expect("secure payload flush must succeed"); + + (mode, decrypt_capture(capture.captured())) +} + +fn assert_secure_payload_with_tail_padding(cleartext: &[u8], payload_len: usize) { + let wire_len = secure_wire_len(cleartext); + assert_eq!(cleartext.len(), 4 + wire_len); + assert!(cleartext[4..4 + payload_len] + .iter() + .all(|byte| *byte == 0xa5)); + + let padding_len = wire_len + .checked_sub(payload_len) + .expect("secure wire length must include payload bytes"); + assert!((1..=3).contains(&padding_len)); + assert_ne!(wire_len % 4, 0); +} + +#[tokio::test] +async fn queue_drain_flush_reason_performs_physical_client_flush() { + let inner = RecordingWriter::default(); + let flushes = inner.flushes.clone(); + let mut writer = crypto_writer(inner); + let cancel = CancellationToken::new(); + + assert!(me_d2c_flush_reason_requires_client_flush( + MeD2cFlushReason::QueueDrain + )); + flush_client_or_cancel(&mut writer, &cancel) + .await + .expect("client flush must succeed"); + + assert_eq!(flushes.load(Ordering::Relaxed), 1); +} + +#[tokio::test] +async fn secure_payload_coalesced_path_keeps_tail_padding() { + let payload_len = 8; + let (mode, cleartext) = write_secure_payload(payload_len).await; + + assert!(matches!(mode, MeD2cWriteMode::Coalesced)); + assert_secure_payload_with_tail_padding(&cleartext, payload_len); +} + +#[tokio::test] +async fn secure_payload_split_path_keeps_tail_padding() { + let payload_len = ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES; + let (mode, cleartext) = write_secure_payload(payload_len).await; + + assert!(matches!(mode, MeD2cWriteMode::Split)); + assert_secure_payload_with_tail_padding(&cleartext, payload_len); +} diff --git a/src/synlimit_control/iptables.rs b/src/synlimit_control/iptables.rs index 21b289f..a1bf7ce 100644 --- a/src/synlimit_control/iptables.rs +++ b/src/synlimit_control/iptables.rs @@ -292,6 +292,10 @@ mod tests { .any(|pair| pair[0].as_str() == key && pair[1].as_str() == value) } + fn has_key(args: &[String], key: &str) -> bool { + args.iter().any(|arg| arg == key) + } + #[test] fn iptables_rules_use_synfix_order_and_rejects() { let target = test_rule(Some(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 7))), 443); @@ -320,4 +324,34 @@ mod tests { assert!(has_pair(&rules[0], "--hl-lt", "65")); assert!(has_pair(&rules[0], "-d", "::1")); } + + #[test] + fn iptables_missing_rule_errors_are_cleanup_benign() { + assert!(is_missing_command_or_iptables_rule( + "iptables is not available" + )); + assert!(is_missing_command_or_iptables_rule( + "iptables: No chain/target/match by that name." + )); + assert!(is_missing_command_or_iptables_rule( + "iptables: Chain TELEMT_SYNLIMIT does not exist." + )); + assert!(is_missing_command_or_iptables_rule( + "Couldn't load target `TELEMT_SYNLIMIT': No such file or directory" + )); + assert!(!is_missing_command_or_iptables_rule( + "iptables: Permission denied" + )); + } + + #[test] + fn iptables_wildcard_rule_omits_destination_match() { + let target = test_rule(None, 443); + let rules = iptables_synfix_rule_args(&target, 0, IpTablesFamily::V4); + + for rule in rules { + assert!(!has_key(&rule, "-d")); + assert!(has_pair(&rule, "--dport", "443")); + } + } } diff --git a/src/synlimit_control/model.rs b/src/synlimit_control/model.rs index 3678bb8..ced8ec5 100644 --- a/src/synlimit_control/model.rs +++ b/src/synlimit_control/model.rs @@ -124,3 +124,111 @@ pub(super) fn test_rule(ip: Option, port: u16) -> SynLimitRule { hashlimit_size: 32_768, } } + +#[cfg(test)] +mod tests { + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + + use super::*; + use crate::config::ListenerConfig; + + fn listener(ip: IpAddr, port: Option, synlimit: SynLimitMode) -> ListenerConfig { + ListenerConfig { + ip, + port, + client_mss: None, + synlimit, + synlimit_seconds: 60, + synlimit_hitcount: 48, + synlimit_burst: 1, + synlimit_ios_seconds: 1, + synlimit_ios_hitcount: 12, + synlimit_ios_burst: 24, + synlimit_hashlimit_expire_ms: 60_000, + synlimit_hashlimit_size: 32_768, + announce: None, + announce_ip: None, + proxy_protocol: None, + reuse_allow: false, + } + } + + #[test] + fn synlimit_targets_deduplicate_and_use_legacy_port_fallback() { + let mut cfg = ProxyConfig::default(); + cfg.server.port = 9443; + cfg.server.listeners = vec![ + listener( + IpAddr::V4(Ipv4Addr::UNSPECIFIED), + None, + SynLimitMode::Iptables, + ), + listener( + IpAddr::V4(Ipv4Addr::UNSPECIFIED), + None, + SynLimitMode::Iptables, + ), + ]; + + let targets = synlimit_targets(&cfg); + + assert_eq!(targets.iptables_v4.len(), 1); + assert_eq!(targets.iptables_v4[0].ip, None); + assert_eq!(targets.iptables_v4[0].port, 9443); + assert!(targets.iptables_v6.is_empty()); + assert!(targets.nft_v4.is_empty()); + assert!(targets.nft_v6.is_empty()); + } + + #[test] + fn synlimit_targets_separate_backends_and_ip_families() { + let mut cfg = ProxyConfig::default(); + cfg.server.listeners = vec![ + listener( + IpAddr::V4(Ipv4Addr::new(203, 0, 113, 1)), + Some(443), + SynLimitMode::Iptables, + ), + listener( + IpAddr::V6(Ipv6Addr::LOCALHOST), + Some(443), + SynLimitMode::Iptables, + ), + listener( + IpAddr::V4(Ipv4Addr::new(203, 0, 113, 2)), + Some(444), + SynLimitMode::Nftables, + ), + listener( + IpAddr::V6(Ipv6Addr::UNSPECIFIED), + Some(444), + SynLimitMode::Nftables, + ), + ]; + + let targets = synlimit_targets(&cfg); + + assert_eq!(targets.iptables_v4.len(), 1); + assert_eq!(targets.iptables_v6.len(), 1); + assert_eq!(targets.nft_v4.len(), 1); + assert_eq!(targets.nft_v6.len(), 1); + assert_eq!( + targets.iptables_v4[0].ip, + Some(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 1))) + ); + assert_eq!(targets.iptables_v6[0].ip, Some(IpAddr::V6(Ipv6Addr::LOCALHOST))); + assert_eq!( + targets.nft_v4[0].ip, + Some(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 2))) + ); + assert_eq!(targets.nft_v6[0].ip, None); + } + + #[test] + fn synlimit_rate_arg_uses_native_units_without_fractional_rates() { + assert_eq!(synlimit_rate_arg(1, 12), "12/second"); + assert_eq!(synlimit_rate_arg(60, 48), "48/minute"); + assert_eq!(synlimit_rate_arg(3600, 121), "121/hour"); + assert_eq!(synlimit_rate_arg(86400, 241), "241/day"); + } +} diff --git a/src/synlimit_control/nftables.rs b/src/synlimit_control/nftables.rs index 428b3c4..c8f8510 100644 --- a/src/synlimit_control/nftables.rs +++ b/src/synlimit_control/nftables.rs @@ -259,4 +259,37 @@ mod tests { assert!(script.contains("ip6 saddr limit rate over 12/second burst 24 packets")); assert!(script.contains("ip6 saddr limit rate over 48/minute burst 1 packets")); } + + #[test] + fn nft_missing_table_errors_are_cleanup_benign() { + assert!(is_missing_command_or_nft_table("nft is not available")); + assert!(is_missing_command_or_nft_table( + "Error: No such file or directory" + )); + assert!(!is_missing_command_or_nft_table( + "Error: Operation not permitted" + )); + } + + #[test] + fn nft_apply_plan_keeps_dual_stack_rules_in_inet_table() { + let v4_rule = test_rule(Some(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 7))), 443); + let v6_rule = test_rule(Some(IpAddr::V6(Ipv6Addr::LOCALHOST)), 443); + let v4_rules = [v4_rule]; + let v6_rules = [v6_rule]; + let plans = nft_apply_plan( + NftTableFamilies { + inet: false, + ip: false, + ip6: false, + }, + &v4_rules, + &v6_rules, + ); + + assert_eq!(plans.len(), 1); + assert_eq!(plans[0].family.as_str(), "inet"); + assert_eq!(plans[0].v4_targets, v4_rules.as_slice()); + assert_eq!(plans[0].v6_targets, v6_rules.as_slice()); + } } diff --git a/src/transport/middle_proxy/handshake.rs b/src/transport/middle_proxy/handshake.rs index 9fd598e..4cde68d 100644 --- a/src/transport/middle_proxy/handshake.rs +++ b/src/transport/middle_proxy/handshake.rs @@ -675,8 +675,117 @@ fn hex_dump(data: &[u8]) -> String { mod tests { use super::*; use std::io::ErrorKind; + use std::net::{Ipv4Addr, Ipv6Addr}; use tokio::net::{TcpListener, TcpStream}; + fn upstream_egress( + route_kind: UpstreamRouteKind, + socks_bound_addr: Option, + ) -> UpstreamEgressInfo { + UpstreamEgressInfo { + upstream_id: 7, + route_kind, + local_addr: None, + direct_bind_ip: None, + socks_bound_addr, + socks_proxy_addr: None, + } + } + + #[test] + fn socks_bound_addr_is_used_only_for_public_same_family_tuple() { + let v4_bound = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)), 443); + let v6_bound = SocketAddr::new( + IpAddr::V6( + "2606:4700:4700::1111" + .parse::() + .expect("test IPv6 address must parse"), + ), + 443, + ); + + assert_eq!( + MePool::select_socks_bound_addr( + IpFamily::V4, + Some(upstream_egress(UpstreamRouteKind::Socks5, Some(v4_bound))) + ), + Some(v4_bound) + ); + assert_eq!( + MePool::select_socks_bound_addr( + IpFamily::V6, + Some(upstream_egress(UpstreamRouteKind::Socks5, Some(v6_bound))) + ), + Some(v6_bound) + ); + } + + #[test] + fn socks_bound_addr_rejects_bogon_unspecified_wrong_family_and_non_socks_routes() { + let bogon_bound = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 443); + let unspecified_bound = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 443); + let public_v4_bound = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)), 443); + + assert_eq!( + MePool::select_socks_bound_addr( + IpFamily::V4, + Some(upstream_egress(UpstreamRouteKind::Socks5, Some(bogon_bound))) + ), + None + ); + assert_eq!( + MePool::select_socks_bound_addr( + IpFamily::V4, + Some(upstream_egress( + UpstreamRouteKind::Socks5, + Some(unspecified_bound) + )) + ), + None + ); + assert_eq!( + MePool::select_socks_bound_addr( + IpFamily::V6, + Some(upstream_egress( + UpstreamRouteKind::Socks5, + Some(public_v4_bound) + )) + ), + None + ); + assert_eq!( + MePool::select_socks_bound_addr( + IpFamily::V4, + Some(upstream_egress( + UpstreamRouteKind::Direct, + Some(public_v4_bound) + )) + ), + None + ); + } + + #[test] + fn kdf_client_port_source_tracks_only_valid_socks_bound_port() { + let bound = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)), 443); + let zero_port_bound = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)), 0); + + assert_eq!( + KdfClientPortSource::from_socks_bound_port(Some(bound.port())), + KdfClientPortSource::SocksBound + ); + assert_eq!( + KdfClientPortSource::from_socks_bound_port( + Some(zero_port_bound).filter(|addr| addr.port() != 0).map(|addr| addr.port()) + ), + KdfClientPortSource::LocalSocket + ); + assert_eq!( + KdfClientPortSource::from_socks_bound_port(None), + KdfClientPortSource::LocalSocket + ); + } + #[tokio::test] async fn test_configure_keepalive_loopback() { let listener = match TcpListener::bind("127.0.0.1:0").await { diff --git a/src/transport/middle_proxy/tests/send_adversarial_tests.rs b/src/transport/middle_proxy/tests/send_adversarial_tests.rs index 963007f..6c64555 100644 --- a/src/transport/middle_proxy/tests/send_adversarial_tests.rs +++ b/src/transport/middle_proxy/tests/send_adversarial_tests.rs @@ -368,3 +368,63 @@ async fn send_proxy_req_uses_writer_source_ip_when_advertised_our_addr_differs() SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 31)), our_addr.port()) ); } + +#[tokio::test] +async fn send_proxy_req_blocking_fallback_uses_writer_source_ip() { + let (pool, _rng) = make_pool().await; + pool.rr.store(0, Ordering::Relaxed); + + let (conn_id, _rx) = pool.registry.register().await; + let mut live_rx = insert_writer( + &pool, + 32, + 2, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 2, 32)), 443), + true, + ) + .await; + let source_ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 32)); + + let tx = { + let mut writers = pool.writers.write().await; + let writer = writers + .iter_mut() + .find(|writer| writer.id == 32) + .expect("test writer must exist"); + writer.source_ip = source_ip; + writer.tx.clone() + }; + for _ in 0..8 { + tx.try_send(WriterCommand::Close) + .expect("test writer channel must accept preload"); + } + + let our_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 100, 8)), 9443); + let pool_for_send = pool.clone(); + let send_task = tokio::spawn(async move { + pool_for_send + .send_proxy_req( + conn_id, + 2, + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 8)), 30003), + our_addr, + b"blocking", + 0, + None, + ) + .await + }); + + tokio::time::sleep(Duration::from_millis(10)).await; + assert!(matches!(live_rx.recv().await, Some(WriterCommand::Close))); + + let result = send_task.await.expect("send task must not panic"); + assert!(result.is_ok()); + let payload = recv_first_data_payload(&mut live_rx, Duration::from_millis(50)) + .await + .expect("writer must receive blocking fallback payload"); + assert_eq!( + proxy_req_our_addr_from_payload(&payload), + SocketAddr::new(source_ip, our_addr.port()) + ); +}