mirror of https://github.com/telemt/telemt.git
feat(proxy): enhance logging and deduplication for unknown datacenters
- Implemented a mechanism to log unknown datacenter indices with a distinct limit to avoid excessive logging. - Introduced tests to ensure that logging is deduplicated per datacenter index and respects the distinct limit. - Updated the fallback logic for datacenter resolution to prevent panics when only a single datacenter is available. feat(proxy): add authentication probe throttling - Added a pre-authentication probe throttling mechanism to limit the rate of invalid TLS and MTProto handshake attempts. - Introduced a backoff strategy for repeated failures and ensured that successful handshakes reset the failure count. - Implemented tests to validate the behavior of the authentication probe under various conditions. fix(proxy): ensure proper flushing of masked writes - Added a flush operation after writing initial data to the mask writer to ensure data integrity. refactor(proxy): optimize desynchronization deduplication - Replaced the Mutex-based deduplication structure with a DashMap for improved concurrency and performance. - Implemented a bounded cache for deduplication to limit memory usage and prevent stale entries from persisting. test(proxy): enhance security tests for middle relay and handshake - Added comprehensive tests for the middle relay and handshake processes, including scenarios for deduplication and authentication probe behavior. - Ensured that the tests cover edge cases and validate the expected behavior of the system under load.
This commit is contained in:
parent
e4a50f9286
commit
205fc88718
|
|
@ -1156,6 +1156,13 @@ pub struct ServerConfig {
|
|||
#[serde(default = "default_proxy_protocol_header_timeout_ms")]
|
||||
pub proxy_protocol_header_timeout_ms: u64,
|
||||
|
||||
/// Trusted source CIDRs allowed to send incoming PROXY protocol headers.
|
||||
///
|
||||
/// When non-empty, connections from addresses outside this allowlist are
|
||||
/// rejected before `src_addr` is applied.
|
||||
#[serde(default)]
|
||||
pub proxy_protocol_trusted_cidrs: Vec<IpNetwork>,
|
||||
|
||||
#[serde(default)]
|
||||
pub metrics_port: Option<u16>,
|
||||
|
||||
|
|
@ -1180,6 +1187,7 @@ impl Default for ServerConfig {
|
|||
listen_tcp: None,
|
||||
proxy_protocol: false,
|
||||
proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(),
|
||||
proxy_protocol_trusted_cidrs: Vec::new(),
|
||||
metrics_port: None,
|
||||
metrics_whitelist: default_metrics_whitelist(),
|
||||
api: ApiConfig::default(),
|
||||
|
|
|
|||
|
|
@ -285,6 +285,26 @@ pub fn validate_tls_handshake(
|
|||
handshake: &[u8],
|
||||
secrets: &[(String, Vec<u8>)],
|
||||
ignore_time_skew: bool,
|
||||
) -> Option<TlsValidation> {
|
||||
validate_tls_handshake_with_replay_window(
|
||||
handshake,
|
||||
secrets,
|
||||
ignore_time_skew,
|
||||
u64::from(BOOT_TIME_MAX_SECS),
|
||||
)
|
||||
}
|
||||
|
||||
/// Validate TLS ClientHello and cap the boot-time bypass by replay-cache TTL.
|
||||
///
|
||||
/// A boot-time timestamp is only accepted when it falls below both
|
||||
/// `BOOT_TIME_MAX_SECS` and the configured replay window, preventing timestamp
|
||||
/// reuse outside replay cache coverage.
|
||||
#[must_use]
|
||||
pub fn validate_tls_handshake_with_replay_window(
|
||||
handshake: &[u8],
|
||||
secrets: &[(String, Vec<u8>)],
|
||||
ignore_time_skew: bool,
|
||||
replay_window_secs: u64,
|
||||
) -> Option<TlsValidation> {
|
||||
// Only pay the clock syscall when we will actually compare against it.
|
||||
// If `ignore_time_skew` is set, a broken or unavailable system clock
|
||||
|
|
@ -295,7 +315,16 @@ pub fn validate_tls_handshake(
|
|||
0_i64
|
||||
};
|
||||
|
||||
validate_tls_handshake_at_time(handshake, secrets, ignore_time_skew, now)
|
||||
let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX);
|
||||
let boot_time_cap_secs = BOOT_TIME_MAX_SECS.min(replay_window_u32);
|
||||
|
||||
validate_tls_handshake_at_time_with_boot_cap(
|
||||
handshake,
|
||||
secrets,
|
||||
ignore_time_skew,
|
||||
now,
|
||||
boot_time_cap_secs,
|
||||
)
|
||||
}
|
||||
|
||||
fn system_time_to_unix_secs(now: SystemTime) -> Option<i64> {
|
||||
|
|
@ -311,6 +340,22 @@ fn validate_tls_handshake_at_time(
|
|||
secrets: &[(String, Vec<u8>)],
|
||||
ignore_time_skew: bool,
|
||||
now: i64,
|
||||
) -> Option<TlsValidation> {
|
||||
validate_tls_handshake_at_time_with_boot_cap(
|
||||
handshake,
|
||||
secrets,
|
||||
ignore_time_skew,
|
||||
now,
|
||||
BOOT_TIME_MAX_SECS,
|
||||
)
|
||||
}
|
||||
|
||||
fn validate_tls_handshake_at_time_with_boot_cap(
|
||||
handshake: &[u8],
|
||||
secrets: &[(String, Vec<u8>)],
|
||||
ignore_time_skew: bool,
|
||||
now: i64,
|
||||
boot_time_cap_secs: u32,
|
||||
) -> Option<TlsValidation> {
|
||||
if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 {
|
||||
return None;
|
||||
|
|
@ -366,7 +411,7 @@ fn validate_tls_handshake_at_time(
|
|||
if !ignore_time_skew {
|
||||
// Allow very small timestamps (boot time instead of unix time)
|
||||
// This is a quirk in some clients that use uptime instead of real time
|
||||
let is_boot_time = timestamp < BOOT_TIME_MAX_SECS;
|
||||
let is_boot_time = timestamp < boot_time_cap_secs;
|
||||
if !is_boot_time {
|
||||
let time_diff = now - i64::from(timestamp);
|
||||
if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) {
|
||||
|
|
|
|||
|
|
@ -654,6 +654,34 @@ fn timestamp_at_boot_threshold_triggers_skew_check() {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn replay_window_cap_disables_boot_bypass_for_old_timestamps() {
|
||||
let secret = b"boot_cap_disabled_test";
|
||||
let ts: u32 = 900;
|
||||
let h = make_valid_tls_handshake(secret, ts);
|
||||
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||
|
||||
let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, 300);
|
||||
assert!(
|
||||
result.is_none(),
|
||||
"timestamp above replay-window cap must not use boot-time bypass"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn replay_window_cap_still_allows_small_boot_timestamp() {
|
||||
let secret = b"boot_cap_enabled_test";
|
||||
let ts: u32 = 120;
|
||||
let h = make_valid_tls_handshake(secret, ts);
|
||||
let secrets = vec![("u".to_string(), secret.to_vec())];
|
||||
|
||||
let result = validate_tls_handshake_with_replay_window(&h, &secrets, false, 300);
|
||||
assert!(
|
||||
result.is_some(),
|
||||
"timestamp below replay-window cap must retain boot-time compatibility"
|
||||
);
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------
|
||||
// Extreme timestamp values
|
||||
// ------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -4,7 +4,10 @@ use std::future::Future;
|
|||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::time::Duration;
|
||||
use ipnetwork::IpNetwork;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::time::timeout;
|
||||
|
|
@ -73,6 +76,20 @@ fn record_handshake_failure_class(
|
|||
record_beobachten_class(beobachten, config, peer_ip, class);
|
||||
}
|
||||
|
||||
fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool {
|
||||
if trusted.is_empty() {
|
||||
static EMPTY_PROXY_TRUST_WARNED: OnceLock<AtomicBool> = OnceLock::new();
|
||||
let warned = EMPTY_PROXY_TRUST_WARNED.get_or_init(|| AtomicBool::new(false));
|
||||
if !warned.swap(true, Ordering::Relaxed) {
|
||||
warn!(
|
||||
"PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers by default"
|
||||
);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
trusted.iter().any(|cidr| cidr.contains(peer_ip))
|
||||
}
|
||||
|
||||
pub async fn handle_client_stream<S>(
|
||||
mut stream: S,
|
||||
peer: SocketAddr,
|
||||
|
|
@ -106,6 +123,17 @@ where
|
|||
);
|
||||
match timeout(proxy_header_timeout, parse_proxy_protocol(&mut stream, peer)).await {
|
||||
Ok(Ok(info)) => {
|
||||
if !is_trusted_proxy_source(peer.ip(), &config.server.proxy_protocol_trusted_cidrs)
|
||||
{
|
||||
stats.increment_connects_bad();
|
||||
warn!(
|
||||
peer = %peer,
|
||||
trusted = ?config.server.proxy_protocol_trusted_cidrs,
|
||||
"Rejecting PROXY protocol header from untrusted source"
|
||||
);
|
||||
record_beobachten_class(&beobachten, &config, peer.ip(), "other");
|
||||
return Err(ProxyError::InvalidProxyProtocol);
|
||||
}
|
||||
debug!(
|
||||
peer = %peer,
|
||||
client = %info.src_addr,
|
||||
|
|
@ -462,6 +490,24 @@ impl RunningClientHandler {
|
|||
.await
|
||||
{
|
||||
Ok(Ok(info)) => {
|
||||
if !is_trusted_proxy_source(
|
||||
self.peer.ip(),
|
||||
&self.config.server.proxy_protocol_trusted_cidrs,
|
||||
) {
|
||||
self.stats.increment_connects_bad();
|
||||
warn!(
|
||||
peer = %self.peer,
|
||||
trusted = ?self.config.server.proxy_protocol_trusted_cidrs,
|
||||
"Rejecting PROXY protocol header from untrusted source"
|
||||
);
|
||||
record_beobachten_class(
|
||||
&self.beobachten,
|
||||
&self.config,
|
||||
self.peer.ip(),
|
||||
"other",
|
||||
);
|
||||
return Err(ProxyError::InvalidProxyProtocol);
|
||||
}
|
||||
debug!(
|
||||
peer = %self.peer,
|
||||
client = %info.src_addr,
|
||||
|
|
@ -768,7 +814,7 @@ impl RunningClientHandler {
|
|||
client_writer,
|
||||
success,
|
||||
pool.clone(),
|
||||
stats,
|
||||
stats.clone(),
|
||||
config,
|
||||
buffer_pool,
|
||||
local_addr,
|
||||
|
|
@ -785,7 +831,7 @@ impl RunningClientHandler {
|
|||
client_writer,
|
||||
success,
|
||||
upstream_manager,
|
||||
stats,
|
||||
stats.clone(),
|
||||
config,
|
||||
buffer_pool,
|
||||
rng,
|
||||
|
|
@ -802,7 +848,7 @@ impl RunningClientHandler {
|
|||
client_writer,
|
||||
success,
|
||||
upstream_manager,
|
||||
stats,
|
||||
stats.clone(),
|
||||
config,
|
||||
buffer_pool,
|
||||
rng,
|
||||
|
|
@ -813,6 +859,7 @@ impl RunningClientHandler {
|
|||
.await
|
||||
};
|
||||
|
||||
stats.decrement_user_curr_connects(&user);
|
||||
ip_tracker.remove_ip(&user, peer_addr.ip()).await;
|
||||
relay_result
|
||||
}
|
||||
|
|
@ -832,14 +879,6 @@ impl RunningClientHandler {
|
|||
});
|
||||
}
|
||||
|
||||
if let Some(limit) = config.access.user_max_tcp_conns.get(user)
|
||||
&& stats.get_user_curr_connects(user) >= *limit as u64
|
||||
{
|
||||
return Err(ProxyError::ConnectionLimitExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(quota) = config.access.user_data_quota.get(user)
|
||||
&& stats.get_user_total_octets(user) >= *quota
|
||||
{
|
||||
|
|
@ -848,9 +887,21 @@ impl RunningClientHandler {
|
|||
});
|
||||
}
|
||||
|
||||
let limit = config
|
||||
.access
|
||||
.user_max_tcp_conns
|
||||
.get(user)
|
||||
.map(|v| *v as u64);
|
||||
if !stats.try_acquire_user_curr_connects(user, limit) {
|
||||
return Err(ProxyError::ConnectionLimitExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
|
||||
Ok(()) => {}
|
||||
Err(reason) => {
|
||||
stats.decrement_user_curr_connects(user);
|
||||
warn!(
|
||||
user = %user,
|
||||
ip = %peer_addr.ip(),
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ use super::*;
|
|||
use crate::config::{UpstreamConfig, UpstreamType};
|
||||
use crate::crypto::sha256_hmac;
|
||||
use crate::protocol::tls;
|
||||
use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
|
|
@ -92,6 +93,206 @@ async fn short_tls_probe_is_masked_through_client_pipeline() {
|
|||
accept_task.await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handle_client_stream_increments_connects_all_exactly_once() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let backend_addr = listener.local_addr().unwrap();
|
||||
let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10];
|
||||
|
||||
let accept_task = tokio::spawn({
|
||||
let probe = probe.clone();
|
||||
async move {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let mut got = vec![0u8; probe.len()];
|
||||
stream.read_exact(&mut got).await.unwrap();
|
||||
assert_eq!(got, probe);
|
||||
}
|
||||
});
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = false;
|
||||
cfg.censorship.mask = true;
|
||||
cfg.censorship.mask_unix_sock = None;
|
||||
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
cfg.censorship.mask_port = backend_addr.port();
|
||||
cfg.censorship.mask_proxy_protocol = 0;
|
||||
|
||||
let config = Arc::new(cfg);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let before = stats.get_connects_all();
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
));
|
||||
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
let beobachten = Arc::new(BeobachtenStore::new());
|
||||
|
||||
let (server_side, mut client_side) = duplex(4096);
|
||||
let peer: SocketAddr = "203.0.113.177:55001".parse().unwrap();
|
||||
|
||||
let handler = tokio::spawn(handle_client_stream(
|
||||
server_side,
|
||||
peer,
|
||||
config,
|
||||
stats.clone(),
|
||||
upstream_manager,
|
||||
replay_checker,
|
||||
buffer_pool,
|
||||
rng,
|
||||
None,
|
||||
route_runtime,
|
||||
None,
|
||||
ip_tracker,
|
||||
beobachten,
|
||||
false,
|
||||
));
|
||||
|
||||
client_side.write_all(&probe).await.unwrap();
|
||||
drop(client_side);
|
||||
|
||||
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
tokio::time::timeout(Duration::from_secs(3), accept_task)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
stats.get_connects_all(),
|
||||
before + 1,
|
||||
"handle_client_stream must increment connects_all exactly once"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn running_client_handler_increments_connects_all_exactly_once() {
|
||||
let mask_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let backend_addr = mask_listener.local_addr().unwrap();
|
||||
|
||||
let front_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let front_addr = front_listener.local_addr().unwrap();
|
||||
|
||||
let probe = [0x16, 0x03, 0x01, 0x00, 0x10];
|
||||
|
||||
let mask_accept_task = tokio::spawn(async move {
|
||||
let (mut stream, _) = mask_listener.accept().await.unwrap();
|
||||
let mut got = [0u8; 5];
|
||||
stream.read_exact(&mut got).await.unwrap();
|
||||
assert_eq!(got, probe);
|
||||
});
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = false;
|
||||
cfg.censorship.mask = true;
|
||||
cfg.censorship.mask_unix_sock = None;
|
||||
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
cfg.censorship.mask_port = backend_addr.port();
|
||||
cfg.censorship.mask_proxy_protocol = 0;
|
||||
|
||||
let config = Arc::new(cfg);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let before = stats.get_connects_all();
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
));
|
||||
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
let beobachten = Arc::new(BeobachtenStore::new());
|
||||
|
||||
let server_task = {
|
||||
let config = config.clone();
|
||||
let stats = stats.clone();
|
||||
let upstream_manager = upstream_manager.clone();
|
||||
let replay_checker = replay_checker.clone();
|
||||
let buffer_pool = buffer_pool.clone();
|
||||
let rng = rng.clone();
|
||||
let route_runtime = route_runtime.clone();
|
||||
let ip_tracker = ip_tracker.clone();
|
||||
let beobachten = beobachten.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let (stream, peer) = front_listener.accept().await.unwrap();
|
||||
let real_peer_report = Arc::new(std::sync::Mutex::new(None));
|
||||
ClientHandler::new(
|
||||
stream,
|
||||
peer,
|
||||
config,
|
||||
stats,
|
||||
upstream_manager,
|
||||
replay_checker,
|
||||
buffer_pool,
|
||||
rng,
|
||||
None,
|
||||
route_runtime,
|
||||
None,
|
||||
ip_tracker,
|
||||
beobachten,
|
||||
false,
|
||||
real_peer_report,
|
||||
)
|
||||
.run()
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
let mut client = TcpStream::connect(front_addr).await.unwrap();
|
||||
client.write_all(&probe).await.unwrap();
|
||||
drop(client);
|
||||
|
||||
let _ = tokio::time::timeout(Duration::from_secs(3), server_task)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
tokio::time::timeout(Duration::from_secs(3), mask_accept_task)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
stats.get_connects_all(),
|
||||
before + 1,
|
||||
"ClientHandler::run must increment connects_all exactly once"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn partial_tls_header_stall_triggers_handshake_timeout() {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
|
|
@ -1058,6 +1259,186 @@ async fn concurrent_limit_rejections_from_mixed_ips_leave_no_ip_footprint() {
|
|||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn atomic_limit_gate_allows_only_one_concurrent_acquire() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config
|
||||
.access
|
||||
.user_max_tcp_conns
|
||||
.insert("user".to_string(), 1);
|
||||
|
||||
let config = Arc::new(config);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
|
||||
let mut tasks = tokio::task::JoinSet::new();
|
||||
for i in 0..64u16 {
|
||||
let config = config.clone();
|
||||
let stats = stats.clone();
|
||||
let ip_tracker = ip_tracker.clone();
|
||||
tasks.spawn(async move {
|
||||
let peer = SocketAddr::new(
|
||||
IpAddr::V4(std::net::Ipv4Addr::new(203, 0, 113, (i + 1) as u8)),
|
||||
30000 + i,
|
||||
);
|
||||
RunningClientHandler::check_user_limits_static("user", &config, &stats, peer, &ip_tracker)
|
||||
.await
|
||||
.is_ok()
|
||||
});
|
||||
}
|
||||
|
||||
let mut successes = 0u64;
|
||||
while let Some(joined) = tasks.join_next().await {
|
||||
if joined.unwrap() {
|
||||
successes += 1;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
successes, 1,
|
||||
"exactly one concurrent acquire must pass for a limit=1 user"
|
||||
);
|
||||
assert_eq!(stats.get_user_curr_connects("user"), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn untrusted_proxy_header_source_is_rejected() {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = false;
|
||||
cfg.server.proxy_protocol_trusted_cidrs = vec!["10.10.0.0/16".parse().unwrap()];
|
||||
|
||||
let config = Arc::new(cfg);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
));
|
||||
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
let beobachten = Arc::new(BeobachtenStore::new());
|
||||
|
||||
let (server_side, mut client_side) = duplex(2048);
|
||||
let peer: SocketAddr = "198.51.100.44:55000".parse().unwrap();
|
||||
|
||||
let handler = tokio::spawn(handle_client_stream(
|
||||
server_side,
|
||||
peer,
|
||||
config,
|
||||
stats,
|
||||
upstream_manager,
|
||||
replay_checker,
|
||||
buffer_pool,
|
||||
rng,
|
||||
None,
|
||||
route_runtime,
|
||||
None,
|
||||
ip_tracker,
|
||||
beobachten,
|
||||
true,
|
||||
));
|
||||
|
||||
let proxy_header = ProxyProtocolV1Builder::new()
|
||||
.tcp4(
|
||||
"203.0.113.9:32000".parse().unwrap(),
|
||||
"192.0.2.8:443".parse().unwrap(),
|
||||
)
|
||||
.build();
|
||||
client_side.write_all(&proxy_header).await.unwrap();
|
||||
drop(client_side);
|
||||
|
||||
let result = tokio::time::timeout(Duration::from_secs(3), handler)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn empty_proxy_trusted_cidrs_rejects_proxy_header_by_default() {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.beobachten = false;
|
||||
cfg.server.proxy_protocol_trusted_cidrs.clear();
|
||||
|
||||
let config = Arc::new(cfg);
|
||||
let stats = Arc::new(Stats::new());
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
));
|
||||
let replay_checker = Arc::new(ReplayChecker::new(128, Duration::from_secs(60)));
|
||||
let buffer_pool = Arc::new(BufferPool::new());
|
||||
let rng = Arc::new(SecureRandom::new());
|
||||
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
let beobachten = Arc::new(BeobachtenStore::new());
|
||||
|
||||
let (server_side, mut client_side) = duplex(2048);
|
||||
let peer: SocketAddr = "198.51.100.45:55000".parse().unwrap();
|
||||
|
||||
let handler = tokio::spawn(handle_client_stream(
|
||||
server_side,
|
||||
peer,
|
||||
config,
|
||||
stats,
|
||||
upstream_manager,
|
||||
replay_checker,
|
||||
buffer_pool,
|
||||
rng,
|
||||
None,
|
||||
route_runtime,
|
||||
None,
|
||||
ip_tracker,
|
||||
beobachten,
|
||||
true,
|
||||
));
|
||||
|
||||
let proxy_header = ProxyProtocolV1Builder::new()
|
||||
.tcp4(
|
||||
"203.0.113.9:32000".parse().unwrap(),
|
||||
"192.0.2.8:443".parse().unwrap(),
|
||||
)
|
||||
.build();
|
||||
client_side.write_all(&proxy_header).await.unwrap();
|
||||
drop(client_side);
|
||||
|
||||
let result = tokio::time::timeout(Duration::from_secs(3), handler)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn oversized_tls_record_is_masked_in_generic_stream_pipeline() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ use std::fs::OpenOptions;
|
|||
use std::io::Write;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
|
|
@ -22,6 +24,45 @@ use crate::stats::Stats;
|
|||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
||||
use crate::transport::UpstreamManager;
|
||||
|
||||
const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024;
|
||||
static LOGGED_UNKNOWN_DCS: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new();
|
||||
|
||||
// In tests, this function shares global mutable state. Callers that also use
|
||||
// cache-reset helpers must hold `unknown_dc_test_lock()` to keep assertions
|
||||
// deterministic under parallel execution.
|
||||
fn should_log_unknown_dc(dc_idx: i16) -> bool {
|
||||
let set = LOGGED_UNKNOWN_DCS.get_or_init(|| Mutex::new(HashSet::new()));
|
||||
match set.lock() {
|
||||
Ok(mut guard) => {
|
||||
if guard.contains(&dc_idx) {
|
||||
return false;
|
||||
}
|
||||
if guard.len() >= UNKNOWN_DC_LOG_DISTINCT_LIMIT {
|
||||
return false;
|
||||
}
|
||||
guard.insert(dc_idx)
|
||||
}
|
||||
// If the lock is poisoned, keep logging rather than silently dropping
|
||||
// operator-visible diagnostics.
|
||||
Err(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn clear_unknown_dc_log_cache_for_testing() {
|
||||
if let Some(set) = LOGGED_UNKNOWN_DCS.get()
|
||||
&& let Ok(mut guard) = set.lock()
|
||||
{
|
||||
guard.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn unknown_dc_test_lock() -> &'static Mutex<()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_via_direct<R, W>(
|
||||
client_reader: CryptoReader<R>,
|
||||
client_writer: CryptoWriter<W>,
|
||||
|
|
@ -64,7 +105,6 @@ where
|
|||
debug!(peer = %success.peer, "TG handshake complete, starting relay");
|
||||
|
||||
stats.increment_user_connects(user);
|
||||
stats.increment_user_curr_connects(user);
|
||||
stats.increment_current_connections_direct();
|
||||
|
||||
let relay_result = relay_bidirectional(
|
||||
|
|
@ -109,7 +149,6 @@ where
|
|||
};
|
||||
|
||||
stats.decrement_current_connections_direct();
|
||||
stats.decrement_user_curr_connects(user);
|
||||
|
||||
match &relay_result {
|
||||
Ok(()) => debug!(user = %user, "Direct relay completed"),
|
||||
|
|
@ -160,6 +199,7 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
|
|||
warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster");
|
||||
if config.general.unknown_dc_file_log_enabled
|
||||
&& let Some(path) = &config.general.unknown_dc_log_path
|
||||
&& should_log_unknown_dc(dc_idx)
|
||||
&& let Ok(handle) = tokio::runtime::Handle::try_current()
|
||||
{
|
||||
let path = path.clone();
|
||||
|
|
@ -175,7 +215,7 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
|
|||
let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs {
|
||||
default_dc - 1
|
||||
} else {
|
||||
1
|
||||
0
|
||||
};
|
||||
|
||||
info!(
|
||||
|
|
@ -203,8 +243,6 @@ async fn do_tg_handshake_static(
|
|||
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce(
|
||||
success.proto_tag,
|
||||
success.dc_idx,
|
||||
&success.dec_key,
|
||||
success.dec_iv,
|
||||
&success.enc_key,
|
||||
success.enc_iv,
|
||||
rng,
|
||||
|
|
@ -230,3 +268,7 @@ async fn do_tg_handshake_static(
|
|||
CryptoWriter::new(write_half, tg_encryptor, max_pending),
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "direct_relay_security_tests.rs"]
|
||||
mod security_tests;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,51 @@
|
|||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn unknown_dc_log_is_deduplicated_per_dc_idx() {
|
||||
let _guard = unknown_dc_test_lock()
|
||||
.lock()
|
||||
.expect("unknown dc test lock must be available");
|
||||
clear_unknown_dc_log_cache_for_testing();
|
||||
|
||||
assert!(should_log_unknown_dc(777));
|
||||
assert!(
|
||||
!should_log_unknown_dc(777),
|
||||
"same unknown dc_idx must not be logged repeatedly"
|
||||
);
|
||||
assert!(
|
||||
should_log_unknown_dc(778),
|
||||
"different unknown dc_idx must still be loggable"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_dc_log_respects_distinct_limit() {
|
||||
let _guard = unknown_dc_test_lock()
|
||||
.lock()
|
||||
.expect("unknown dc test lock must be available");
|
||||
clear_unknown_dc_log_cache_for_testing();
|
||||
|
||||
for dc in 1..=UNKNOWN_DC_LOG_DISTINCT_LIMIT {
|
||||
assert!(
|
||||
should_log_unknown_dc(dc as i16),
|
||||
"expected first-time unknown dc_idx to be loggable"
|
||||
);
|
||||
}
|
||||
|
||||
assert!(
|
||||
!should_log_unknown_dc(i16::MAX),
|
||||
"distinct unknown dc_idx entries above limit must not be logged"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fallback_dc_never_panics_with_single_dc_list() {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.network.prefer = 6;
|
||||
cfg.network.ipv6 = Some(true);
|
||||
cfg.default_dc = Some(42);
|
||||
|
||||
let addr = get_dc_addr_static(999, &cfg).expect("fallback dc must resolve safely");
|
||||
let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT);
|
||||
assert_eq!(addr, expected);
|
||||
}
|
||||
|
|
@ -4,9 +4,11 @@
|
|||
|
||||
use std::net::SocketAddr;
|
||||
use std::collections::HashSet;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, Instant};
|
||||
use dashmap::DashMap;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tracing::{debug, warn, trace};
|
||||
use zeroize::Zeroize;
|
||||
|
|
@ -22,10 +24,138 @@ use crate::config::ProxyConfig;
|
|||
use crate::tls_front::{TlsFrontCache, emulator};
|
||||
|
||||
const ACCESS_SECRET_BYTES: usize = 16;
|
||||
static INVALID_SECRET_WARNED: OnceLock<Mutex<HashSet<String>>> = OnceLock::new();
|
||||
static INVALID_SECRET_WARNED: OnceLock<Mutex<HashSet<(String, String)>>> = OnceLock::new();
|
||||
|
||||
const AUTH_PROBE_TRACK_RETENTION_SECS: u64 = 10 * 60;
|
||||
const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536;
|
||||
const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4;
|
||||
|
||||
#[cfg(test)]
|
||||
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1;
|
||||
#[cfg(not(test))]
|
||||
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 25;
|
||||
|
||||
#[cfg(test)]
|
||||
const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 16;
|
||||
#[cfg(not(test))]
|
||||
const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 1_000;
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct AuthProbeState {
|
||||
fail_streak: u32,
|
||||
blocked_until: Instant,
|
||||
last_seen: Instant,
|
||||
}
|
||||
|
||||
static AUTH_PROBE_STATE: OnceLock<DashMap<IpAddr, AuthProbeState>> = OnceLock::new();
|
||||
|
||||
fn auth_probe_state_map() -> &'static DashMap<IpAddr, AuthProbeState> {
|
||||
AUTH_PROBE_STATE.get_or_init(DashMap::new)
|
||||
}
|
||||
|
||||
fn auth_probe_backoff(fail_streak: u32) -> Duration {
|
||||
if fail_streak < AUTH_PROBE_BACKOFF_START_FAILS {
|
||||
return Duration::ZERO;
|
||||
}
|
||||
let shift = (fail_streak - AUTH_PROBE_BACKOFF_START_FAILS).min(10);
|
||||
let multiplier = 1u64.checked_shl(shift).unwrap_or(u64::MAX);
|
||||
let ms = AUTH_PROBE_BACKOFF_BASE_MS
|
||||
.saturating_mul(multiplier)
|
||||
.min(AUTH_PROBE_BACKOFF_MAX_MS);
|
||||
Duration::from_millis(ms)
|
||||
}
|
||||
|
||||
fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool {
|
||||
let retention = Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS);
|
||||
now.duration_since(state.last_seen) > retention
|
||||
}
|
||||
|
||||
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
|
||||
let state = auth_probe_state_map();
|
||||
let Some(entry) = state.get(&peer_ip) else {
|
||||
return false;
|
||||
};
|
||||
if auth_probe_state_expired(&entry, now) {
|
||||
drop(entry);
|
||||
state.remove(&peer_ip);
|
||||
return false;
|
||||
}
|
||||
now < entry.blocked_until
|
||||
}
|
||||
|
||||
fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) {
|
||||
let state = auth_probe_state_map();
|
||||
if let Some(mut entry) = state.get_mut(&peer_ip) {
|
||||
if auth_probe_state_expired(&entry, now) {
|
||||
*entry = AuthProbeState {
|
||||
fail_streak: 1,
|
||||
blocked_until: now + auth_probe_backoff(1),
|
||||
last_seen: now,
|
||||
};
|
||||
return;
|
||||
}
|
||||
entry.fail_streak = entry.fail_streak.saturating_add(1);
|
||||
entry.last_seen = now;
|
||||
entry.blocked_until = now + auth_probe_backoff(entry.fail_streak);
|
||||
return;
|
||||
};
|
||||
|
||||
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||
return;
|
||||
}
|
||||
|
||||
state.insert(peer_ip, AuthProbeState {
|
||||
fail_streak: 0,
|
||||
blocked_until: now,
|
||||
last_seen: now,
|
||||
});
|
||||
|
||||
if let Some(mut entry) = state.get_mut(&peer_ip) {
|
||||
entry.fail_streak = 1;
|
||||
entry.blocked_until = now + auth_probe_backoff(1);
|
||||
}
|
||||
}
|
||||
|
||||
fn auth_probe_record_success(peer_ip: IpAddr) {
|
||||
let state = auth_probe_state_map();
|
||||
state.remove(&peer_ip);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn clear_auth_probe_state_for_testing() {
|
||||
if let Some(state) = AUTH_PROBE_STATE.get() {
|
||||
state.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn auth_probe_fail_streak_for_testing(peer_ip: IpAddr) -> Option<u32> {
|
||||
let state = AUTH_PROBE_STATE.get()?;
|
||||
state.get(&peer_ip).map(|entry| entry.fail_streak)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn auth_probe_is_throttled_for_testing(peer_ip: IpAddr) -> bool {
|
||||
auth_probe_is_throttled(peer_ip, Instant::now())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn auth_probe_test_lock() -> &'static Mutex<()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn clear_warned_secrets_for_testing() {
|
||||
if let Some(warned) = INVALID_SECRET_WARNED.get()
|
||||
&& let Ok(mut guard) = warned.lock()
|
||||
{
|
||||
guard.clear();
|
||||
}
|
||||
}
|
||||
|
||||
fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option<usize>) {
|
||||
let key = format!("{}:{}", name, reason);
|
||||
let key = (name.to_string(), reason.to_string());
|
||||
let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new()));
|
||||
let should_warn = match warned.lock() {
|
||||
Ok(mut guard) => guard.insert(key),
|
||||
|
|
@ -170,6 +300,11 @@ where
|
|||
{
|
||||
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
|
||||
|
||||
if auth_probe_is_throttled(peer.ip(), Instant::now()) {
|
||||
debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle");
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
|
||||
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
|
||||
debug!(peer = %peer, "TLS handshake too short");
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
|
|
@ -177,13 +312,15 @@ where
|
|||
|
||||
let secrets = decode_user_secrets(config, None);
|
||||
|
||||
let validation = match tls::validate_tls_handshake(
|
||||
let validation = match tls::validate_tls_handshake_with_replay_window(
|
||||
handshake,
|
||||
&secrets,
|
||||
config.access.ignore_time_skew,
|
||||
config.access.replay_window_secs,
|
||||
) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||
debug!(
|
||||
peer = %peer,
|
||||
ignore_time_skew = config.access.ignore_time_skew,
|
||||
|
|
@ -197,6 +334,7 @@ where
|
|||
// letting unauthenticated probes evict valid entries from the replay cache.
|
||||
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
|
||||
if replay_checker.check_and_add_tls_digest(digest_half) {
|
||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
|
|
@ -307,6 +445,8 @@ where
|
|||
"TLS handshake successful"
|
||||
);
|
||||
|
||||
auth_probe_record_success(peer.ip());
|
||||
|
||||
HandshakeResult::Success((
|
||||
FakeTlsReader::new(reader),
|
||||
FakeTlsWriter::new(writer),
|
||||
|
|
@ -331,6 +471,11 @@ where
|
|||
{
|
||||
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
|
||||
|
||||
if auth_probe_is_throttled(peer.ip(), Instant::now()) {
|
||||
debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle");
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
|
||||
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
|
||||
|
||||
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
|
||||
|
|
@ -396,6 +541,7 @@ where
|
|||
// entry from the cache. We accept the cost of performing the full
|
||||
// authentication check first to avoid poisoning the replay cache.
|
||||
if replay_checker.check_and_add_handshake(dec_prekey_iv) {
|
||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||
warn!(peer = %peer, user = %user, "MTProto replay attack detected");
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
|
|
@ -421,6 +567,8 @@ where
|
|||
"MTProto handshake successful"
|
||||
);
|
||||
|
||||
auth_probe_record_success(peer.ip());
|
||||
|
||||
let max_pending = config.general.crypto_pending_buffer;
|
||||
return HandshakeResult::Success((
|
||||
CryptoReader::new(reader, decryptor),
|
||||
|
|
@ -429,6 +577,7 @@ where
|
|||
));
|
||||
}
|
||||
|
||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||
debug!(peer = %peer, "MTProto handshake: no matching user found");
|
||||
HandshakeResult::BadClient { reader, writer }
|
||||
}
|
||||
|
|
@ -437,8 +586,6 @@ where
|
|||
pub fn generate_tg_nonce(
|
||||
proto_tag: ProtoTag,
|
||||
dc_idx: i16,
|
||||
_client_dec_key: &[u8; 32],
|
||||
_client_dec_iv: u128,
|
||||
client_enc_key: &[u8; 32],
|
||||
client_enc_iv: u128,
|
||||
rng: &SecureRandom,
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@ fn make_valid_tls_client_hello_with_alpn(
|
|||
}
|
||||
|
||||
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
|
||||
clear_auth_probe_state_for_testing();
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.access.users.clear();
|
||||
cfg.access
|
||||
|
|
@ -93,8 +94,6 @@ fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
|
|||
|
||||
#[test]
|
||||
fn test_generate_tg_nonce() {
|
||||
let client_dec_key = [0x42u8; 32];
|
||||
let client_dec_iv = 12345u128;
|
||||
let client_enc_key = [0x24u8; 32];
|
||||
let client_enc_iv = 54321u128;
|
||||
|
||||
|
|
@ -102,8 +101,6 @@ fn test_generate_tg_nonce() {
|
|||
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce(
|
||||
ProtoTag::Secure,
|
||||
2,
|
||||
&client_dec_key,
|
||||
client_dec_iv,
|
||||
&client_enc_key,
|
||||
client_enc_iv,
|
||||
&rng,
|
||||
|
|
@ -118,8 +115,6 @@ fn test_generate_tg_nonce() {
|
|||
|
||||
#[test]
|
||||
fn test_encrypt_tg_nonce() {
|
||||
let client_dec_key = [0x42u8; 32];
|
||||
let client_dec_iv = 12345u128;
|
||||
let client_enc_key = [0x24u8; 32];
|
||||
let client_enc_iv = 54321u128;
|
||||
|
||||
|
|
@ -127,8 +122,6 @@ fn test_encrypt_tg_nonce() {
|
|||
let (nonce, _, _, _, _) = generate_tg_nonce(
|
||||
ProtoTag::Secure,
|
||||
2,
|
||||
&client_dec_key,
|
||||
client_dec_iv,
|
||||
&client_enc_key,
|
||||
client_enc_iv,
|
||||
&rng,
|
||||
|
|
@ -164,8 +157,6 @@ fn test_handshake_success_drop_does_not_panic() {
|
|||
|
||||
#[test]
|
||||
fn test_generate_tg_nonce_enc_dec_material_is_consistent() {
|
||||
let client_dec_key = [0x12u8; 32];
|
||||
let client_dec_iv = 0x11223344556677889900aabbccddeeffu128;
|
||||
let client_enc_key = [0x34u8; 32];
|
||||
let client_enc_iv = 0xffeeddccbbaa00998877665544332211u128;
|
||||
let rng = SecureRandom::new();
|
||||
|
|
@ -173,8 +164,6 @@ fn test_generate_tg_nonce_enc_dec_material_is_consistent() {
|
|||
let (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv) = generate_tg_nonce(
|
||||
ProtoTag::Secure,
|
||||
7,
|
||||
&client_dec_key,
|
||||
client_dec_iv,
|
||||
&client_enc_key,
|
||||
client_enc_iv,
|
||||
&rng,
|
||||
|
|
@ -209,8 +198,6 @@ fn test_generate_tg_nonce_enc_dec_material_is_consistent() {
|
|||
|
||||
#[test]
|
||||
fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() {
|
||||
let client_dec_key = [0x22u8; 32];
|
||||
let client_dec_iv = 0x0102030405060708090a0b0c0d0e0f10u128;
|
||||
let client_enc_key = [0xABu8; 32];
|
||||
let client_enc_iv = 0x11223344556677889900aabbccddeeffu128;
|
||||
let rng = SecureRandom::new();
|
||||
|
|
@ -218,8 +205,6 @@ fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() {
|
|||
let (nonce, _, _, _, _) = generate_tg_nonce(
|
||||
ProtoTag::Secure,
|
||||
9,
|
||||
&client_dec_key,
|
||||
client_dec_iv,
|
||||
&client_enc_key,
|
||||
client_enc_iv,
|
||||
&rng,
|
||||
|
|
@ -236,8 +221,6 @@ fn test_generate_tg_nonce_fast_mode_embeds_reversed_client_enc_material() {
|
|||
|
||||
#[test]
|
||||
fn test_encrypt_tg_nonce_with_ciphers_matches_manual_suffix_encryption() {
|
||||
let client_dec_key = [0x42u8; 32];
|
||||
let client_dec_iv = 12345u128;
|
||||
let client_enc_key = [0x24u8; 32];
|
||||
let client_enc_iv = 54321u128;
|
||||
|
||||
|
|
@ -245,8 +228,6 @@ fn test_encrypt_tg_nonce_with_ciphers_matches_manual_suffix_encryption() {
|
|||
let (nonce, _, _, _, _) = generate_tg_nonce(
|
||||
ProtoTag::Secure,
|
||||
2,
|
||||
&client_dec_key,
|
||||
client_dec_iv,
|
||||
&client_enc_key,
|
||||
client_enc_iv,
|
||||
&rng,
|
||||
|
|
@ -386,6 +367,7 @@ async fn invalid_tls_probe_does_not_pollute_replay_cache() {
|
|||
|
||||
#[tokio::test]
|
||||
async fn empty_decoded_secret_is_rejected() {
|
||||
clear_warned_secrets_for_testing();
|
||||
let config = test_config_with_secret_hex("");
|
||||
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
|
|
@ -409,6 +391,7 @@ async fn empty_decoded_secret_is_rejected() {
|
|||
|
||||
#[tokio::test]
|
||||
async fn wrong_length_decoded_secret_is_rejected() {
|
||||
clear_warned_secrets_for_testing();
|
||||
let config = test_config_with_secret_hex("aa");
|
||||
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
|
|
@ -458,6 +441,8 @@ async fn invalid_mtproto_probe_does_not_pollute_replay_cache() {
|
|||
|
||||
#[tokio::test]
|
||||
async fn mixed_secret_lengths_keep_valid_user_authenticating() {
|
||||
clear_warned_secrets_for_testing();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let good_secret = [0x22u8; 16];
|
||||
let mut config = ProxyConfig::default();
|
||||
config.access.users.clear();
|
||||
|
|
@ -582,6 +567,7 @@ async fn malformed_tls_classes_complete_within_bounded_time() {
|
|||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "timing-sensitive; run manually on low-jitter hosts"]
|
||||
async fn malformed_tls_classes_share_close_latency_buckets() {
|
||||
const ITER: usize = 24;
|
||||
const BUCKET_MS: u128 = 10;
|
||||
|
|
@ -680,3 +666,114 @@ fn secure_tag_requires_secure_mode_on_direct_transport() {
|
|||
"Secure tag without TLS must be accepted when secure mode is enabled"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_secret_warning_keys_do_not_collide_on_colon_boundaries() {
|
||||
clear_warned_secrets_for_testing();
|
||||
|
||||
warn_invalid_secret_once("a:b", "c", ACCESS_SECRET_BYTES, Some(1));
|
||||
warn_invalid_secret_once("a", "b:c", ACCESS_SECRET_BYTES, Some(2));
|
||||
|
||||
let warned = INVALID_SECRET_WARNED
|
||||
.get()
|
||||
.expect("warned set must be initialized");
|
||||
let guard = warned.lock().expect("warned set lock must be available");
|
||||
assert_eq!(
|
||||
guard.len(),
|
||||
2,
|
||||
"(name, reason) pairs that stringify to the same colon-joined key must remain distinct"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn repeated_invalid_tls_probes_trigger_pre_auth_throttle() {
|
||||
let _guard = auth_probe_test_lock()
|
||||
.lock()
|
||||
.expect("auth probe test lock must be available");
|
||||
clear_auth_probe_state_for_testing();
|
||||
|
||||
let config = test_config_with_secret_hex("11111111111111111111111111111111");
|
||||
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
let peer: SocketAddr = "127.0.0.1:44361".parse().unwrap();
|
||||
|
||||
let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32];
|
||||
invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32;
|
||||
|
||||
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
|
||||
let result = handle_tls_handshake(
|
||||
&invalid,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
peer,
|
||||
&config,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||
}
|
||||
|
||||
assert!(
|
||||
auth_probe_is_throttled_for_testing(peer.ip()),
|
||||
"invalid probe burst must activate per-IP pre-auth throttle"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn successful_tls_handshake_clears_pre_auth_failure_streak() {
|
||||
let _guard = auth_probe_test_lock()
|
||||
.lock()
|
||||
.expect("auth probe test lock must be available");
|
||||
clear_auth_probe_state_for_testing();
|
||||
|
||||
let secret = [0x23u8; 16];
|
||||
let config = test_config_with_secret_hex("23232323232323232323232323232323");
|
||||
let replay_checker = ReplayChecker::new(256, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
let peer: SocketAddr = "127.0.0.1:44362".parse().unwrap();
|
||||
|
||||
let mut invalid = vec![0x42u8; tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + 32];
|
||||
invalid[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = 32;
|
||||
|
||||
for expected in 1..AUTH_PROBE_BACKOFF_START_FAILS {
|
||||
let result = handle_tls_handshake(
|
||||
&invalid,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
peer,
|
||||
&config,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing(peer.ip()),
|
||||
Some(expected),
|
||||
"failure streak must grow before a successful authentication"
|
||||
);
|
||||
}
|
||||
|
||||
let valid = make_valid_tls_handshake(&secret, 0);
|
||||
let success = handle_tls_handshake(
|
||||
&valid,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
peer,
|
||||
&config,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(success, HandshakeResult::Success(_)));
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing(peer.ip()),
|
||||
None,
|
||||
"successful authentication must clear accumulated pre-auth failures"
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -232,6 +232,9 @@ where
|
|||
if mask_write.write_all(initial_data).await.is_err() {
|
||||
return;
|
||||
}
|
||||
if mask_write.flush().await.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut client_buf = vec![0u8; MASK_BUFFER_SIZE];
|
||||
let mut mask_buf = vec![0u8; MASK_BUFFER_SIZE];
|
||||
|
|
|
|||
|
|
@ -122,6 +122,12 @@ fn detect_client_type_covers_ssh_port_scanner_and_unknown() {
|
|||
assert_eq!(detect_client_type(b"random-binary-payload"), "unknown");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_client_type_len_boundary_9_vs_10_bytes() {
|
||||
assert_eq!(detect_client_type(b"123456789"), "port-scanner");
|
||||
assert_eq!(detect_client_type(b"1234567890"), "unknown");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn beobachten_records_scanner_class_when_mask_is_disabled() {
|
||||
let mut config = ProxyConfig::default();
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
use std::collections::HashMap;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::time::{Duration, Instant};
|
||||
#[cfg(test)]
|
||||
use std::sync::Mutex;
|
||||
|
||||
use bytes::Bytes;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use dashmap::DashMap;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::sync::{mpsc, oneshot, watch};
|
||||
use tracing::{debug, trace, warn};
|
||||
|
|
@ -30,13 +32,15 @@ enum C2MeCommand {
|
|||
}
|
||||
|
||||
const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60);
|
||||
const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536;
|
||||
const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024;
|
||||
const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
|
||||
const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
|
||||
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
|
||||
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
|
||||
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
|
||||
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
|
||||
static DESYNC_DEDUP: OnceLock<Mutex<HashMap<u64, Instant>>> = OnceLock::new();
|
||||
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
|
||||
|
||||
struct RelayForensicsState {
|
||||
trace_id: u64,
|
||||
|
|
@ -90,24 +94,46 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
|
|||
return true;
|
||||
}
|
||||
|
||||
let dedup = DESYNC_DEDUP.get_or_init(|| Mutex::new(HashMap::new()));
|
||||
let mut guard = dedup.lock().expect("desync dedup mutex poisoned");
|
||||
guard.retain(|_, seen_at| now.duration_since(*seen_at) < DESYNC_DEDUP_WINDOW);
|
||||
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
|
||||
|
||||
match guard.get_mut(&key) {
|
||||
Some(seen_at) => {
|
||||
if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW {
|
||||
*seen_at = now;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
if let Some(mut seen_at) = dedup.get_mut(&key) {
|
||||
if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW {
|
||||
*seen_at = now;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
|
||||
let mut stale_keys = Vec::new();
|
||||
for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) {
|
||||
if now.duration_since(*entry.value()) >= DESYNC_DEDUP_WINDOW {
|
||||
stale_keys.push(*entry.key());
|
||||
}
|
||||
}
|
||||
None => {
|
||||
guard.insert(key, now);
|
||||
true
|
||||
for stale_key in stale_keys {
|
||||
dedup.remove(&stale_key);
|
||||
}
|
||||
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
dedup.insert(key, now);
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn clear_desync_dedup_for_testing() {
|
||||
if let Some(dedup) = DESYNC_DEDUP.get() {
|
||||
dedup.clear();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn desync_dedup_test_lock() -> &'static Mutex<()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
fn report_desync_frame_too_large(
|
||||
|
|
@ -229,7 +255,7 @@ pub(crate) async fn handle_via_middle_proxy<R, W>(
|
|||
me_pool: Arc<MePool>,
|
||||
stats: Arc<Stats>,
|
||||
config: Arc<ProxyConfig>,
|
||||
_buffer_pool: Arc<BufferPool>,
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
local_addr: SocketAddr,
|
||||
rng: Arc<SecureRandom>,
|
||||
mut route_rx: watch::Receiver<RouteCutoverState>,
|
||||
|
|
@ -271,7 +297,6 @@ where
|
|||
};
|
||||
|
||||
stats.increment_user_connects(&user);
|
||||
stats.increment_user_curr_connects(&user);
|
||||
stats.increment_current_connections_me();
|
||||
|
||||
if let Some(cutover) = affected_cutover_state(
|
||||
|
|
@ -291,7 +316,6 @@ where
|
|||
let _ = me_pool.send_close(conn_id).await;
|
||||
me_pool.registry().unregister(conn_id).await;
|
||||
stats.decrement_current_connections_me();
|
||||
stats.decrement_user_curr_connects(&user);
|
||||
return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
|
||||
}
|
||||
|
||||
|
|
@ -557,6 +581,7 @@ where
|
|||
&mut crypto_reader,
|
||||
proto_tag,
|
||||
frame_limit,
|
||||
&buffer_pool,
|
||||
&forensics,
|
||||
&mut frame_counter,
|
||||
&stats,
|
||||
|
|
@ -638,7 +663,6 @@ where
|
|||
);
|
||||
me_pool.registry().unregister(conn_id).await;
|
||||
stats.decrement_current_connections_me();
|
||||
stats.decrement_user_curr_connects(&user);
|
||||
result
|
||||
}
|
||||
|
||||
|
|
@ -646,6 +670,7 @@ async fn read_client_payload<R>(
|
|||
client_reader: &mut CryptoReader<R>,
|
||||
proto_tag: ProtoTag,
|
||||
max_frame: usize,
|
||||
buffer_pool: &Arc<BufferPool>,
|
||||
forensics: &RelayForensicsState,
|
||||
frame_counter: &mut u64,
|
||||
stats: &Stats,
|
||||
|
|
@ -737,18 +762,27 @@ where
|
|||
len
|
||||
};
|
||||
|
||||
let mut payload = vec![0u8; len];
|
||||
client_reader
|
||||
.read_exact(&mut payload)
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
let chunk_cap = buffer_pool.buffer_size().max(1024);
|
||||
let mut payload = BytesMut::with_capacity(len.min(chunk_cap));
|
||||
let mut remaining = len;
|
||||
while remaining > 0 {
|
||||
let chunk_len = remaining.min(chunk_cap);
|
||||
let mut chunk = buffer_pool.get();
|
||||
chunk.resize(chunk_len, 0);
|
||||
client_reader
|
||||
.read_exact(&mut chunk[..chunk_len])
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
payload.extend_from_slice(&chunk[..chunk_len]);
|
||||
remaining -= chunk_len;
|
||||
}
|
||||
|
||||
// Secure Intermediate: strip validated trailing padding bytes.
|
||||
if proto_tag == ProtoTag::Secure {
|
||||
payload.truncate(secure_payload_len);
|
||||
}
|
||||
*frame_counter += 1;
|
||||
return Ok(Some((Bytes::from(payload), quickack)));
|
||||
return Ok(Some((payload.freeze(), quickack)));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -940,82 +974,5 @@ where
|
|||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::time::{Duration as TokioDuration, timeout};
|
||||
|
||||
#[test]
|
||||
fn should_yield_sender_only_on_budget_with_backlog() {
|
||||
assert!(!should_yield_c2me_sender(0, true));
|
||||
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true));
|
||||
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false));
|
||||
assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enqueue_c2me_command_uses_try_send_fast_path() {
|
||||
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(2);
|
||||
enqueue_c2me_command(
|
||||
&tx,
|
||||
C2MeCommand::Data {
|
||||
payload: Bytes::from_static(&[1, 2, 3]),
|
||||
flags: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let recv = timeout(TokioDuration::from_millis(50), rx.recv())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
match recv {
|
||||
C2MeCommand::Data { payload, flags } => {
|
||||
assert_eq!(payload.as_ref(), &[1, 2, 3]);
|
||||
assert_eq!(flags, 0);
|
||||
}
|
||||
C2MeCommand::Close => panic!("unexpected close command"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
|
||||
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
|
||||
tx.send(C2MeCommand::Data {
|
||||
payload: Bytes::from_static(&[9]),
|
||||
flags: 9,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tx2 = tx.clone();
|
||||
let producer = tokio::spawn(async move {
|
||||
enqueue_c2me_command(
|
||||
&tx2,
|
||||
C2MeCommand::Data {
|
||||
payload: Bytes::from_static(&[7, 7]),
|
||||
flags: 7,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let _ = timeout(TokioDuration::from_millis(100), rx.recv())
|
||||
.await
|
||||
.unwrap();
|
||||
producer.await.unwrap();
|
||||
|
||||
let recv = timeout(TokioDuration::from_millis(100), rx.recv())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
match recv {
|
||||
C2MeCommand::Data { payload, flags } => {
|
||||
assert_eq!(payload.as_ref(), &[7, 7]);
|
||||
assert_eq!(flags, 7);
|
||||
}
|
||||
C2MeCommand::Close => panic!("unexpected close command"),
|
||||
}
|
||||
}
|
||||
}
|
||||
#[path = "middle_relay_security_tests.rs"]
|
||||
mod security_tests;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,103 @@
|
|||
use super::*;
|
||||
use tokio::time::{Duration as TokioDuration, timeout};
|
||||
|
||||
#[test]
|
||||
fn should_yield_sender_only_on_budget_with_backlog() {
|
||||
assert!(!should_yield_c2me_sender(0, true));
|
||||
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true));
|
||||
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false));
|
||||
assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enqueue_c2me_command_uses_try_send_fast_path() {
|
||||
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(2);
|
||||
enqueue_c2me_command(
|
||||
&tx,
|
||||
C2MeCommand::Data {
|
||||
payload: Bytes::from_static(&[1, 2, 3]),
|
||||
flags: 0,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let recv = timeout(TokioDuration::from_millis(50), rx.recv())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
match recv {
|
||||
C2MeCommand::Data { payload, flags } => {
|
||||
assert_eq!(payload.as_ref(), &[1, 2, 3]);
|
||||
assert_eq!(flags, 0);
|
||||
}
|
||||
C2MeCommand::Close => panic!("unexpected close command"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
|
||||
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
|
||||
tx.send(C2MeCommand::Data {
|
||||
payload: Bytes::from_static(&[9]),
|
||||
flags: 9,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tx2 = tx.clone();
|
||||
let producer = tokio::spawn(async move {
|
||||
enqueue_c2me_command(
|
||||
&tx2,
|
||||
C2MeCommand::Data {
|
||||
payload: Bytes::from_static(&[7, 7]),
|
||||
flags: 7,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let _ = timeout(TokioDuration::from_millis(100), rx.recv())
|
||||
.await
|
||||
.unwrap();
|
||||
producer.await.unwrap();
|
||||
|
||||
let recv = timeout(TokioDuration::from_millis(100), rx.recv())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
match recv {
|
||||
C2MeCommand::Data { payload, flags } => {
|
||||
assert_eq!(payload.as_ref(), &[7, 7]);
|
||||
assert_eq!(flags, 7);
|
||||
}
|
||||
C2MeCommand::Close => panic!("unexpected close command"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn desync_dedup_cache_is_bounded() {
|
||||
let _guard = desync_dedup_test_lock()
|
||||
.lock()
|
||||
.expect("desync dedup test lock must be available");
|
||||
clear_desync_dedup_for_testing();
|
||||
|
||||
let now = Instant::now();
|
||||
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
|
||||
assert!(
|
||||
should_emit_full_desync(key, false, now),
|
||||
"unique keys up to cap must be tracked"
|
||||
);
|
||||
}
|
||||
|
||||
assert!(
|
||||
!should_emit_full_desync(u64::MAX, false, now),
|
||||
"new key above cap must be suppressed to bound memory"
|
||||
);
|
||||
|
||||
assert!(
|
||||
!should_emit_full_desync(7, false, now),
|
||||
"already tracked key inside dedup window must stay suppressed"
|
||||
);
|
||||
}
|
||||
|
|
@ -1256,6 +1256,33 @@ impl Stats {
|
|||
Self::touch_user_stats(stats.value());
|
||||
stats.curr_connects.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn try_acquire_user_curr_connects(&self, user: &str, limit: Option<u64>) -> bool {
|
||||
if !self.telemetry_user_enabled() {
|
||||
return true;
|
||||
}
|
||||
|
||||
self.maybe_cleanup_user_stats();
|
||||
let stats = self.user_stats.entry(user.to_string()).or_default();
|
||||
Self::touch_user_stats(stats.value());
|
||||
|
||||
let counter = &stats.curr_connects;
|
||||
let mut current = counter.load(Ordering::Relaxed);
|
||||
loop {
|
||||
if let Some(max) = limit && current >= max {
|
||||
return false;
|
||||
}
|
||||
match counter.compare_exchange_weak(
|
||||
current,
|
||||
current.saturating_add(1),
|
||||
Ordering::Relaxed,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => return true,
|
||||
Err(actual) => current = actual,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrement_user_curr_connects(&self, user: &str) {
|
||||
self.maybe_cleanup_user_stats();
|
||||
|
|
|
|||
|
|
@ -513,6 +513,7 @@ impl FrameCodecTrait for SecureCodec {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashSet;
|
||||
use tokio_util::codec::{FramedRead, FramedWrite};
|
||||
use tokio::io::duplex;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
|
|
@ -630,4 +631,31 @@ mod tests {
|
|||
let result = codec.decode(&mut buf);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn secure_codec_always_adds_padding_and_jitters_wire_length() {
|
||||
let codec = SecureCodec::new(Arc::new(SecureRandom::new()));
|
||||
let payload = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
let mut wire_lens = HashSet::new();
|
||||
|
||||
for _ in 0..64 {
|
||||
let frame = Frame::new(payload.clone());
|
||||
let mut out = BytesMut::new();
|
||||
codec.encode(&frame, &mut out).unwrap();
|
||||
|
||||
assert!(out.len() >= 4 + payload.len() + 1);
|
||||
let wire_len = u32::from_le_bytes([out[0], out[1], out[2], out[3]]) as usize;
|
||||
assert!(
|
||||
(payload.len() + 1..=payload.len() + 3).contains(&wire_len),
|
||||
"Secure wire length must be payload+1..3, got {wire_len}"
|
||||
);
|
||||
assert_ne!(wire_len % 4, 0, "Secure wire length must be non-4-aligned");
|
||||
wire_lens.insert(wire_len);
|
||||
}
|
||||
|
||||
assert!(
|
||||
wire_lens.len() >= 2,
|
||||
"Secure padding should create observable wire-length jitter"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue