feat(proxy): enhance logging and deduplication for unknown datacenters

- Implemented a mechanism to log unknown datacenter indices with a distinct limit to avoid excessive logging.
- Introduced tests to ensure that logging is deduplicated per datacenter index and respects the distinct limit.
- Updated the fallback logic for datacenter resolution to prevent panics when only a single datacenter is available.

feat(proxy): add authentication probe throttling

- Added a pre-authentication probe throttling mechanism to limit the rate of invalid TLS and MTProto handshake attempts.
- Introduced a backoff strategy for repeated failures and ensured that successful handshakes reset the failure count.
- Implemented tests to validate the behavior of the authentication probe under various conditions.

fix(proxy): ensure proper flushing of masked writes

- Added a flush operation after writing initial data to the mask writer to ensure data integrity.

refactor(proxy): optimize desynchronization deduplication

- Replaced the Mutex-based deduplication structure with a DashMap for improved concurrency and performance.
- Implemented a bounded cache for deduplication to limit memory usage and prevent stale entries from persisting.

test(proxy): enhance security tests for middle relay and handshake

- Added comprehensive tests for the middle relay and handshake processes, including scenarios for deduplication and authentication probe behavior.
- Ensured that the tests cover edge cases and validate the expected behavior of the system under load.
This commit is contained in:
David Osipov 2026-03-17 01:29:30 +04:00
parent e4a50f9286
commit 205fc88718
No known key found for this signature in database
GPG Key ID: 0E55C4A47454E82E
15 changed files with 1124 additions and 150 deletions

View File

@ -1156,6 +1156,13 @@ pub struct ServerConfig {
#[serde(default = "default_proxy_protocol_header_timeout_ms")] #[serde(default = "default_proxy_protocol_header_timeout_ms")]
pub proxy_protocol_header_timeout_ms: u64, pub proxy_protocol_header_timeout_ms: u64,
/// Trusted source CIDRs allowed to send incoming PROXY protocol headers.
///
/// When non-empty, connections from addresses outside this allowlist are
/// rejected before `src_addr` is applied.
#[serde(default)]
pub proxy_protocol_trusted_cidrs: Vec<IpNetwork>,
#[serde(default)] #[serde(default)]
pub metrics_port: Option<u16>, pub metrics_port: Option<u16>,
@ -1180,6 +1187,7 @@ impl Default for ServerConfig {
listen_tcp: None, listen_tcp: None,
proxy_protocol: false, proxy_protocol: false,
proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(), proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(),
proxy_protocol_trusted_cidrs: Vec::new(),
metrics_port: None, metrics_port: None,
metrics_whitelist: default_metrics_whitelist(), metrics_whitelist: default_metrics_whitelist(),
api: ApiConfig::default(), api: ApiConfig::default(),

View File

@ -285,6 +285,26 @@ pub fn validate_tls_handshake(
handshake: &[u8], handshake: &[u8],
secrets: &[(String, Vec<u8>)], secrets: &[(String, Vec<u8>)],
ignore_time_skew: bool, ignore_time_skew: bool,
) -> Option<TlsValidation> {
validate_tls_handshake_with_replay_window(
handshake,
secrets,
ignore_time_skew,
u64::from(BOOT_TIME_MAX_SECS),
)
}
/// Validate TLS ClientHello and cap the boot-time bypass by replay-cache TTL.
///
/// A boot-time timestamp is only accepted when it falls below both
/// `BOOT_TIME_MAX_SECS` and the configured replay window, preventing timestamp
/// reuse outside replay cache coverage.
#[must_use]
pub fn validate_tls_handshake_with_replay_window(
handshake: &[u8],
secrets: &[(String, Vec<u8>)],
ignore_time_skew: bool,
replay_window_secs: u64,
) -> Option<TlsValidation> { ) -> Option<TlsValidation> {
// Only pay the clock syscall when we will actually compare against it. // Only pay the clock syscall when we will actually compare against it.
// If `ignore_time_skew` is set, a broken or unavailable system clock // If `ignore_time_skew` is set, a broken or unavailable system clock
@ -295,7 +315,16 @@ pub fn validate_tls_handshake(
0_i64 0_i64
}; };
validate_tls_handshake_at_time(handshake, secrets, ignore_time_skew, now) let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX);
let boot_time_cap_secs = BOOT_TIME_MAX_SECS.min(replay_window_u32);
validate_tls_handshake_at_time_with_boot_cap(
handshake,
secrets,
ignore_time_skew,
now,
boot_time_cap_secs,
)
} }
fn system_time_to_unix_secs(now: SystemTime) -> Option<i64> { fn system_time_to_unix_secs(now: SystemTime) -> Option<i64> {
@ -311,6 +340,22 @@ fn validate_tls_handshake_at_time(
secrets: &[(String, Vec<u8>)], secrets: &[(String, Vec<u8>)],
ignore_time_skew: bool, ignore_time_skew: bool,
now: i64, now: i64,
) -> Option<TlsValidation> {
validate_tls_handshake_at_time_with_boot_cap(
handshake,
secrets,
ignore_time_skew,
now,
BOOT_TIME_MAX_SECS,
)
}
fn validate_tls_handshake_at_time_with_boot_cap(
handshake: &[u8],
secrets: &[(String, Vec<u8>)],
ignore_time_skew: bool,
now: i64,
boot_time_cap_secs: u32,
) -> Option<TlsValidation> { ) -> Option<TlsValidation> {
if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 { if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 {
return None; return None;
@ -366,7 +411,7 @@ fn validate_tls_handshake_at_time(
if !ignore_time_skew { if !ignore_time_skew {
// Allow very small timestamps (boot time instead of unix time) // Allow very small timestamps (boot time instead of unix time)
// This is a quirk in some clients that use uptime instead of real time // This is a quirk in some clients that use uptime instead of real time
let is_boot_time = timestamp < BOOT_TIME_MAX_SECS; let is_boot_time = timestamp < boot_time_cap_secs;
if !is_boot_time { if !is_boot_time {
let time_diff = now - i64::from(timestamp); let time_diff = now - i64::from(timestamp);
if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) { if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) {

View File

@ -654,6 +654,34 @@ fn timestamp_at_boot_threshold_triggers_skew_check() {
); );
} }
#[test]
fn replay_window_cap_disables_boot_bypass_for_old_timestamps() {
let secret = b"boot_cap_disabled_test";
let ts: u32 = 900;
let h = make_valid_tls_handshake(secret, ts);
let secrets = vec![("u".to_string(), secret.to_vec())];
let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, 300);
assert!(
result.is_none(),
"timestamp above replay-window cap must not use boot-time bypass"
);
}
#[test]
fn replay_window_cap_still_allows_small_boot_timestamp() {
let secret = b"boot_cap_enabled_test";
let ts: u32 = 120;
let h = make_valid_tls_handshake(secret, ts);
let secrets = vec![("u".to_string(), secret.to_vec())];
let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, 300);
assert!(
result.is_some(),
"timestamp below replay-window cap must retain boot-time compatibility"
);
}
// ------------------------------------------------------------------ // ------------------------------------------------------------------
// Extreme timestamp values // Extreme timestamp values
// ------------------------------------------------------------------ // ------------------------------------------------------------------

View File

@ -4,7 +4,10 @@ use std::future::Future;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration; use std::time::Duration;
use ipnetwork::IpNetwork;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::time::timeout; use tokio::time::timeout;
@ -73,6 +76,20 @@ fn record_handshake_failure_class(
record_beobachten_class(beobachten, config, peer_ip, class); record_beobachten_class(beobachten, config, peer_ip, class);
} }
fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool {
if trusted.is_empty() {
static EMPTY_PROXY_TRUST_WARNED: OnceLock<AtomicBool> = OnceLock::new();
let warned = EMPTY_PROXY_TRUST_WARNED.get_or_init(|| AtomicBool::new(false));
if !warned.swap(true, Ordering::Relaxed) {
warn!(
"PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers by default"
);
}
return false;
}
trusted.iter().any(|cidr| cidr.contains(peer_ip))
}
pub async fn handle_client_stream<S>( pub async fn handle_client_stream<S>(
mut stream: S, mut stream: S,
peer: SocketAddr, peer: SocketAddr,
@ -106,6 +123,17 @@ where
); );
match timeout(proxy_header_timeout, parse_proxy_protocol(&mut stream, peer)).await { match timeout(proxy_header_timeout, parse_proxy_protocol(&mut stream, peer)).await {
Ok(Ok(info)) => { Ok(Ok(info)) => {
if !is_trusted_proxy_source(peer.ip(), &config.server.proxy_protocol_trusted_cidrs)
{
stats.increment_connects_bad();
warn!(
peer = %peer,
trusted = ?config.server.proxy_protocol_trusted_cidrs,
"Rejecting PROXY protocol header from untrusted source"
);
record_beobachten_class(&beobachten, &config, peer.ip(), "other");
return Err(ProxyError::InvalidProxyProtocol);
}
debug!( debug!(
peer = %peer, peer = %peer,
client = %info.src_addr, client = %info.src_addr,
@ -462,6 +490,24 @@ impl RunningClientHandler {
.await .await
{ {
Ok(Ok(info)) => { Ok(Ok(info)) => {
if !is_trusted_proxy_source(
self.peer.ip(),
&self.config.server.proxy_protocol_trusted_cidrs,
) {
self.stats.increment_connects_bad();
warn!(
peer = %self.peer,
trusted = ?self.config.server.proxy_protocol_trusted_cidrs,
"Rejecting PROXY protocol header from untrusted source"
);
record_beobachten_class(
&self.beobachten,
&self.config,
self.peer.ip(),
"other",
);
return Err(ProxyError::InvalidProxyProtocol);
}
debug!( debug!(
peer = %self.peer, peer = %self.peer,
client = %info.src_addr, client = %info.src_addr,
@ -768,7 +814,7 @@ impl RunningClientHandler {
client_writer, client_writer,
success, success,
pool.clone(), pool.clone(),
stats, stats.clone(),
config, config,
buffer_pool, buffer_pool,
local_addr, local_addr,
@ -785,7 +831,7 @@ impl RunningClientHandler {
client_writer, client_writer,
success, success,
upstream_manager, upstream_manager,
stats, stats.clone(),
config, config,
buffer_pool, buffer_pool,
rng, rng,
@ -802,7 +848,7 @@ impl RunningClientHandler {
client_writer, client_writer,
success, success,
upstream_manager, upstream_manager,
stats, stats.clone(),
config, config,
buffer_pool, buffer_pool,
rng, rng,
@ -813,6 +859,7 @@ impl RunningClientHandler {
.await .await
}; };
stats.decrement_user_curr_connects(&user);
ip_tracker.remove_ip(&user, peer_addr.ip()).await; ip_tracker.remove_ip(&user, peer_addr.ip()).await;
relay_result relay_result
} }
@ -832,14 +879,6 @@ impl RunningClientHandler {
}); });
} }
if let Some(limit) = config.access.user_max_tcp_conns.get(user)
&& stats.get_user_curr_connects(user) >= *limit as u64
{
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
if let Some(quota) = config.access.user_data_quota.get(user) if let Some(quota) = config.access.user_data_quota.get(user)
&& stats.get_user_total_octets(user) >= *quota && stats.get_user_total_octets(user) >= *quota
{ {
@ -848,9 +887,21 @@ impl RunningClientHandler {
}); });
} }
let limit = config
.access
.user_max_tcp_conns
.get(user)
.map(|v| *v as u64);
if !stats.try_acquire_user_curr_connects(user, limit) {
return Err(ProxyError::ConnectionLimitExceeded {
user: user.to_string(),
});
}
match ip_tracker.check_and_add(user, peer_addr.ip()).await { match ip_tracker.check_and_add(user, peer_addr.ip()).await {
Ok(()) => {} Ok(()) => {}
Err(reason) => { Err(reason) => {
stats.decrement_user_curr_connects(user);
warn!( warn!(
user = %user, user = %user,
ip = %peer_addr.ip(), ip = %peer_addr.ip(),

View File

@ -2,6 +2,7 @@ use super::*;
use crate::config::{UpstreamConfig, UpstreamType}; use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac; use crate::crypto::sha256_hmac;
use crate::protocol::tls; use crate::protocol::tls;
use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
@ -92,6 +93,206 @@ async fn short_tls_probe_is_masked_through_client_pipeline() {
accept_task.await.unwrap(); accept_task.await.unwrap();
} }
#[tokio::test]
async fn handle_client_stream_increments_connects_all_exactly_once() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10];
let accept_task = tokio::spawn({
let probe = probe.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; probe.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, probe);
}
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let before = stats.get_connects_all();
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let peer: SocketAddr = "203.0.113.177:55001".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
));
client_side.write_all(&probe).await.unwrap();
drop(client_side);
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
assert_eq!(
stats.get_connects_all(),
before + 1,
"handle_client_stream must increment connects_all exactly once"
);
}
#[tokio::test]
async fn running_client_handler_increments_connects_all_exactly_once() {
let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = mask_listener.local_addr().unwrap();
let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let front_addr = front_listener.local_addr().unwrap();
let probe = [0x16, 0x03, 0x01, 0x00, 0x10];
let mask_accept_task = tokio::spawn(async move {
let (mut stream, _) = mask_listener.accept().await.unwrap();
let mut got = [0u8; 5];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, probe);
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_proxy_protocol = 0;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let before = stats.get_connects_all();
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let server_task = {
let config = config.clone();
let stats = stats.clone();
let upstream_manager = upstream_manager.clone();
let replay_checker = replay_checker.clone();
let buffer_pool = buffer_pool.clone();
let rng = rng.clone();
let route_runtime = route_runtime.clone();
let ip_tracker = ip_tracker.clone();
let beobachten = beobachten.clone();
tokio::spawn(async move {
let (stream, peer) = front_listener.accept().await.unwrap();
let real_peer_report = Arc::new(std::sync::Mutex::new(None));
ClientHandler::new(
stream,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
false,
real_peer_report,
)
.run()
.await
})
};
let mut client = TcpStream::connect(front_addr).await.unwrap();
client.write_all(&probe).await.unwrap();
drop(client);
let _ = tokio::time::timeout(Duration::from_secs(3), server_task)
.await
.unwrap()
.unwrap();
tokio::time::timeout(Duration::from_secs(3), mask_accept_task)
.await
.unwrap()
.unwrap();
assert_eq!(
stats.get_connects_all(),
before + 1,
"ClientHandler::run must increment connects_all exactly once"
);
}
#[tokio::test] #[tokio::test]
async fn partial_tls_header_stall_triggers_handshake_timeout() { async fn partial_tls_header_stall_triggers_handshake_timeout() {
let mut cfg = ProxyConfig::default(); let mut cfg = ProxyConfig::default();
@ -1058,6 +1259,186 @@ async fn concurrent_limit_rejections_from_mixed_ips_leave_no_ip_footprint() {
); );
} }
#[tokio::test]
async fn atomic_limit_gate_allows_only_one_concurrent_acquire() {
let mut config = ProxyConfig::default();
config
.access
.user_max_tcp_conns
.insert("user".to_string(), 1);
let config = Arc::new(config);
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
let mut tasks = tokio::task::JoinSet::new();
for i in 0..64u16 {
let config = config.clone();
let stats = stats.clone();
let ip_tracker = ip_tracker.clone();
tasks.spawn(async move {
let peer = SocketAddr::new(
IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 113, (i + 1) as u8)),
30000 + i,
);
RunningClientHandler::check_user_limits_static("user", &config, &stats, peer, &ip_tracker)
.await
.is_ok()
});
}
let mut successes = 0u64;
while let Some(joined) = tasks.join_next().await {
if joined.unwrap() {
successes += 1;
}
}
assert_eq!(
successes, 1,
"exactly one concurrent acquire must pass for a limit=1 user"
);
assert_eq!(stats.get_user_curr_connects("user"), 1);
}
#[tokio::test]
async fn untrusted_proxy_header_source_is_rejected() {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.server.proxy_protocol_trusted_cidrs = vec!["10.10.0.0/16".parse().unwrap()];
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(2048);
let peer: SocketAddr = "198.51.100.44:55000".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
true,
));
let proxy_header = ProxyProtocolV1Builder::new()
.tcp4(
"203.0.113.9:32000".parse().unwrap(),
"192.0.2.8:443".parse().unwrap(),
)
.build();
client_side.write_all(&proxy_header).await.unwrap();
drop(client_side);
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol)));
}
#[tokio::test]
async fn empty_proxy_trusted_cidrs_rejects_proxy_header_by_default() {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.server.proxy_protocol_trusted_cidrs.clear();
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
let buffer_pool = Arc::new(BufferPool::new());
let rng = Arc::new(SecureRandom::new());
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
let ip_tracker = Arc::new(UserIpTracker::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(2048);
let peer: SocketAddr = "198.51.100.45:55000".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats,
upstream_manager,
replay_checker,
buffer_pool,
rng,
None,
route_runtime,
None,
ip_tracker,
beobachten,
true,
));
let proxy_header = ProxyProtocolV1Builder::new()
.tcp4(
"203.0.113.9:32000".parse().unwrap(),
"192.0.2.8:443".parse().unwrap(),
)
.build();
client_side.write_all(&proxy_header).await.unwrap();
drop(client_side);
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol)));
}
#[tokio::test] #[tokio::test]
async fn oversized_tls_record_is_masked_in_generic_stream_pipeline() { async fn oversized_tls_record_is_masked_in_generic_stream_pipeline() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();

View File

@ -2,6 +2,8 @@ use std::fs::OpenOptions;
use std::io::Write; use std::io::Write;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::collections::HashSet;
use std::sync::{Mutex, OnceLock};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@ -22,6 +24,45 @@ use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, CryptoWriter}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
use crate::transport::UpstreamManager; use crate::transport::UpstreamManager;
const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024;
static LOGGED_UNKNOWN_DCS: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new();
// In tests, this function shares global mutable state. Callers that also use
// cache-reset helpers must hold `unknown_dc_test_lock()` to keep assertions
// deterministic under parallel execution.
fn should_log_unknown_dc(dc_idx: i16) -> bool {
let set = LOGGED_UNKNOWN_DCS.get_or_init(|| Mutex::new(HashSet::new()));
match set.lock() {
Ok(mut guard) => {
if guard.contains(&dc_idx) {
return false;
}
if guard.len() >= UNKNOWN_DC_LOG_DISTINCT_LIMIT {
return false;
}
guard.insert(dc_idx)
}
// If the lock is poisoned, keep logging rather than silently dropping
// operator-visible diagnostics.
Err(_) => true,
}
}
#[cfg(test)]
fn clear_unknown_dc_log_cache_for_testing() {
if let Some(set) = LOGGED_UNKNOWN_DCS.get()
&& let Ok(mut guard) = set.lock()
{
guard.clear();
}
}
#[cfg(test)]
fn unknown_dc_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
pub(crate) async fn handle_via_direct<R, W>( pub(crate) async fn handle_via_direct<R, W>(
client_reader: CryptoReader<R>, client_reader: CryptoReader<R>,
client_writer: CryptoWriter<W>, client_writer: CryptoWriter<W>,
@ -64,7 +105,6 @@ where
debug!(peer = %success.peer, "TG handshake complete, starting relay"); debug!(peer = %success.peer, "TG handshake complete, starting relay");
stats.increment_user_connects(user); stats.increment_user_connects(user);
stats.increment_user_curr_connects(user);
stats.increment_current_connections_direct(); stats.increment_current_connections_direct();
let relay_result = relay_bidirectional( let relay_result = relay_bidirectional(
@ -109,7 +149,6 @@ where
}; };
stats.decrement_current_connections_direct(); stats.decrement_current_connections_direct();
stats.decrement_user_curr_connects(user);
match &relay_result { match &relay_result {
Ok(()) => debug!(user = %user, "Direct relay completed"), Ok(()) => debug!(user = %user, "Direct relay completed"),
@ -160,6 +199,7 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster"); warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster");
if config.general.unknown_dc_file_log_enabled if config.general.unknown_dc_file_log_enabled
&& let Some(path) = &config.general.unknown_dc_log_path && let Some(path) = &config.general.unknown_dc_log_path
&& should_log_unknown_dc(dc_idx)
&& let Ok(handle) = tokio::runtime::Handle::try_current() && let Ok(handle) = tokio::runtime::Handle::try_current()
{ {
let path = path.clone(); let path = path.clone();
@ -175,7 +215,7 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs { let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs {
default_dc - 1 default_dc - 1
} else { } else {
1 0
}; };
info!( info!(
@ -203,8 +243,6 @@ async fn do_tg_handshake_static(
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce( let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce(
success.proto_tag, success.proto_tag,
success.dc_idx, success.dc_idx,
&success.dec_key,
success.dec_iv,
&success.enc_key, &success.enc_key,
success.enc_iv, success.enc_iv,
rng, rng,
@ -230,3 +268,7 @@ async fn do_tg_handshake_static(
CryptoWriter::new(write_half, tg_encryptor, max_pending), CryptoWriter::new(write_half, tg_encryptor, max_pending),
)) ))
} }
#[cfg(test)]
#[path = "direct_relay_security_tests.rs"]
mod security_tests;

View File

@ -0,0 +1,51 @@
use super::*;
#[test]
fn unknown_dc_log_is_deduplicated_per_dc_idx() {
let _guard = unknown_dc_test_lock()
.lock()
.expect("unknown dc test lock must be available");
clear_unknown_dc_log_cache_for_testing();
assert!(should_log_unknown_dc(777));
assert!(
!should_log_unknown_dc(777),
"same unknown dc_idx must not be logged repeatedly"
);
assert!(
should_log_unknown_dc(778),
"different unknown dc_idx must still be loggable"
);
}
#[test]
fn unknown_dc_log_respects_distinct_limit() {
let _guard = unknown_dc_test_lock()
.lock()
.expect("unknown dc test lock must be available");
clear_unknown_dc_log_cache_for_testing();
for dc in 1..=UNKNOWN_DC_LOG_DISTINCT_LIMIT {
assert!(
should_log_unknown_dc(dc as i16),
"expected first-time unknown dc_idx to be loggable"
);
}
assert!(
!should_log_unknown_dc(i16::MAX),
"distinct unknown dc_idx entries above limit must not be logged"
);
}
#[test]
fn fallback_dc_never_panics_with_single_dc_list() {
let mut cfg = ProxyConfig::default();
cfg.network.prefer = 6;
cfg.network.ipv6 = Some(true);
cfg.default_dc = Some(42);
let addr = get_dc_addr_static(999, &cfg).expect("fallback dc must resolve safely");
let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT);
assert_eq!(addr, expected);
}

View File

@ -4,9 +4,11 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use std::collections::HashSet; use std::collections::HashSet;
use std::net::IpAddr;
use std::sync::Arc; use std::sync::Arc;
use std::sync::{Mutex, OnceLock}; use std::sync::{Mutex, OnceLock};
use std::time::Duration; use std::time::{Duration, Instant};
use dashmap::DashMap;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, warn, trace}; use tracing::{debug, warn, trace};
use zeroize::Zeroize; use zeroize::Zeroize;
@ -22,10 +24,138 @@ use crate::config::ProxyConfig;
use crate::tls_front::{TlsFrontCache, emulator}; use crate::tls_front::{TlsFrontCache, emulator};
const ACCESS_SECRET_BYTES: usize = 16; const ACCESS_SECRET_BYTES: usize = 16;
static INVALID_SECRET_WARNED: OnceLock<Mutex<HashSet<String>>> = OnceLock::new(); static INVALID_SECRET_WARNED: OnceLock<Mutex<HashSet<(String, String)>>> = OnceLock::new();
const AUTH_PROBE_TRACK_RETENTION_SECS: u64 = 10 * 60;
const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536;
const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4;
#[cfg(test)]
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1;
#[cfg(not(test))]
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 25;
#[cfg(test)]
const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 16;
#[cfg(not(test))]
const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 1_000;
#[derive(Clone, Copy)]
struct AuthProbeState {
fail_streak: u32,
blocked_until: Instant,
last_seen: Instant,
}
static AUTH_PROBE_STATE: OnceLock<DashMap<IpAddr, AuthProbeState>> = OnceLock::new();
fn auth_probe_state_map() -> &'static DashMap<IpAddr, AuthProbeState> {
AUTH_PROBE_STATE.get_or_init(DashMap::new)
}
fn auth_probe_backoff(fail_streak: u32) -> Duration {
if fail_streak < AUTH_PROBE_BACKOFF_START_FAILS {
return Duration::ZERO;
}
let shift = (fail_streak - AUTH_PROBE_BACKOFF_START_FAILS).min(10);
let multiplier = 1u64.checked_shl(shift).unwrap_or(u64::MAX);
let ms = AUTH_PROBE_BACKOFF_BASE_MS
.saturating_mul(multiplier)
.min(AUTH_PROBE_BACKOFF_MAX_MS);
Duration::from_millis(ms)
}
fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool {
let retention = Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS);
now.duration_since(state.last_seen) > retention
}
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
let state = auth_probe_state_map();
let Some(entry) = state.get(&peer_ip) else {
return false;
};
if auth_probe_state_expired(&entry, now) {
drop(entry);
state.remove(&peer_ip);
return false;
}
now < entry.blocked_until
}
fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) {
let state = auth_probe_state_map();
if let Some(mut entry) = state.get_mut(&peer_ip) {
if auth_probe_state_expired(&entry, now) {
*entry = AuthProbeState {
fail_streak: 1,
blocked_until: now + auth_probe_backoff(1),
last_seen: now,
};
return;
}
entry.fail_streak = entry.fail_streak.saturating_add(1);
entry.last_seen = now;
entry.blocked_until = now + auth_probe_backoff(entry.fail_streak);
return;
};
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
return;
}
state.insert(peer_ip, AuthProbeState {
fail_streak: 0,
blocked_until: now,
last_seen: now,
});
if let Some(mut entry) = state.get_mut(&peer_ip) {
entry.fail_streak = 1;
entry.blocked_until = now + auth_probe_backoff(1);
}
}
fn auth_probe_record_success(peer_ip: IpAddr) {
let state = auth_probe_state_map();
state.remove(&peer_ip);
}
#[cfg(test)]
fn clear_auth_probe_state_for_testing() {
if let Some(state) = AUTH_PROBE_STATE.get() {
state.clear();
}
}
#[cfg(test)]
fn auth_probe_fail_streak_for_testing(peer_ip: IpAddr) -> Option<u32> {
let state = AUTH_PROBE_STATE.get()?;
state.get(&peer_ip).map(|entry| entry.fail_streak)
}
#[cfg(test)]
fn auth_probe_is_throttled_for_testing(peer_ip: IpAddr) -> bool {
auth_probe_is_throttled(peer_ip, Instant::now())
}
#[cfg(test)]
fn auth_probe_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
fn clear_warned_secrets_for_testing() {
if let Some(warned) = INVALID_SECRET_WARNED.get()
&& let Ok(mut guard) = warned.lock()
{
guard.clear();
}
}
fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option<usize>) { fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option<usize>) {
let key = format!("{}:{}", name, reason); let key = (name.to_string(), reason.to_string());
let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new())); let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new()));
let should_warn = match warned.lock() { let should_warn = match warned.lock() {
Ok(mut guard) => guard.insert(key), Ok(mut guard) => guard.insert(key),
@ -170,6 +300,11 @@ where
{ {
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake"); debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
if auth_probe_is_throttled(peer.ip(), Instant::now()) {
debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle");
return HandshakeResult::BadClient { reader, writer };
}
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 { if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
debug!(peer = %peer, "TLS handshake too short"); debug!(peer = %peer, "TLS handshake too short");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
@ -177,13 +312,15 @@ where
let secrets = decode_user_secrets(config, None); let secrets = decode_user_secrets(config, None);
let validation = match tls::validate_tls_handshake( let validation = match tls::validate_tls_handshake_with_replay_window(
handshake, handshake,
&secrets, &secrets,
config.access.ignore_time_skew, config.access.ignore_time_skew,
config.access.replay_window_secs,
) { ) {
Some(v) => v, Some(v) => v,
None => { None => {
auth_probe_record_failure(peer.ip(), Instant::now());
debug!( debug!(
peer = %peer, peer = %peer,
ignore_time_skew = config.access.ignore_time_skew, ignore_time_skew = config.access.ignore_time_skew,
@ -197,6 +334,7 @@ where
// letting unauthenticated probes evict valid entries from the replay cache. // letting unauthenticated probes evict valid entries from the replay cache.
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_and_add_tls_digest(digest_half) { if replay_checker.check_and_add_tls_digest(digest_half) {
auth_probe_record_failure(peer.ip(), Instant::now());
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
@ -307,6 +445,8 @@ where
"TLS handshake successful" "TLS handshake successful"
); );
auth_probe_record_success(peer.ip());
HandshakeResult::Success(( HandshakeResult::Success((
FakeTlsReader::new(reader), FakeTlsReader::new(reader),
FakeTlsWriter::new(writer), FakeTlsWriter::new(writer),
@ -331,6 +471,11 @@ where
{ {
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes"); trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
if auth_probe_is_throttled(peer.ip(), Instant::now()) {
debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle");
return HandshakeResult::BadClient { reader, writer };
}
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]; let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect(); let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
@ -396,6 +541,7 @@ where
// entry from the cache. We accept the cost of performing the full // entry from the cache. We accept the cost of performing the full
// authentication check first to avoid poisoning the replay cache. // authentication check first to avoid poisoning the replay cache.
if replay_checker.check_and_add_handshake(dec_prekey_iv) { if replay_checker.check_and_add_handshake(dec_prekey_iv) {
auth_probe_record_failure(peer.ip(), Instant::now());
warn!(peer = %peer, user = %user, "MTProto replay attack detected"); warn!(peer = %peer, user = %user, "MTProto replay attack detected");
return HandshakeResult::BadClient { reader, writer }; return HandshakeResult::BadClient { reader, writer };
} }
@ -421,6 +567,8 @@ where
"MTProto handshake successful" "MTProto handshake successful"
); );
auth_probe_record_success(peer.ip());
let max_pending = config.general.crypto_pending_buffer; let max_pending = config.general.crypto_pending_buffer;
return HandshakeResult::Success(( return HandshakeResult::Success((
CryptoReader::new(reader, decryptor), CryptoReader::new(reader, decryptor),
@ -429,6 +577,7 @@ where
)); ));
} }
auth_probe_record_failure(peer.ip(), Instant::now());
debug!(peer = %peer, "MTProto handshake: no matching user found"); debug!(peer = %peer, "MTProto handshake: no matching user found");
HandshakeResult::BadClient { reader, writer } HandshakeResult::BadClient { reader, writer }
} }
@ -437,8 +586,6 @@ where
pub fn generate_tg_nonce( pub fn generate_tg_nonce(
proto_tag: ProtoTag, proto_tag: ProtoTag,
dc_idx: i16, dc_idx: i16,
_client_dec_key: &[u8; 32],
_client_dec_iv: u128,
client_enc_key: &[u8; 32], client_enc_key: &[u8; 32],
client_enc_iv: u128, client_enc_iv: u128,
rng: &SecureRandom, rng: &SecureRandom,

View File

@ -82,6 +82,7 @@ fn make_valid_tls_client_hello_with_alpn(
} }
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig { fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
clear_auth_probe_state_for_testing();
let mut cfg = ProxyConfig::default(); let mut cfg = ProxyConfig::default();
cfg.access.users.clear(); cfg.access.users.clear();
cfg.access cfg.access
@ -93,8 +94,6 @@ fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
#[test] #[test]
fn test_generate_tg_nonce() { fn test_generate_tg_nonce() {
let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128;
let client_enc_key = [0x24u8; 32]; let client_enc_key = [0x24u8; 32];
let client_enc_iv = 54321u128; let client_enc_iv = 54321u128;
@ -102,8 +101,6 @@ fn test_generate_tg_nonce() {
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce( let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce(
ProtoTag::Secure, ProtoTag::Secure,
2, 2,
&client_dec_key,
client_dec_iv,
&client_enc_key, &client_enc_key,
client_enc_iv, client_enc_iv,
&rng, &rng,
@ -118,8 +115,6 @@ fn test_generate_tg_nonce() {
#[test] #[test]
fn test_encrypt_tg_nonce() { fn test_encrypt_tg_nonce() {
let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128;
let client_enc_key = [0x24u8; 32]; let client_enc_key = [0x24u8; 32];
let client_enc_iv = 54321u128; let client_enc_iv = 54321u128;
@ -127,8 +122,6 @@ fn test_encrypt_tg_nonce() {
let (nonce, _, _, _, _) = generate_tg_nonce( let (nonce, _, _, _, _) = generate_tg_nonce(
ProtoTag::Secure, ProtoTag::Secure,
2, 2,
&client_dec_key,
client_dec_iv,
&client_enc_key, &client_enc_key,
client_enc_iv, client_enc_iv,
&rng, &rng,
@ -164,8 +157,6 @@ fn test_handshake_success_drop_does_not_panic() {
#[test] #[test]
fn test_generate_tg_nonce_enc_dec_material_is_consistent() { fn test_generate_tg_nonce_enc_dec_material_is_consistent() {
let client_dec_key = [0x12u8; 32];
let client_dec_iv = 0x11223344556677889900aabbccddeeffu128;
let client_enc_key = [0x34u8; 32]; let client_enc_key = [0x34u8; 32];
let client_enc_iv = 0xffeeddccbbaa00998877665544332211u128; let client_enc_iv = 0xffeeddccbbaa00998877665544332211u128;
let rng = SecureRandom::new(); let rng = SecureRandom::new();
@ -173,8 +164,6 @@ fn test_generate_tg_nonce_enc_dec_material_is_consistent() {
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce( let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
ProtoTag::Secure, ProtoTag::Secure,
7, 7,
&client_dec_key,
client_dec_iv,
&client_enc_key, &client_enc_key,
client_enc_iv, client_enc_iv,
&rng, &rng,
@ -209,8 +198,6 @@ fn test_generate_tg_nonce_enc_dec_material_is_consistent() {
#[test] #[test]
fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() { fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() {
let client_dec_key = [0x22u8; 32];
let client_dec_iv = 0x0102030405060708090a0b0c0d0e0f10u128;
let client_enc_key = [0xABu8; 32]; let client_enc_key = [0xABu8; 32];
let client_enc_iv = 0x11223344556677889900aabbccddeeffu128; let client_enc_iv = 0x11223344556677889900aabbccddeeffu128;
let rng = SecureRandom::new(); let rng = SecureRandom::new();
@ -218,8 +205,6 @@ fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() {
let (nonce, _, _, _, _) = generate_tg_nonce( let (nonce, _, _, _, _) = generate_tg_nonce(
ProtoTag::Secure, ProtoTag::Secure,
9, 9,
&client_dec_key,
client_dec_iv,
&client_enc_key, &client_enc_key,
client_enc_iv, client_enc_iv,
&rng, &rng,
@ -236,8 +221,6 @@ fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() {
#[test] #[test]
fn test_encrypt_tg_nonce_with_ciphers_matches_manual_suffix_encryption() { fn test_encrypt_tg_nonce_with_ciphers_matches_manual_suffix_encryption() {
let client_dec_key = [0x42u8; 32];
let client_dec_iv = 12345u128;
let client_enc_key = [0x24u8; 32]; let client_enc_key = [0x24u8; 32];
let client_enc_iv = 54321u128; let client_enc_iv = 54321u128;
@ -245,8 +228,6 @@ fn test_encrypt_tg_nonce_with_ciphers_matches_manual_suffix_encryption() {
let (nonce, _, _, _, _) = generate_tg_nonce( let (nonce, _, _, _, _) = generate_tg_nonce(
ProtoTag::Secure, ProtoTag::Secure,
2, 2,
&client_dec_key,
client_dec_iv,
&client_enc_key, &client_enc_key,
client_enc_iv, client_enc_iv,
&rng, &rng,
@ -386,6 +367,7 @@ async fn invalid_tls_probe_does_not_pollute_replay_cache() {
#[tokio::test] #[tokio::test]
async fn empty_decoded_secret_is_rejected() { async fn empty_decoded_secret_is_rejected() {
clear_warned_secrets_for_testing();
let config = test_config_with_secret_hex(""); let config = test_config_with_secret_hex("");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new(); let rng = SecureRandom::new();
@ -409,6 +391,7 @@ async fn empty_decoded_secret_is_rejected() {
#[tokio::test] #[tokio::test]
async fn wrong_length_decoded_secret_is_rejected() { async fn wrong_length_decoded_secret_is_rejected() {
clear_warned_secrets_for_testing();
let config = test_config_with_secret_hex("aa"); let config = test_config_with_secret_hex("aa");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new(); let rng = SecureRandom::new();
@ -458,6 +441,8 @@ async fn invalid_mtproto_probe_does_not_pollute_replay_cache() {
#[tokio::test] #[tokio::test]
async fn mixed_secret_lengths_keep_valid_user_authenticating() { async fn mixed_secret_lengths_keep_valid_user_authenticating() {
clear_warned_secrets_for_testing();
clear_auth_probe_state_for_testing();
let good_secret = [0x22u8; 16]; let good_secret = [0x22u8; 16];
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();
config.access.users.clear(); config.access.users.clear();
@ -582,6 +567,7 @@ async fn malformed_tls_classes_complete_within_bounded_time() {
} }
#[tokio::test] #[tokio::test]
#[ignore = "timing-sensitive; run manually on low-jitter hosts"]
async fn malformed_tls_classes_share_close_latency_buckets() { async fn malformed_tls_classes_share_close_latency_buckets() {
const ITER: usize = 24; const ITER: usize = 24;
const BUCKET_MS: u128 = 10; const BUCKET_MS: u128 = 10;
@ -680,3 +666,114 @@ fn secure_tag_requires_secure_mode_on_direct_transport() {
"Secure tag without TLS must be accepted when secure mode is enabled" "Secure tag without TLS must be accepted when secure mode is enabled"
); );
} }
#[test]
fn invalid_secret_warning_keys_do_not_collide_on_colon_boundaries() {
clear_warned_secrets_for_testing();
warn_invalid_secret_once("a:b", "c", ACCESS_SECRET_BYTES, Some(1));
warn_invalid_secret_once("a", "b:c", ACCESS_SECRET_BYTES, Some(2));
let warned = INVALID_SECRET_WARNED
.get()
.expect("warned set must be initialized");
let guard = warned.lock().expect("warned set lock must be available");
assert_eq!(
guard.len(),
2,
"(name, reason) pairs that stringify to the same colon-joined key must remain distinct"
);
}
#[tokio::test]
async fn repeated_invalid_tls_probes_trigger_pre_auth_throttle() {
let _guard = auth_probe_test_lock()
.lock()
.expect("auth probe test lock must be available");
clear_auth_probe_state_for_testing();
let config = test_config_with_secret_hex("11111111111111111111111111111111");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "127.0.0.1:44361".parse().unwrap();
let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32];
invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32;
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
let result = handle_tls_handshake(
&invalid,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
assert!(
auth_probe_is_throttled_for_testing(peer.ip()),
"invalid probe burst must activate per-IP pre-auth throttle"
);
}
#[tokio::test]
async fn successful_tls_handshake_clears_pre_auth_failure_streak() {
let _guard = auth_probe_test_lock()
.lock()
.expect("auth probe test lock must be available");
clear_auth_probe_state_for_testing();
let secret = [0x23u8; 16];
let config = test_config_with_secret_hex("23232323232323232323232323232323");
let replay_checker = ReplayChecker::new(256, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "127.0.0.1:44362".parse().unwrap();
let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32];
invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32;
for expected in 1..AUTH_PROBE_BACKOFF_START_FAILS {
let result = handle_tls_handshake(
&invalid,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
assert_eq!(
auth_probe_fail_streak_for_testing(peer.ip()),
Some(expected),
"failure streak must grow before a successful authentication"
);
}
let valid = make_valid_tls_handshake(&secret, 0);
let success = handle_tls_handshake(
&valid,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(success, HandshakeResult::Success(_)));
assert_eq!(
auth_probe_fail_streak_for_testing(peer.ip()),
None,
"successful authentication must clear accumulated pre-auth failures"
);
}

View File

@ -232,6 +232,9 @@ where
if mask_write.write_all(initial_data).await.is_err() { if mask_write.write_all(initial_data).await.is_err() {
return; return;
} }
if mask_write.flush().await.is_err() {
return;
}
let mut client_buf = vec![0u8; MASK_BUFFER_SIZE]; let mut client_buf = vec![0u8; MASK_BUFFER_SIZE];
let mut mask_buf = vec![0u8; MASK_BUFFER_SIZE]; let mut mask_buf = vec![0u8; MASK_BUFFER_SIZE];

View File

@ -122,6 +122,12 @@ fn detect_client_type_covers_ssh_port_scanner_and_unknown() {
assert_eq!(detect_client_type(b"random-binary-payload"), "unknown"); assert_eq!(detect_client_type(b"random-binary-payload"), "unknown");
} }
#[test]
fn detect_client_type_len_boundary_9_vs_10_bytes() {
assert_eq!(detect_client_type(b"123456789"), "port-scanner");
assert_eq!(detect_client_type(b"1234567890"), "unknown");
}
#[tokio::test] #[tokio::test]
async fn beobachten_records_scanner_class_when_mask_is_disabled() { async fn beobachten_records_scanner_class_when_mask_is_disabled() {
let mut config = ProxyConfig::default(); let mut config = ProxyConfig::default();

View File

@ -1,12 +1,14 @@
use std::collections::HashMap;
use std::collections::hash_map::DefaultHasher; use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock}; use std::sync::{Arc, OnceLock};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
#[cfg(test)]
use std::sync::Mutex;
use bytes::Bytes; use bytes::{Bytes, BytesMut};
use dashmap::DashMap;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot, watch}; use tokio::sync::{mpsc, oneshot, watch};
use tracing::{debug, trace, warn}; use tracing::{debug, trace, warn};
@ -30,13 +32,15 @@ enum C2MeCommand {
} }
const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60); const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60);
const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536;
const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024;
const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync"; const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128; const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64; const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32; const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
static DESYNC_DEDUP: OnceLock<Mutex<HashMap<u64, Instant>>> = OnceLock::new(); static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
struct RelayForensicsState { struct RelayForensicsState {
trace_id: u64, trace_id: u64,
@ -90,26 +94,48 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
return true; return true;
} }
let dedup = DESYNC_DEDUP.get_or_init(|| Mutex::new(HashMap::new())); let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
let mut guard = dedup.lock().expect("desync dedup mutex poisoned");
guard.retain(|_, seen_at| now.duration_since(*seen_at) < DESYNC_DEDUP_WINDOW);
match guard.get_mut(&key) { if let Some(mut seen_at) = dedup.get_mut(&key) {
Some(seen_at) => {
if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW { if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW {
*seen_at = now; *seen_at = now;
return true;
}
return false;
}
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
let mut stale_keys = Vec::new();
for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) {
if now.duration_since(*entry.value()) >= DESYNC_DEDUP_WINDOW {
stale_keys.push(*entry.key());
}
}
for stale_key in stale_keys {
dedup.remove(&stale_key);
}
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
return false;
}
}
dedup.insert(key, now);
true true
} else { }
false
} #[cfg(test)]
} fn clear_desync_dedup_for_testing() {
None => { if let Some(dedup) = DESYNC_DEDUP.get() {
guard.insert(key, now); dedup.clear();
true
}
} }
} }
#[cfg(test)]
fn desync_dedup_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
fn report_desync_frame_too_large( fn report_desync_frame_too_large(
state: &RelayForensicsState, state: &RelayForensicsState,
proto_tag: ProtoTag, proto_tag: ProtoTag,
@ -229,7 +255,7 @@ pub(crate) async fn handle_via_middle_proxy<R, W>(
me_pool: Arc<MePool>, me_pool: Arc<MePool>,
stats: Arc<Stats>, stats: Arc<Stats>,
config: Arc<ProxyConfig>, config: Arc<ProxyConfig>,
_buffer_pool: Arc<BufferPool>, buffer_pool: Arc<BufferPool>,
local_addr: SocketAddr, local_addr: SocketAddr,
rng: Arc<SecureRandom>, rng: Arc<SecureRandom>,
mut route_rx: watch::Receiver<RouteCutoverState>, mut route_rx: watch::Receiver<RouteCutoverState>,
@ -271,7 +297,6 @@ where
}; };
stats.increment_user_connects(&user); stats.increment_user_connects(&user);
stats.increment_user_curr_connects(&user);
stats.increment_current_connections_me(); stats.increment_current_connections_me();
if let Some(cutover) = affected_cutover_state( if let Some(cutover) = affected_cutover_state(
@ -291,7 +316,6 @@ where
let _ = me_pool.send_close(conn_id).await; let _ = me_pool.send_close(conn_id).await;
me_pool.registry().unregister(conn_id).await; me_pool.registry().unregister(conn_id).await;
stats.decrement_current_connections_me(); stats.decrement_current_connections_me();
stats.decrement_user_curr_connects(&user);
return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string())); return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
} }
@ -557,6 +581,7 @@ where
&mut crypto_reader, &mut crypto_reader,
proto_tag, proto_tag,
frame_limit, frame_limit,
&buffer_pool,
&forensics, &forensics,
&mut frame_counter, &mut frame_counter,
&stats, &stats,
@ -638,7 +663,6 @@ where
); );
me_pool.registry().unregister(conn_id).await; me_pool.registry().unregister(conn_id).await;
stats.decrement_current_connections_me(); stats.decrement_current_connections_me();
stats.decrement_user_curr_connects(&user);
result result
} }
@ -646,6 +670,7 @@ async fn read_client_payload<R>(
client_reader: &mut CryptoReader<R>, client_reader: &mut CryptoReader<R>,
proto_tag: ProtoTag, proto_tag: ProtoTag,
max_frame: usize, max_frame: usize,
buffer_pool: &Arc<BufferPool>,
forensics: &RelayForensicsState, forensics: &RelayForensicsState,
frame_counter: &mut u64, frame_counter: &mut u64,
stats: &Stats, stats: &Stats,
@ -737,18 +762,27 @@ where
len len
}; };
let mut payload = vec![0u8; len]; let chunk_cap = buffer_pool.buffer_size().max(1024);
let mut payload = BytesMut::with_capacity(len.min(chunk_cap));
let mut remaining = len;
while remaining > 0 {
let chunk_len = remaining.min(chunk_cap);
let mut chunk = buffer_pool.get();
chunk.resize(chunk_len, 0);
client_reader client_reader
.read_exact(&mut payload) .read_exact(&mut chunk[..chunk_len])
.await .await
.map_err(ProxyError::Io)?; .map_err(ProxyError::Io)?;
payload.extend_from_slice(&chunk[..chunk_len]);
remaining -= chunk_len;
}
// Secure Intermediate: strip validated trailing padding bytes. // Secure Intermediate: strip validated trailing padding bytes.
if proto_tag == ProtoTag::Secure { if proto_tag == ProtoTag::Secure {
payload.truncate(secure_payload_len); payload.truncate(secure_payload_len);
} }
*frame_counter += 1; *frame_counter += 1;
return Ok(Some((Bytes::from(payload), quickack))); return Ok(Some((payload.freeze(), quickack)));
} }
} }
@ -940,82 +974,5 @@ where
} }
#[cfg(test)] #[cfg(test)]
mod tests { #[path = "middle_relay_security_tests.rs"]
use super::*; mod security_tests;
use tokio::time::{Duration as TokioDuration, timeout};
#[test]
fn should_yield_sender_only_on_budget_with_backlog() {
assert!(!should_yield_c2me_sender(0, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false));
assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true));
}
#[tokio::test]
async fn enqueue_c2me_command_uses_try_send_fast_path() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(2);
enqueue_c2me_command(
&tx,
C2MeCommand::Data {
payload: Bytes::from_static(&[1, 2, 3]),
flags: 0,
},
)
.await
.unwrap();
let recv = timeout(TokioDuration::from_millis(50), rx.recv())
.await
.unwrap()
.unwrap();
match recv {
C2MeCommand::Data { payload, flags } => {
assert_eq!(payload.as_ref(), &[1, 2, 3]);
assert_eq!(flags, 0);
}
C2MeCommand::Close => panic!("unexpected close command"),
}
}
#[tokio::test]
async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
tx.send(C2MeCommand::Data {
payload: Bytes::from_static(&[9]),
flags: 9,
})
.await
.unwrap();
let tx2 = tx.clone();
let producer = tokio::spawn(async move {
enqueue_c2me_command(
&tx2,
C2MeCommand::Data {
payload: Bytes::from_static(&[7, 7]),
flags: 7,
},
)
.await
.unwrap();
});
let _ = timeout(TokioDuration::from_millis(100), rx.recv())
.await
.unwrap();
producer.await.unwrap();
let recv = timeout(TokioDuration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
match recv {
C2MeCommand::Data { payload, flags } => {
assert_eq!(payload.as_ref(), &[7, 7]);
assert_eq!(flags, 7);
}
C2MeCommand::Close => panic!("unexpected close command"),
}
}
}

View File

@ -0,0 +1,103 @@
use super::*;
use tokio::time::{Duration as TokioDuration, timeout};
#[test]
fn should_yield_sender_only_on_budget_with_backlog() {
assert!(!should_yield_c2me_sender(0, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false));
assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true));
}
#[tokio::test]
async fn enqueue_c2me_command_uses_try_send_fast_path() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(2);
enqueue_c2me_command(
&tx,
C2MeCommand::Data {
payload: Bytes::from_static(&[1, 2, 3]),
flags: 0,
},
)
.await
.unwrap();
let recv = timeout(TokioDuration::from_millis(50), rx.recv())
.await
.unwrap()
.unwrap();
match recv {
C2MeCommand::Data { payload, flags } => {
assert_eq!(payload.as_ref(), &[1, 2, 3]);
assert_eq!(flags, 0);
}
C2MeCommand::Close => panic!("unexpected close command"),
}
}
#[tokio::test]
async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
tx.send(C2MeCommand::Data {
payload: Bytes::from_static(&[9]),
flags: 9,
})
.await
.unwrap();
let tx2 = tx.clone();
let producer = tokio::spawn(async move {
enqueue_c2me_command(
&tx2,
C2MeCommand::Data {
payload: Bytes::from_static(&[7, 7]),
flags: 7,
},
)
.await
.unwrap();
});
let _ = timeout(TokioDuration::from_millis(100), rx.recv())
.await
.unwrap();
producer.await.unwrap();
let recv = timeout(TokioDuration::from_millis(100), rx.recv())
.await
.unwrap()
.unwrap();
match recv {
C2MeCommand::Data { payload, flags } => {
assert_eq!(payload.as_ref(), &[7, 7]);
assert_eq!(flags, 7);
}
C2MeCommand::Close => panic!("unexpected close command"),
}
}
#[test]
fn desync_dedup_cache_is_bounded() {
let _guard = desync_dedup_test_lock()
.lock()
.expect("desync dedup test lock must be available");
clear_desync_dedup_for_testing();
let now = Instant::now();
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
assert!(
should_emit_full_desync(key, false, now),
"unique keys up to cap must be tracked"
);
}
assert!(
!should_emit_full_desync(u64::MAX, false, now),
"new key above cap must be suppressed to bound memory"
);
assert!(
!should_emit_full_desync(7, false, now),
"already tracked key inside dedup window must stay suppressed"
);
}

View File

@ -1257,6 +1257,33 @@ impl Stats {
stats.curr_connects.fetch_add(1, Ordering::Relaxed); stats.curr_connects.fetch_add(1, Ordering::Relaxed);
} }
pub fn try_acquire_user_curr_connects(&self, user: &str, limit: Option<u64>) -> bool {
if !self.telemetry_user_enabled() {
return true;
}
self.maybe_cleanup_user_stats();
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
let counter = &stats.curr_connects;
let mut current = counter.load(Ordering::Relaxed);
loop {
if let Some(max) = limit && current >= max {
return false;
}
match counter.compare_exchange_weak(
current,
current.saturating_add(1),
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return true,
Err(actual) => current = actual,
}
}
}
pub fn decrement_user_curr_connects(&self, user: &str) { pub fn decrement_user_curr_connects(&self, user: &str) {
self.maybe_cleanup_user_stats(); self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) { if let Some(stats) = self.user_stats.get(user) {

View File

@ -513,6 +513,7 @@ impl FrameCodecTrait for SecureCodec {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::collections::HashSet;
use tokio_util::codec::{FramedRead, FramedWrite}; use tokio_util::codec::{FramedRead, FramedWrite};
use tokio::io::duplex; use tokio::io::duplex;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
@ -630,4 +631,31 @@ mod tests {
let result = codec.decode(&mut buf); let result = codec.decode(&mut buf);
assert!(result.is_err()); assert!(result.is_err());
} }
#[test]
fn secure_codec_always_adds_padding_and_jitters_wire_length() {
let codec = SecureCodec::new(Arc::new(SecureRandom::new()));
let payload = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
let mut wire_lens = HashSet::new();
for _ in 0..64 {
let frame = Frame::new(payload.clone());
let mut out = BytesMut::new();
codec.encode(&frame, &mut out).unwrap();
assert!(out.len() >= 4 + payload.len() + 1);
let wire_len = u32::from_le_bytes([out[0], out[1], out[2], out[3]]) as usize;
assert!(
(payload.len() + 1..=payload.len() + 3).contains(&wire_len),
"Secure wire length must be payload+1..3, got {wire_len}"
);
assert_ne!(wire_len % 4, 0, "Secure wire length must be non-4-aligned");
wire_lens.insert(wire_len);
}
assert!(
wire_lens.len() >= 2,
"Secure padding should create observable wire-length jitter"
);
}
} }