mirror of
https://github.com/telemt/telemt.git
synced 2026-06-21 02:11:11 +03:00
Merge branch 'flow' into feature/metrics-build-info
This commit is contained in:
@@ -81,10 +81,21 @@ pub(super) struct ZeroCoreData {
|
||||
pub(super) connections_total: u64,
|
||||
pub(super) connections_bad_total: u64,
|
||||
pub(super) handshake_timeouts_total: u64,
|
||||
pub(super) accept_permit_timeout_total: u64,
|
||||
pub(super) configured_users: usize,
|
||||
pub(super) telemetry_core_enabled: bool,
|
||||
pub(super) telemetry_user_enabled: bool,
|
||||
pub(super) telemetry_me_level: String,
|
||||
pub(super) conntrack_control_enabled: bool,
|
||||
pub(super) conntrack_control_available: bool,
|
||||
pub(super) conntrack_pressure_active: bool,
|
||||
pub(super) conntrack_event_queue_depth: u64,
|
||||
pub(super) conntrack_rule_apply_ok: bool,
|
||||
pub(super) conntrack_delete_attempt_total: u64,
|
||||
pub(super) conntrack_delete_success_total: u64,
|
||||
pub(super) conntrack_delete_not_found_total: u64,
|
||||
pub(super) conntrack_delete_error_total: u64,
|
||||
pub(super) conntrack_close_event_drop_total: u64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
|
||||
@@ -39,10 +39,21 @@ pub(super) fn build_zero_all_data(stats: &Stats, configured_users: usize) -> Zer
|
||||
connections_total: stats.get_connects_all(),
|
||||
connections_bad_total: stats.get_connects_bad(),
|
||||
handshake_timeouts_total: stats.get_handshake_timeouts(),
|
||||
accept_permit_timeout_total: stats.get_accept_permit_timeout_total(),
|
||||
configured_users,
|
||||
telemetry_core_enabled: telemetry.core_enabled,
|
||||
telemetry_user_enabled: telemetry.user_enabled,
|
||||
telemetry_me_level: telemetry.me_level.to_string(),
|
||||
conntrack_control_enabled: stats.get_conntrack_control_enabled(),
|
||||
conntrack_control_available: stats.get_conntrack_control_available(),
|
||||
conntrack_pressure_active: stats.get_conntrack_pressure_active(),
|
||||
conntrack_event_queue_depth: stats.get_conntrack_event_queue_depth(),
|
||||
conntrack_rule_apply_ok: stats.get_conntrack_rule_apply_ok(),
|
||||
conntrack_delete_attempt_total: stats.get_conntrack_delete_attempt_total(),
|
||||
conntrack_delete_success_total: stats.get_conntrack_delete_success_total(),
|
||||
conntrack_delete_not_found_total: stats.get_conntrack_delete_not_found_total(),
|
||||
conntrack_delete_error_total: stats.get_conntrack_delete_error_total(),
|
||||
conntrack_close_event_drop_total: stats.get_conntrack_close_event_drop_total(),
|
||||
},
|
||||
upstream: build_zero_upstream_data(stats),
|
||||
middle_proxy: ZeroMiddleProxyData {
|
||||
|
||||
+23
-3
@@ -48,6 +48,10 @@ const DEFAULT_ME_POOL_DRAIN_SOFT_EVICT_BUDGET_PER_CORE: u16 = 16;
|
||||
const DEFAULT_ME_POOL_DRAIN_SOFT_EVICT_COOLDOWN_MS: u64 = 1000;
|
||||
const DEFAULT_USER_MAX_UNIQUE_IPS_WINDOW_SECS: u64 = 30;
|
||||
const DEFAULT_ACCEPT_PERMIT_TIMEOUT_MS: u64 = 250;
|
||||
const DEFAULT_CONNTRACK_CONTROL_ENABLED: bool = true;
|
||||
const DEFAULT_CONNTRACK_PRESSURE_HIGH_WATERMARK_PCT: u8 = 85;
|
||||
const DEFAULT_CONNTRACK_PRESSURE_LOW_WATERMARK_PCT: u8 = 70;
|
||||
const DEFAULT_CONNTRACK_DELETE_BUDGET_PER_SEC: u64 = 4096;
|
||||
const DEFAULT_UPSTREAM_CONNECT_RETRY_ATTEMPTS: u32 = 2;
|
||||
const DEFAULT_UPSTREAM_UNHEALTHY_FAIL_THRESHOLD: u32 = 5;
|
||||
const DEFAULT_UPSTREAM_CONNECT_BUDGET_MS: u64 = 3000;
|
||||
@@ -96,7 +100,7 @@ pub(crate) fn default_fake_cert_len() -> usize {
|
||||
}
|
||||
|
||||
pub(crate) fn default_tls_front_dir() -> String {
|
||||
"tlsfront".to_string()
|
||||
"/etc/telemt/tlsfront".to_string()
|
||||
}
|
||||
|
||||
pub(crate) fn default_replay_check_len() -> usize {
|
||||
@@ -221,6 +225,22 @@ pub(crate) fn default_accept_permit_timeout_ms() -> u64 {
|
||||
DEFAULT_ACCEPT_PERMIT_TIMEOUT_MS
|
||||
}
|
||||
|
||||
pub(crate) fn default_conntrack_control_enabled() -> bool {
|
||||
DEFAULT_CONNTRACK_CONTROL_ENABLED
|
||||
}
|
||||
|
||||
pub(crate) fn default_conntrack_pressure_high_watermark_pct() -> u8 {
|
||||
DEFAULT_CONNTRACK_PRESSURE_HIGH_WATERMARK_PCT
|
||||
}
|
||||
|
||||
pub(crate) fn default_conntrack_pressure_low_watermark_pct() -> u8 {
|
||||
DEFAULT_CONNTRACK_PRESSURE_LOW_WATERMARK_PCT
|
||||
}
|
||||
|
||||
pub(crate) fn default_conntrack_delete_budget_per_sec() -> u64 {
|
||||
DEFAULT_CONNTRACK_DELETE_BUDGET_PER_SEC
|
||||
}
|
||||
|
||||
pub(crate) fn default_prefer_4() -> u8 {
|
||||
4
|
||||
}
|
||||
@@ -282,7 +302,7 @@ pub(crate) fn default_me2dc_fallback() -> bool {
|
||||
}
|
||||
|
||||
pub(crate) fn default_me2dc_fast() -> bool {
|
||||
false
|
||||
true
|
||||
}
|
||||
|
||||
pub(crate) fn default_keepalive_interval() -> u64 {
|
||||
@@ -538,7 +558,7 @@ pub(crate) fn default_beobachten_flush_secs() -> u64 {
|
||||
}
|
||||
|
||||
pub(crate) fn default_beobachten_file() -> String {
|
||||
"cache/beobachten.txt".to_string()
|
||||
"/etc/telemt/beobachten.txt".to_string()
|
||||
}
|
||||
|
||||
pub(crate) fn default_tls_new_session_tickets() -> u8 {
|
||||
|
||||
@@ -540,6 +540,10 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
|
||||
cfg.access.user_max_unique_ips_mode = new.access.user_max_unique_ips_mode;
|
||||
cfg.access.user_max_unique_ips_window_secs = new.access.user_max_unique_ips_window_secs;
|
||||
|
||||
if cfg.rebuild_runtime_user_auth().is_err() {
|
||||
cfg.runtime_user_auth = None;
|
||||
}
|
||||
|
||||
cfg
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::collections::{BTreeSet, HashMap, HashSet};
|
||||
use std::hash::{DefaultHasher, Hash, Hasher};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use rand::RngExt;
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -15,6 +16,13 @@ use crate::error::{ProxyError, Result};
|
||||
use super::defaults::*;
|
||||
use super::types::*;
|
||||
|
||||
const ACCESS_SECRET_BYTES: usize = 16;
|
||||
const MAX_ME_WRITER_CMD_CHANNEL_CAPACITY: usize = 16_384;
|
||||
const MAX_ME_ROUTE_CHANNEL_CAPACITY: usize = 8_192;
|
||||
const MAX_ME_C2ME_CHANNEL_CAPACITY: usize = 8_192;
|
||||
const MIN_MAX_CLIENT_FRAME_BYTES: usize = 4 * 1024;
|
||||
const MAX_MAX_CLIENT_FRAME_BYTES: usize = 16 * 1024 * 1024;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct LoadedConfig {
|
||||
pub(crate) config: ProxyConfig,
|
||||
@@ -22,6 +30,111 @@ pub(crate) struct LoadedConfig {
|
||||
pub(crate) rendered_hash: u64,
|
||||
}
|
||||
|
||||
/// Precomputed, immutable user authentication data used by handshake hot paths.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub(crate) struct UserAuthSnapshot {
|
||||
entries: Vec<UserAuthEntry>,
|
||||
by_name: HashMap<String, u32>,
|
||||
sni_index: HashMap<u64, Vec<u32>>,
|
||||
sni_initial_index: HashMap<u8, Vec<u32>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct UserAuthEntry {
|
||||
pub(crate) user: String,
|
||||
pub(crate) secret: [u8; ACCESS_SECRET_BYTES],
|
||||
}
|
||||
|
||||
impl UserAuthSnapshot {
|
||||
fn from_users(users: &HashMap<String, String>) -> Result<Self> {
|
||||
let mut entries = Vec::with_capacity(users.len());
|
||||
let mut by_name = HashMap::with_capacity(users.len());
|
||||
let mut sni_index = HashMap::with_capacity(users.len());
|
||||
let mut sni_initial_index = HashMap::with_capacity(users.len());
|
||||
|
||||
for (user, secret_hex) in users {
|
||||
let decoded = hex::decode(secret_hex).map_err(|_| ProxyError::InvalidSecret {
|
||||
user: user.clone(),
|
||||
reason: "Must be 32 hex characters".to_string(),
|
||||
})?;
|
||||
if decoded.len() != ACCESS_SECRET_BYTES {
|
||||
return Err(ProxyError::InvalidSecret {
|
||||
user: user.clone(),
|
||||
reason: "Must be 32 hex characters".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
let user_id = u32::try_from(entries.len()).map_err(|_| {
|
||||
ProxyError::Config("Too many users for runtime auth snapshot".to_string())
|
||||
})?;
|
||||
|
||||
let mut secret = [0u8; ACCESS_SECRET_BYTES];
|
||||
secret.copy_from_slice(&decoded);
|
||||
entries.push(UserAuthEntry {
|
||||
user: user.clone(),
|
||||
secret,
|
||||
});
|
||||
by_name.insert(user.clone(), user_id);
|
||||
sni_index
|
||||
.entry(Self::sni_lookup_hash(user))
|
||||
.or_insert_with(Vec::new)
|
||||
.push(user_id);
|
||||
if let Some(initial) = user
|
||||
.as_bytes()
|
||||
.first()
|
||||
.map(|byte| byte.to_ascii_lowercase())
|
||||
{
|
||||
sni_initial_index
|
||||
.entry(initial)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(user_id);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
entries,
|
||||
by_name,
|
||||
sni_index,
|
||||
sni_initial_index,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn entries(&self) -> &[UserAuthEntry] {
|
||||
&self.entries
|
||||
}
|
||||
|
||||
pub(crate) fn user_id_by_name(&self, user: &str) -> Option<u32> {
|
||||
self.by_name.get(user).copied()
|
||||
}
|
||||
|
||||
pub(crate) fn entry_by_id(&self, user_id: u32) -> Option<&UserAuthEntry> {
|
||||
let idx = usize::try_from(user_id).ok()?;
|
||||
self.entries.get(idx)
|
||||
}
|
||||
|
||||
pub(crate) fn sni_candidates(&self, sni: &str) -> Option<&[u32]> {
|
||||
self.sni_index
|
||||
.get(&Self::sni_lookup_hash(sni))
|
||||
.map(Vec::as_slice)
|
||||
}
|
||||
|
||||
pub(crate) fn sni_initial_candidates(&self, sni: &str) -> Option<&[u32]> {
|
||||
let initial = sni
|
||||
.as_bytes()
|
||||
.first()
|
||||
.map(|byte| byte.to_ascii_lowercase())?;
|
||||
self.sni_initial_index.get(&initial).map(Vec::as_slice)
|
||||
}
|
||||
|
||||
fn sni_lookup_hash(value: &str) -> u64 {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
for byte in value.bytes() {
|
||||
hasher.write_u8(byte.to_ascii_lowercase());
|
||||
}
|
||||
hasher.finish()
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_config_path(path: &Path) -> PathBuf {
|
||||
path.canonicalize().unwrap_or_else(|_| {
|
||||
if path.is_absolute() {
|
||||
@@ -196,6 +309,10 @@ pub struct ProxyConfig {
|
||||
/// If not set, defaults to 2 (matching Telegram's official `default 2;` in proxy-multi.conf).
|
||||
#[serde(default)]
|
||||
pub default_dc: Option<u8>,
|
||||
|
||||
/// Precomputed authentication snapshot for handshake hot paths.
|
||||
#[serde(skip)]
|
||||
pub(crate) runtime_user_auth: Option<Arc<UserAuthSnapshot>>,
|
||||
}
|
||||
|
||||
impl ProxyConfig {
|
||||
@@ -514,18 +631,41 @@ impl ProxyConfig {
|
||||
"general.me_writer_cmd_channel_capacity must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
if config.general.me_writer_cmd_channel_capacity > MAX_ME_WRITER_CMD_CHANNEL_CAPACITY {
|
||||
return Err(ProxyError::Config(format!(
|
||||
"general.me_writer_cmd_channel_capacity must be within [1, {MAX_ME_WRITER_CMD_CHANNEL_CAPACITY}]"
|
||||
)));
|
||||
}
|
||||
|
||||
if config.general.me_route_channel_capacity == 0 {
|
||||
return Err(ProxyError::Config(
|
||||
"general.me_route_channel_capacity must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
if config.general.me_route_channel_capacity > MAX_ME_ROUTE_CHANNEL_CAPACITY {
|
||||
return Err(ProxyError::Config(format!(
|
||||
"general.me_route_channel_capacity must be within [1, {MAX_ME_ROUTE_CHANNEL_CAPACITY}]"
|
||||
)));
|
||||
}
|
||||
|
||||
if config.general.me_c2me_channel_capacity == 0 {
|
||||
return Err(ProxyError::Config(
|
||||
"general.me_c2me_channel_capacity must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
if config.general.me_c2me_channel_capacity > MAX_ME_C2ME_CHANNEL_CAPACITY {
|
||||
return Err(ProxyError::Config(format!(
|
||||
"general.me_c2me_channel_capacity must be within [1, {MAX_ME_C2ME_CHANNEL_CAPACITY}]"
|
||||
)));
|
||||
}
|
||||
|
||||
if !(MIN_MAX_CLIENT_FRAME_BYTES..=MAX_MAX_CLIENT_FRAME_BYTES)
|
||||
.contains(&config.general.max_client_frame)
|
||||
{
|
||||
return Err(ProxyError::Config(format!(
|
||||
"general.max_client_frame must be within [{MIN_MAX_CLIENT_FRAME_BYTES}, {MAX_MAX_CLIENT_FRAME_BYTES}]"
|
||||
)));
|
||||
}
|
||||
|
||||
if config.general.me_c2me_send_timeout_ms > 60_000 {
|
||||
return Err(ProxyError::Config(
|
||||
@@ -922,6 +1062,43 @@ impl ProxyConfig {
|
||||
));
|
||||
}
|
||||
|
||||
if config.server.conntrack_control.pressure_high_watermark_pct == 0
|
||||
|| config.server.conntrack_control.pressure_high_watermark_pct > 100
|
||||
{
|
||||
return Err(ProxyError::Config(
|
||||
"server.conntrack_control.pressure_high_watermark_pct must be within [1, 100]"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if config.server.conntrack_control.pressure_low_watermark_pct
|
||||
>= config.server.conntrack_control.pressure_high_watermark_pct
|
||||
{
|
||||
return Err(ProxyError::Config(
|
||||
"server.conntrack_control.pressure_low_watermark_pct must be < pressure_high_watermark_pct"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if config.server.conntrack_control.delete_budget_per_sec == 0 {
|
||||
return Err(ProxyError::Config(
|
||||
"server.conntrack_control.delete_budget_per_sec must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if matches!(config.server.conntrack_control.mode, ConntrackMode::Hybrid)
|
||||
&& config
|
||||
.server
|
||||
.conntrack_control
|
||||
.hybrid_listener_ips
|
||||
.is_empty()
|
||||
{
|
||||
return Err(ProxyError::Config(
|
||||
"server.conntrack_control.hybrid_listener_ips must be non-empty in mode=hybrid"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if config.general.effective_me_pool_force_close_secs() > 0
|
||||
&& config.general.effective_me_pool_force_close_secs()
|
||||
< config.general.me_pool_drain_ttl_secs
|
||||
@@ -1127,6 +1304,7 @@ impl ProxyConfig {
|
||||
.or_insert_with(|| vec!["91.105.192.100:443".to_string()]);
|
||||
|
||||
validate_upstreams(&config)?;
|
||||
config.rebuild_runtime_user_auth()?;
|
||||
|
||||
Ok(LoadedConfig {
|
||||
config,
|
||||
@@ -1135,6 +1313,16 @@ impl ProxyConfig {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn rebuild_runtime_user_auth(&mut self) -> Result<()> {
|
||||
let snapshot = UserAuthSnapshot::from_users(&self.access.users)?;
|
||||
self.runtime_user_auth = Some(Arc::new(snapshot));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn runtime_user_auth(&self) -> Option<&UserAuthSnapshot> {
|
||||
self.runtime_user_auth.as_deref()
|
||||
}
|
||||
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
if self.access.users.is_empty() {
|
||||
return Err(ProxyError::Config("No users configured".to_string()));
|
||||
@@ -1186,6 +1374,10 @@ mod load_mask_shape_security_tests;
|
||||
#[path = "tests/load_mask_classifier_prefetch_timeout_security_tests.rs"]
|
||||
mod load_mask_classifier_prefetch_timeout_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/load_memory_envelope_tests.rs"]
|
||||
mod load_memory_envelope_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -1327,6 +1519,31 @@ mod tests {
|
||||
cfg.server.api.runtime_edge_events_capacity,
|
||||
default_api_runtime_edge_events_capacity()
|
||||
);
|
||||
assert_eq!(
|
||||
cfg.server.conntrack_control.inline_conntrack_control,
|
||||
default_conntrack_control_enabled()
|
||||
);
|
||||
assert_eq!(cfg.server.conntrack_control.mode, ConntrackMode::default());
|
||||
assert_eq!(
|
||||
cfg.server.conntrack_control.backend,
|
||||
ConntrackBackend::default()
|
||||
);
|
||||
assert_eq!(
|
||||
cfg.server.conntrack_control.profile,
|
||||
ConntrackPressureProfile::default()
|
||||
);
|
||||
assert_eq!(
|
||||
cfg.server.conntrack_control.pressure_high_watermark_pct,
|
||||
default_conntrack_pressure_high_watermark_pct()
|
||||
);
|
||||
assert_eq!(
|
||||
cfg.server.conntrack_control.pressure_low_watermark_pct,
|
||||
default_conntrack_pressure_low_watermark_pct()
|
||||
);
|
||||
assert_eq!(
|
||||
cfg.server.conntrack_control.delete_budget_per_sec,
|
||||
default_conntrack_delete_budget_per_sec()
|
||||
);
|
||||
assert_eq!(cfg.access.users, default_access_users());
|
||||
assert_eq!(
|
||||
cfg.access.user_max_tcp_conns_global_each,
|
||||
@@ -1472,6 +1689,31 @@ mod tests {
|
||||
server.api.runtime_edge_events_capacity,
|
||||
default_api_runtime_edge_events_capacity()
|
||||
);
|
||||
assert_eq!(
|
||||
server.conntrack_control.inline_conntrack_control,
|
||||
default_conntrack_control_enabled()
|
||||
);
|
||||
assert_eq!(server.conntrack_control.mode, ConntrackMode::default());
|
||||
assert_eq!(
|
||||
server.conntrack_control.backend,
|
||||
ConntrackBackend::default()
|
||||
);
|
||||
assert_eq!(
|
||||
server.conntrack_control.profile,
|
||||
ConntrackPressureProfile::default()
|
||||
);
|
||||
assert_eq!(
|
||||
server.conntrack_control.pressure_high_watermark_pct,
|
||||
default_conntrack_pressure_high_watermark_pct()
|
||||
);
|
||||
assert_eq!(
|
||||
server.conntrack_control.pressure_low_watermark_pct,
|
||||
default_conntrack_pressure_low_watermark_pct()
|
||||
);
|
||||
assert_eq!(
|
||||
server.conntrack_control.delete_budget_per_sec,
|
||||
default_conntrack_delete_budget_per_sec()
|
||||
);
|
||||
|
||||
let access = AccessConfig::default();
|
||||
assert_eq!(access.users, default_access_users());
|
||||
@@ -1548,6 +1790,22 @@ mod tests {
|
||||
cfg_mask.censorship.unknown_sni_action,
|
||||
UnknownSniAction::Mask
|
||||
);
|
||||
|
||||
let cfg_accept: ProxyConfig = toml::from_str(
|
||||
r#"
|
||||
[server]
|
||||
[general]
|
||||
[network]
|
||||
[access]
|
||||
[censorship]
|
||||
unknown_sni_action = "accept"
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
cfg_accept.censorship.unknown_sni_action,
|
||||
UnknownSniAction::Accept
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -2404,6 +2662,118 @@ mod tests {
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conntrack_pressure_high_watermark_out_of_range_is_rejected() {
|
||||
let toml = r#"
|
||||
[server.conntrack_control]
|
||||
pressure_high_watermark_pct = 0
|
||||
|
||||
[censorship]
|
||||
tls_domain = "example.com"
|
||||
|
||||
[access.users]
|
||||
user = "00000000000000000000000000000000"
|
||||
"#;
|
||||
let dir = std::env::temp_dir();
|
||||
let path = dir.join("telemt_conntrack_high_watermark_invalid_test.toml");
|
||||
std::fs::write(&path, toml).unwrap();
|
||||
let err = ProxyConfig::load(&path).unwrap_err().to_string();
|
||||
assert!(err.contains(
|
||||
"server.conntrack_control.pressure_high_watermark_pct must be within [1, 100]"
|
||||
));
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conntrack_pressure_low_watermark_must_be_below_high() {
|
||||
let toml = r#"
|
||||
[server.conntrack_control]
|
||||
pressure_high_watermark_pct = 50
|
||||
pressure_low_watermark_pct = 50
|
||||
|
||||
[censorship]
|
||||
tls_domain = "example.com"
|
||||
|
||||
[access.users]
|
||||
user = "00000000000000000000000000000000"
|
||||
"#;
|
||||
let dir = std::env::temp_dir();
|
||||
let path = dir.join("telemt_conntrack_low_watermark_invalid_test.toml");
|
||||
std::fs::write(&path, toml).unwrap();
|
||||
let err = ProxyConfig::load(&path).unwrap_err().to_string();
|
||||
assert!(
|
||||
err.contains(
|
||||
"server.conntrack_control.pressure_low_watermark_pct must be < pressure_high_watermark_pct"
|
||||
)
|
||||
);
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conntrack_delete_budget_zero_is_rejected() {
|
||||
let toml = r#"
|
||||
[server.conntrack_control]
|
||||
delete_budget_per_sec = 0
|
||||
|
||||
[censorship]
|
||||
tls_domain = "example.com"
|
||||
|
||||
[access.users]
|
||||
user = "00000000000000000000000000000000"
|
||||
"#;
|
||||
let dir = std::env::temp_dir();
|
||||
let path = dir.join("telemt_conntrack_delete_budget_invalid_test.toml");
|
||||
std::fs::write(&path, toml).unwrap();
|
||||
let err = ProxyConfig::load(&path).unwrap_err().to_string();
|
||||
assert!(err.contains("server.conntrack_control.delete_budget_per_sec must be > 0"));
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conntrack_hybrid_mode_requires_listener_allow_list() {
|
||||
let toml = r#"
|
||||
[server.conntrack_control]
|
||||
mode = "hybrid"
|
||||
|
||||
[censorship]
|
||||
tls_domain = "example.com"
|
||||
|
||||
[access.users]
|
||||
user = "00000000000000000000000000000000"
|
||||
"#;
|
||||
let dir = std::env::temp_dir();
|
||||
let path = dir.join("telemt_conntrack_hybrid_requires_ips_test.toml");
|
||||
std::fs::write(&path, toml).unwrap();
|
||||
let err = ProxyConfig::load(&path).unwrap_err().to_string();
|
||||
assert!(err.contains(
|
||||
"server.conntrack_control.hybrid_listener_ips must be non-empty in mode=hybrid"
|
||||
));
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conntrack_profile_is_loaded_from_config() {
|
||||
let toml = r#"
|
||||
[server.conntrack_control]
|
||||
profile = "aggressive"
|
||||
|
||||
[censorship]
|
||||
tls_domain = "example.com"
|
||||
|
||||
[access.users]
|
||||
user = "00000000000000000000000000000000"
|
||||
"#;
|
||||
let dir = std::env::temp_dir();
|
||||
let path = dir.join("telemt_conntrack_profile_parse_test.toml");
|
||||
std::fs::write(&path, toml).unwrap();
|
||||
let cfg = ProxyConfig::load(&path).unwrap();
|
||||
assert_eq!(
|
||||
cfg.server.conntrack_control.profile,
|
||||
ConntrackPressureProfile::Aggressive
|
||||
);
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn force_close_default_matches_drain_ttl() {
|
||||
let toml = r#"
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
use super::*;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
fn write_temp_config(contents: &str) -> PathBuf {
|
||||
let nonce = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.expect("system time must be after unix epoch")
|
||||
.as_nanos();
|
||||
let path = std::env::temp_dir().join(format!("telemt-load-memory-envelope-{nonce}.toml"));
|
||||
fs::write(&path, contents).expect("temp config write must succeed");
|
||||
path
|
||||
}
|
||||
|
||||
fn remove_temp_config(path: &PathBuf) {
|
||||
let _ = fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_rejects_writer_cmd_capacity_above_upper_bound() {
|
||||
let path = write_temp_config(
|
||||
r#"
|
||||
[general]
|
||||
me_writer_cmd_channel_capacity = 16385
|
||||
"#,
|
||||
);
|
||||
|
||||
let err = ProxyConfig::load(&path).expect_err("writer command capacity above hard cap must fail");
|
||||
let msg = err.to_string();
|
||||
assert!(
|
||||
msg.contains("general.me_writer_cmd_channel_capacity must be within [1, 16384]"),
|
||||
"error must explain writer command capacity hard cap, got: {msg}"
|
||||
);
|
||||
|
||||
remove_temp_config(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_rejects_route_channel_capacity_above_upper_bound() {
|
||||
let path = write_temp_config(
|
||||
r#"
|
||||
[general]
|
||||
me_route_channel_capacity = 8193
|
||||
"#,
|
||||
);
|
||||
|
||||
let err = ProxyConfig::load(&path).expect_err("route channel capacity above hard cap must fail");
|
||||
let msg = err.to_string();
|
||||
assert!(
|
||||
msg.contains("general.me_route_channel_capacity must be within [1, 8192]"),
|
||||
"error must explain route channel hard cap, got: {msg}"
|
||||
);
|
||||
|
||||
remove_temp_config(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_rejects_c2me_channel_capacity_above_upper_bound() {
|
||||
let path = write_temp_config(
|
||||
r#"
|
||||
[general]
|
||||
me_c2me_channel_capacity = 8193
|
||||
"#,
|
||||
);
|
||||
|
||||
let err = ProxyConfig::load(&path).expect_err("c2me channel capacity above hard cap must fail");
|
||||
let msg = err.to_string();
|
||||
assert!(
|
||||
msg.contains("general.me_c2me_channel_capacity must be within [1, 8192]"),
|
||||
"error must explain c2me channel hard cap, got: {msg}"
|
||||
);
|
||||
|
||||
remove_temp_config(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_rejects_max_client_frame_above_upper_bound() {
|
||||
let path = write_temp_config(
|
||||
r#"
|
||||
[general]
|
||||
max_client_frame = 16777217
|
||||
"#,
|
||||
);
|
||||
|
||||
let err = ProxyConfig::load(&path).expect_err("max_client_frame above hard cap must fail");
|
||||
let msg = err.to_string();
|
||||
assert!(
|
||||
msg.contains("general.max_client_frame must be within [4096, 16777216]"),
|
||||
"error must explain max_client_frame hard cap, got: {msg}"
|
||||
);
|
||||
|
||||
remove_temp_config(&path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_accepts_memory_limits_at_hard_upper_bounds() {
|
||||
let path = write_temp_config(
|
||||
r#"
|
||||
[general]
|
||||
me_writer_cmd_channel_capacity = 16384
|
||||
me_route_channel_capacity = 8192
|
||||
me_c2me_channel_capacity = 8192
|
||||
max_client_frame = 16777216
|
||||
"#,
|
||||
);
|
||||
|
||||
let cfg = ProxyConfig::load(&path).expect("hard upper bound values must be accepted");
|
||||
assert_eq!(cfg.general.me_writer_cmd_channel_capacity, 16384);
|
||||
assert_eq!(cfg.general.me_route_channel_capacity, 8192);
|
||||
assert_eq!(cfg.general.me_c2me_channel_capacity, 8192);
|
||||
assert_eq!(cfg.general.max_client_frame, 16 * 1024 * 1024);
|
||||
|
||||
remove_temp_config(&path);
|
||||
}
|
||||
@@ -1216,6 +1216,118 @@ impl Default for ApiConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ConntrackMode {
|
||||
#[default]
|
||||
Tracked,
|
||||
Notrack,
|
||||
Hybrid,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ConntrackBackend {
|
||||
#[default]
|
||||
Auto,
|
||||
Nftables,
|
||||
Iptables,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ConntrackPressureProfile {
|
||||
Conservative,
|
||||
#[default]
|
||||
Balanced,
|
||||
Aggressive,
|
||||
}
|
||||
|
||||
impl ConntrackPressureProfile {
|
||||
pub fn client_first_byte_idle_cap_secs(self) -> u64 {
|
||||
match self {
|
||||
Self::Conservative => 30,
|
||||
Self::Balanced => 20,
|
||||
Self::Aggressive => 10,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn direct_activity_timeout_secs(self) -> u64 {
|
||||
match self {
|
||||
Self::Conservative => 180,
|
||||
Self::Balanced => 120,
|
||||
Self::Aggressive => 60,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn middle_soft_idle_cap_secs(self) -> u64 {
|
||||
match self {
|
||||
Self::Conservative => 60,
|
||||
Self::Balanced => 30,
|
||||
Self::Aggressive => 20,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn middle_hard_idle_cap_secs(self) -> u64 {
|
||||
match self {
|
||||
Self::Conservative => 180,
|
||||
Self::Balanced => 90,
|
||||
Self::Aggressive => 60,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ConntrackControlConfig {
|
||||
/// Enables runtime conntrack-control worker for pressure mitigation.
|
||||
#[serde(default = "default_conntrack_control_enabled")]
|
||||
pub inline_conntrack_control: bool,
|
||||
|
||||
/// Conntrack mode for listener ingress traffic.
|
||||
#[serde(default)]
|
||||
pub mode: ConntrackMode,
|
||||
|
||||
/// Netfilter backend used to reconcile notrack rules.
|
||||
#[serde(default)]
|
||||
pub backend: ConntrackBackend,
|
||||
|
||||
/// Pressure profile for timeout caps under resource saturation.
|
||||
#[serde(default)]
|
||||
pub profile: ConntrackPressureProfile,
|
||||
|
||||
/// Listener IP allow-list for hybrid mode.
|
||||
/// Ignored in tracked/notrack mode.
|
||||
#[serde(default)]
|
||||
pub hybrid_listener_ips: Vec<IpAddr>,
|
||||
|
||||
/// Pressure high watermark as percentage.
|
||||
#[serde(default = "default_conntrack_pressure_high_watermark_pct")]
|
||||
pub pressure_high_watermark_pct: u8,
|
||||
|
||||
/// Pressure low watermark as percentage.
|
||||
#[serde(default = "default_conntrack_pressure_low_watermark_pct")]
|
||||
pub pressure_low_watermark_pct: u8,
|
||||
|
||||
/// Maximum conntrack delete operations per second.
|
||||
#[serde(default = "default_conntrack_delete_budget_per_sec")]
|
||||
pub delete_budget_per_sec: u64,
|
||||
}
|
||||
|
||||
impl Default for ConntrackControlConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
inline_conntrack_control: default_conntrack_control_enabled(),
|
||||
mode: ConntrackMode::default(),
|
||||
backend: ConntrackBackend::default(),
|
||||
profile: ConntrackPressureProfile::default(),
|
||||
hybrid_listener_ips: Vec::new(),
|
||||
pressure_high_watermark_pct: default_conntrack_pressure_high_watermark_pct(),
|
||||
pressure_low_watermark_pct: default_conntrack_pressure_low_watermark_pct(),
|
||||
delete_budget_per_sec: default_conntrack_delete_budget_per_sec(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ServerConfig {
|
||||
#[serde(default = "default_port")]
|
||||
@@ -1291,6 +1403,10 @@ pub struct ServerConfig {
|
||||
/// `0` keeps legacy unbounded wait behavior.
|
||||
#[serde(default = "default_accept_permit_timeout_ms")]
|
||||
pub accept_permit_timeout_ms: u64,
|
||||
|
||||
/// Runtime conntrack control and pressure policy.
|
||||
#[serde(default)]
|
||||
pub conntrack_control: ConntrackControlConfig,
|
||||
}
|
||||
|
||||
impl Default for ServerConfig {
|
||||
@@ -1313,6 +1429,7 @@ impl Default for ServerConfig {
|
||||
listen_backlog: default_listen_backlog(),
|
||||
max_connections: default_server_max_connections(),
|
||||
accept_permit_timeout_ms: default_accept_permit_timeout_ms(),
|
||||
conntrack_control: ConntrackControlConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1385,6 +1502,7 @@ pub enum UnknownSniAction {
|
||||
#[default]
|
||||
Drop,
|
||||
Mask,
|
||||
Accept,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
|
||||
@@ -0,0 +1,755 @@
|
||||
use std::collections::BTreeSet;
|
||||
use std::net::IpAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::process::Command;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::config::{ConntrackBackend, ConntrackMode, ProxyConfig};
|
||||
use crate::proxy::middle_relay::note_global_relay_pressure;
|
||||
use crate::proxy::shared_state::{ConntrackCloseEvent, ConntrackCloseReason, ProxySharedState};
|
||||
use crate::stats::Stats;
|
||||
|
||||
const CONNTRACK_EVENT_QUEUE_CAPACITY: usize = 32_768;
|
||||
const PRESSURE_RELEASE_TICKS: u8 = 3;
|
||||
const PRESSURE_SAMPLE_INTERVAL: Duration = Duration::from_secs(1);
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
enum NetfilterBackend {
|
||||
Nftables,
|
||||
Iptables,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct PressureSample {
|
||||
conn_pct: Option<u8>,
|
||||
fd_pct: Option<u8>,
|
||||
accept_timeout_delta: u64,
|
||||
me_queue_pressure_delta: u64,
|
||||
}
|
||||
|
||||
struct PressureState {
|
||||
active: bool,
|
||||
low_streak: u8,
|
||||
prev_accept_timeout_total: u64,
|
||||
prev_me_queue_pressure_total: u64,
|
||||
}
|
||||
|
||||
impl PressureState {
|
||||
fn new(stats: &Stats) -> Self {
|
||||
Self {
|
||||
active: false,
|
||||
low_streak: 0,
|
||||
prev_accept_timeout_total: stats.get_accept_permit_timeout_total(),
|
||||
prev_me_queue_pressure_total: stats.get_me_c2me_send_full_total(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn spawn_conntrack_controller(
|
||||
config_rx: watch::Receiver<Arc<ProxyConfig>>,
|
||||
stats: Arc<Stats>,
|
||||
shared: Arc<ProxySharedState>,
|
||||
) {
|
||||
if !cfg!(target_os = "linux") {
|
||||
let enabled = config_rx
|
||||
.borrow()
|
||||
.server
|
||||
.conntrack_control
|
||||
.inline_conntrack_control;
|
||||
stats.set_conntrack_control_enabled(enabled);
|
||||
stats.set_conntrack_control_available(false);
|
||||
stats.set_conntrack_pressure_active(false);
|
||||
stats.set_conntrack_event_queue_depth(0);
|
||||
stats.set_conntrack_rule_apply_ok(false);
|
||||
shared.disable_conntrack_close_sender();
|
||||
shared.set_conntrack_pressure_active(false);
|
||||
if enabled {
|
||||
warn!(
|
||||
"conntrack control is configured but unsupported on this OS; disabling runtime worker"
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let (tx, rx) = mpsc::channel(CONNTRACK_EVENT_QUEUE_CAPACITY);
|
||||
shared.set_conntrack_close_sender(tx);
|
||||
tokio::spawn(async move {
|
||||
run_conntrack_controller(config_rx, stats, shared, rx).await;
|
||||
});
|
||||
}
|
||||
|
||||
async fn run_conntrack_controller(
|
||||
mut config_rx: watch::Receiver<Arc<ProxyConfig>>,
|
||||
stats: Arc<Stats>,
|
||||
shared: Arc<ProxySharedState>,
|
||||
mut close_rx: mpsc::Receiver<ConntrackCloseEvent>,
|
||||
) {
|
||||
let mut cfg = config_rx.borrow().clone();
|
||||
let mut pressure_state = PressureState::new(stats.as_ref());
|
||||
let mut delete_budget_tokens = cfg.server.conntrack_control.delete_budget_per_sec;
|
||||
let mut backend = pick_backend(cfg.server.conntrack_control.backend);
|
||||
|
||||
apply_runtime_state(
|
||||
stats.as_ref(),
|
||||
shared.as_ref(),
|
||||
&cfg,
|
||||
backend.is_some(),
|
||||
false,
|
||||
);
|
||||
reconcile_rules(&cfg, backend, stats.as_ref()).await;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
changed = config_rx.changed() => {
|
||||
if changed.is_err() {
|
||||
break;
|
||||
}
|
||||
cfg = config_rx.borrow_and_update().clone();
|
||||
backend = pick_backend(cfg.server.conntrack_control.backend);
|
||||
delete_budget_tokens = cfg.server.conntrack_control.delete_budget_per_sec;
|
||||
apply_runtime_state(stats.as_ref(), shared.as_ref(), &cfg, backend.is_some(), pressure_state.active);
|
||||
reconcile_rules(&cfg, backend, stats.as_ref()).await;
|
||||
}
|
||||
event = close_rx.recv() => {
|
||||
let Some(event) = event else {
|
||||
break;
|
||||
};
|
||||
stats.set_conntrack_event_queue_depth(close_rx.len() as u64);
|
||||
if !cfg.server.conntrack_control.inline_conntrack_control {
|
||||
continue;
|
||||
}
|
||||
if !pressure_state.active {
|
||||
continue;
|
||||
}
|
||||
if !matches!(event.reason, ConntrackCloseReason::Timeout | ConntrackCloseReason::Pressure | ConntrackCloseReason::Reset) {
|
||||
continue;
|
||||
}
|
||||
if delete_budget_tokens == 0 {
|
||||
continue;
|
||||
}
|
||||
stats.increment_conntrack_delete_attempt_total();
|
||||
match delete_conntrack_entry(event).await {
|
||||
DeleteOutcome::Deleted => {
|
||||
delete_budget_tokens = delete_budget_tokens.saturating_sub(1);
|
||||
stats.increment_conntrack_delete_success_total();
|
||||
}
|
||||
DeleteOutcome::NotFound => {
|
||||
delete_budget_tokens = delete_budget_tokens.saturating_sub(1);
|
||||
stats.increment_conntrack_delete_not_found_total();
|
||||
}
|
||||
DeleteOutcome::Error => {
|
||||
delete_budget_tokens = delete_budget_tokens.saturating_sub(1);
|
||||
stats.increment_conntrack_delete_error_total();
|
||||
}
|
||||
}
|
||||
}
|
||||
_ = tokio::time::sleep(PRESSURE_SAMPLE_INTERVAL) => {
|
||||
delete_budget_tokens = cfg.server.conntrack_control.delete_budget_per_sec;
|
||||
stats.set_conntrack_event_queue_depth(close_rx.len() as u64);
|
||||
let sample = collect_pressure_sample(stats.as_ref(), &cfg, &mut pressure_state);
|
||||
update_pressure_state(
|
||||
stats.as_ref(),
|
||||
shared.as_ref(),
|
||||
&cfg,
|
||||
&sample,
|
||||
&mut pressure_state,
|
||||
);
|
||||
if pressure_state.active {
|
||||
note_global_relay_pressure(shared.as_ref());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
shared.disable_conntrack_close_sender();
|
||||
shared.set_conntrack_pressure_active(false);
|
||||
stats.set_conntrack_pressure_active(false);
|
||||
}
|
||||
|
||||
fn apply_runtime_state(
|
||||
stats: &Stats,
|
||||
shared: &ProxySharedState,
|
||||
cfg: &ProxyConfig,
|
||||
backend_available: bool,
|
||||
pressure_active: bool,
|
||||
) {
|
||||
let enabled = cfg.server.conntrack_control.inline_conntrack_control;
|
||||
let available = enabled && backend_available && has_cap_net_admin();
|
||||
if enabled && !available {
|
||||
warn!(
|
||||
"conntrack control enabled but unavailable (missing CAP_NET_ADMIN or backend binaries)"
|
||||
);
|
||||
}
|
||||
stats.set_conntrack_control_enabled(enabled);
|
||||
stats.set_conntrack_control_available(available);
|
||||
shared.set_conntrack_pressure_active(enabled && pressure_active);
|
||||
stats.set_conntrack_pressure_active(enabled && pressure_active);
|
||||
}
|
||||
|
||||
fn collect_pressure_sample(
|
||||
stats: &Stats,
|
||||
cfg: &ProxyConfig,
|
||||
state: &mut PressureState,
|
||||
) -> PressureSample {
|
||||
let current_connections = stats.get_current_connections_total();
|
||||
let conn_pct = if cfg.server.max_connections == 0 {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
((current_connections.saturating_mul(100)) / u64::from(cfg.server.max_connections))
|
||||
.min(100) as u8,
|
||||
)
|
||||
};
|
||||
|
||||
let fd_pct = fd_usage_pct();
|
||||
|
||||
let accept_total = stats.get_accept_permit_timeout_total();
|
||||
let accept_delta = accept_total.saturating_sub(state.prev_accept_timeout_total);
|
||||
state.prev_accept_timeout_total = accept_total;
|
||||
|
||||
let me_total = stats.get_me_c2me_send_full_total();
|
||||
let me_delta = me_total.saturating_sub(state.prev_me_queue_pressure_total);
|
||||
state.prev_me_queue_pressure_total = me_total;
|
||||
|
||||
PressureSample {
|
||||
conn_pct,
|
||||
fd_pct,
|
||||
accept_timeout_delta: accept_delta,
|
||||
me_queue_pressure_delta: me_delta,
|
||||
}
|
||||
}
|
||||
|
||||
fn update_pressure_state(
|
||||
stats: &Stats,
|
||||
shared: &ProxySharedState,
|
||||
cfg: &ProxyConfig,
|
||||
sample: &PressureSample,
|
||||
state: &mut PressureState,
|
||||
) {
|
||||
if !cfg.server.conntrack_control.inline_conntrack_control {
|
||||
if state.active {
|
||||
state.active = false;
|
||||
state.low_streak = 0;
|
||||
shared.set_conntrack_pressure_active(false);
|
||||
stats.set_conntrack_pressure_active(false);
|
||||
info!("Conntrack pressure mode deactivated (feature disabled)");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let high = cfg.server.conntrack_control.pressure_high_watermark_pct;
|
||||
let low = cfg.server.conntrack_control.pressure_low_watermark_pct;
|
||||
|
||||
let high_hit = sample.conn_pct.is_some_and(|v| v >= high)
|
||||
|| sample.fd_pct.is_some_and(|v| v >= high)
|
||||
|| sample.accept_timeout_delta > 0
|
||||
|| sample.me_queue_pressure_delta > 0;
|
||||
|
||||
let low_clear = sample.conn_pct.is_none_or(|v| v <= low)
|
||||
&& sample.fd_pct.is_none_or(|v| v <= low)
|
||||
&& sample.accept_timeout_delta == 0
|
||||
&& sample.me_queue_pressure_delta == 0;
|
||||
|
||||
if !state.active && high_hit {
|
||||
state.active = true;
|
||||
state.low_streak = 0;
|
||||
shared.set_conntrack_pressure_active(true);
|
||||
stats.set_conntrack_pressure_active(true);
|
||||
info!(
|
||||
conn_pct = ?sample.conn_pct,
|
||||
fd_pct = ?sample.fd_pct,
|
||||
accept_timeout_delta = sample.accept_timeout_delta,
|
||||
me_queue_pressure_delta = sample.me_queue_pressure_delta,
|
||||
"Conntrack pressure mode activated"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if state.active && low_clear {
|
||||
state.low_streak = state.low_streak.saturating_add(1);
|
||||
if state.low_streak >= PRESSURE_RELEASE_TICKS {
|
||||
state.active = false;
|
||||
state.low_streak = 0;
|
||||
shared.set_conntrack_pressure_active(false);
|
||||
stats.set_conntrack_pressure_active(false);
|
||||
info!("Conntrack pressure mode deactivated");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
state.low_streak = 0;
|
||||
}
|
||||
|
||||
async fn reconcile_rules(cfg: &ProxyConfig, backend: Option<NetfilterBackend>, stats: &Stats) {
|
||||
if !cfg.server.conntrack_control.inline_conntrack_control {
|
||||
clear_notrack_rules_all_backends().await;
|
||||
stats.set_conntrack_rule_apply_ok(true);
|
||||
return;
|
||||
}
|
||||
|
||||
if !has_cap_net_admin() {
|
||||
stats.set_conntrack_rule_apply_ok(false);
|
||||
return;
|
||||
}
|
||||
|
||||
let Some(backend) = backend else {
|
||||
stats.set_conntrack_rule_apply_ok(false);
|
||||
return;
|
||||
};
|
||||
|
||||
let apply_result = match backend {
|
||||
NetfilterBackend::Nftables => apply_nft_rules(cfg).await,
|
||||
NetfilterBackend::Iptables => apply_iptables_rules(cfg).await,
|
||||
};
|
||||
|
||||
if let Err(error) = apply_result {
|
||||
warn!(error = %error, "Failed to reconcile conntrack/notrack rules");
|
||||
stats.set_conntrack_rule_apply_ok(false);
|
||||
} else {
|
||||
stats.set_conntrack_rule_apply_ok(true);
|
||||
}
|
||||
}
|
||||
|
||||
fn pick_backend(configured: ConntrackBackend) -> Option<NetfilterBackend> {
|
||||
match configured {
|
||||
ConntrackBackend::Auto => {
|
||||
if command_exists("nft") {
|
||||
Some(NetfilterBackend::Nftables)
|
||||
} else if command_exists("iptables") {
|
||||
Some(NetfilterBackend::Iptables)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
ConntrackBackend::Nftables => command_exists("nft").then_some(NetfilterBackend::Nftables),
|
||||
ConntrackBackend::Iptables => {
|
||||
command_exists("iptables").then_some(NetfilterBackend::Iptables)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn command_exists(binary: &str) -> bool {
|
||||
let Some(path_var) = std::env::var_os("PATH") else {
|
||||
return false;
|
||||
};
|
||||
std::env::split_paths(&path_var).any(|dir| {
|
||||
let candidate: PathBuf = dir.join(binary);
|
||||
candidate.exists() && candidate.is_file()
|
||||
})
|
||||
}
|
||||
|
||||
fn notrack_targets(cfg: &ProxyConfig) -> (Vec<Option<IpAddr>>, Vec<Option<IpAddr>>) {
|
||||
let mode = cfg.server.conntrack_control.mode;
|
||||
let mut v4_targets: BTreeSet<Option<IpAddr>> = BTreeSet::new();
|
||||
let mut v6_targets: BTreeSet<Option<IpAddr>> = BTreeSet::new();
|
||||
|
||||
match mode {
|
||||
ConntrackMode::Tracked => {}
|
||||
ConntrackMode::Notrack => {
|
||||
if cfg.server.listeners.is_empty() {
|
||||
if let Some(ipv4) = cfg
|
||||
.server
|
||||
.listen_addr_ipv4
|
||||
.as_ref()
|
||||
.and_then(|s| s.parse::<IpAddr>().ok())
|
||||
{
|
||||
if ipv4.is_unspecified() {
|
||||
v4_targets.insert(None);
|
||||
} else {
|
||||
v4_targets.insert(Some(ipv4));
|
||||
}
|
||||
}
|
||||
if let Some(ipv6) = cfg
|
||||
.server
|
||||
.listen_addr_ipv6
|
||||
.as_ref()
|
||||
.and_then(|s| s.parse::<IpAddr>().ok())
|
||||
{
|
||||
if ipv6.is_unspecified() {
|
||||
v6_targets.insert(None);
|
||||
} else {
|
||||
v6_targets.insert(Some(ipv6));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for listener in &cfg.server.listeners {
|
||||
if listener.ip.is_ipv4() {
|
||||
if listener.ip.is_unspecified() {
|
||||
v4_targets.insert(None);
|
||||
} else {
|
||||
v4_targets.insert(Some(listener.ip));
|
||||
}
|
||||
} else if listener.ip.is_unspecified() {
|
||||
v6_targets.insert(None);
|
||||
} else {
|
||||
v6_targets.insert(Some(listener.ip));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ConntrackMode::Hybrid => {
|
||||
for ip in &cfg.server.conntrack_control.hybrid_listener_ips {
|
||||
if ip.is_ipv4() {
|
||||
v4_targets.insert(Some(*ip));
|
||||
} else {
|
||||
v6_targets.insert(Some(*ip));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(
|
||||
v4_targets.into_iter().collect(),
|
||||
v6_targets.into_iter().collect(),
|
||||
)
|
||||
}
|
||||
|
||||
async fn apply_nft_rules(cfg: &ProxyConfig) -> Result<(), String> {
|
||||
let _ = run_command(
|
||||
"nft",
|
||||
&["delete", "table", "inet", "telemt_conntrack"],
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
if matches!(cfg.server.conntrack_control.mode, ConntrackMode::Tracked) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let (v4_targets, v6_targets) = notrack_targets(cfg);
|
||||
let mut rules = Vec::new();
|
||||
for ip in v4_targets {
|
||||
let rule = if let Some(ip) = ip {
|
||||
format!("tcp dport {} ip daddr {} notrack", cfg.server.port, ip)
|
||||
} else {
|
||||
format!("tcp dport {} notrack", cfg.server.port)
|
||||
};
|
||||
rules.push(rule);
|
||||
}
|
||||
for ip in v6_targets {
|
||||
let rule = if let Some(ip) = ip {
|
||||
format!("tcp dport {} ip6 daddr {} notrack", cfg.server.port, ip)
|
||||
} else {
|
||||
format!("tcp dport {} notrack", cfg.server.port)
|
||||
};
|
||||
rules.push(rule);
|
||||
}
|
||||
|
||||
let rule_blob = if rules.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" {}\n", rules.join("\n "))
|
||||
};
|
||||
let script = format!(
|
||||
"table inet telemt_conntrack {{\n chain preraw {{\n type filter hook prerouting priority raw; policy accept;\n{rule_blob} }}\n}}\n"
|
||||
);
|
||||
run_command("nft", &["-f", "-"], Some(script)).await
|
||||
}
|
||||
|
||||
async fn apply_iptables_rules(cfg: &ProxyConfig) -> Result<(), String> {
|
||||
apply_iptables_rules_for_binary("iptables", cfg, true).await?;
|
||||
apply_iptables_rules_for_binary("ip6tables", cfg, false).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn apply_iptables_rules_for_binary(
|
||||
binary: &str,
|
||||
cfg: &ProxyConfig,
|
||||
ipv4: bool,
|
||||
) -> Result<(), String> {
|
||||
if !command_exists(binary) {
|
||||
return Ok(());
|
||||
}
|
||||
let chain = "TELEMT_NOTRACK";
|
||||
let _ = run_command(
|
||||
binary,
|
||||
&["-t", "raw", "-D", "PREROUTING", "-j", chain],
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
let _ = run_command(binary, &["-t", "raw", "-F", chain], None).await;
|
||||
let _ = run_command(binary, &["-t", "raw", "-X", chain], None).await;
|
||||
|
||||
if matches!(cfg.server.conntrack_control.mode, ConntrackMode::Tracked) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
run_command(binary, &["-t", "raw", "-N", chain], None).await?;
|
||||
run_command(binary, &["-t", "raw", "-F", chain], None).await?;
|
||||
if run_command(
|
||||
binary,
|
||||
&["-t", "raw", "-C", "PREROUTING", "-j", chain],
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
run_command(
|
||||
binary,
|
||||
&["-t", "raw", "-I", "PREROUTING", "1", "-j", chain],
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let (v4_targets, v6_targets) = notrack_targets(cfg);
|
||||
let selected = if ipv4 { v4_targets } else { v6_targets };
|
||||
for ip in selected {
|
||||
let mut args = vec![
|
||||
"-t".to_string(),
|
||||
"raw".to_string(),
|
||||
"-A".to_string(),
|
||||
chain.to_string(),
|
||||
"-p".to_string(),
|
||||
"tcp".to_string(),
|
||||
"--dport".to_string(),
|
||||
cfg.server.port.to_string(),
|
||||
];
|
||||
if let Some(ip) = ip {
|
||||
args.push("-d".to_string());
|
||||
args.push(ip.to_string());
|
||||
}
|
||||
args.push("-j".to_string());
|
||||
args.push("CT".to_string());
|
||||
args.push("--notrack".to_string());
|
||||
let arg_refs: Vec<&str> = args.iter().map(String::as_str).collect();
|
||||
run_command(binary, &arg_refs, None).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn clear_notrack_rules_all_backends() {
|
||||
let _ = run_command(
|
||||
"nft",
|
||||
&["delete", "table", "inet", "telemt_conntrack"],
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
let _ = run_command(
|
||||
"iptables",
|
||||
&["-t", "raw", "-D", "PREROUTING", "-j", "TELEMT_NOTRACK"],
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
let _ = run_command("iptables", &["-t", "raw", "-F", "TELEMT_NOTRACK"], None).await;
|
||||
let _ = run_command("iptables", &["-t", "raw", "-X", "TELEMT_NOTRACK"], None).await;
|
||||
let _ = run_command(
|
||||
"ip6tables",
|
||||
&["-t", "raw", "-D", "PREROUTING", "-j", "TELEMT_NOTRACK"],
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
let _ = run_command("ip6tables", &["-t", "raw", "-F", "TELEMT_NOTRACK"], None).await;
|
||||
let _ = run_command("ip6tables", &["-t", "raw", "-X", "TELEMT_NOTRACK"], None).await;
|
||||
}
|
||||
|
||||
enum DeleteOutcome {
|
||||
Deleted,
|
||||
NotFound,
|
||||
Error,
|
||||
}
|
||||
|
||||
async fn delete_conntrack_entry(event: ConntrackCloseEvent) -> DeleteOutcome {
|
||||
if !command_exists("conntrack") {
|
||||
return DeleteOutcome::Error;
|
||||
}
|
||||
let args = vec![
|
||||
"-D".to_string(),
|
||||
"-p".to_string(),
|
||||
"tcp".to_string(),
|
||||
"-s".to_string(),
|
||||
event.src.ip().to_string(),
|
||||
"--sport".to_string(),
|
||||
event.src.port().to_string(),
|
||||
"-d".to_string(),
|
||||
event.dst.ip().to_string(),
|
||||
"--dport".to_string(),
|
||||
event.dst.port().to_string(),
|
||||
];
|
||||
let arg_refs: Vec<&str> = args.iter().map(String::as_str).collect();
|
||||
match run_command("conntrack", &arg_refs, None).await {
|
||||
Ok(()) => DeleteOutcome::Deleted,
|
||||
Err(error) => {
|
||||
if error.contains("0 flow entries have been deleted") {
|
||||
DeleteOutcome::NotFound
|
||||
} else {
|
||||
debug!(error = %error, "conntrack delete failed");
|
||||
DeleteOutcome::Error
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_command(binary: &str, args: &[&str], stdin: Option<String>) -> Result<(), String> {
|
||||
if !command_exists(binary) {
|
||||
return Err(format!("{binary} is not available"));
|
||||
}
|
||||
let mut command = Command::new(binary);
|
||||
command.args(args);
|
||||
if stdin.is_some() {
|
||||
command.stdin(std::process::Stdio::piped());
|
||||
}
|
||||
command.stdout(std::process::Stdio::null());
|
||||
command.stderr(std::process::Stdio::piped());
|
||||
let mut child = command
|
||||
.spawn()
|
||||
.map_err(|e| format!("spawn {binary} failed: {e}"))?;
|
||||
if let Some(blob) = stdin
|
||||
&& let Some(mut writer) = child.stdin.take()
|
||||
{
|
||||
writer
|
||||
.write_all(blob.as_bytes())
|
||||
.await
|
||||
.map_err(|e| format!("stdin write {binary} failed: {e}"))?;
|
||||
}
|
||||
let output = child
|
||||
.wait_with_output()
|
||||
.await
|
||||
.map_err(|e| format!("wait {binary} failed: {e}"))?;
|
||||
if output.status.success() {
|
||||
return Ok(());
|
||||
}
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string();
|
||||
Err(if stderr.is_empty() {
|
||||
format!("{binary} exited with status {}", output.status)
|
||||
} else {
|
||||
stderr
|
||||
})
|
||||
}
|
||||
|
||||
fn fd_usage_pct() -> Option<u8> {
|
||||
let soft_limit = nofile_soft_limit()?;
|
||||
if soft_limit == 0 {
|
||||
return None;
|
||||
}
|
||||
let fd_count = std::fs::read_dir("/proc/self/fd").ok()?.count() as u64;
|
||||
Some(((fd_count.saturating_mul(100)) / soft_limit).min(100) as u8)
|
||||
}
|
||||
|
||||
fn nofile_soft_limit() -> Option<u64> {
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
let mut lim = libc::rlimit {
|
||||
rlim_cur: 0,
|
||||
rlim_max: 0,
|
||||
};
|
||||
let rc = unsafe { libc::getrlimit(libc::RLIMIT_NOFILE, &mut lim) };
|
||||
if rc != 0 {
|
||||
return None;
|
||||
}
|
||||
return Some(lim.rlim_cur);
|
||||
}
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
{
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn has_cap_net_admin() -> bool {
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
let Ok(status) = std::fs::read_to_string("/proc/self/status") else {
|
||||
return false;
|
||||
};
|
||||
for line in status.lines() {
|
||||
if let Some(raw) = line.strip_prefix("CapEff:") {
|
||||
let caps = raw.trim();
|
||||
if let Ok(bits) = u64::from_str_radix(caps, 16) {
|
||||
const CAP_NET_ADMIN_BIT: u64 = 12;
|
||||
return (bits & (1u64 << CAP_NET_ADMIN_BIT)) != 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
{
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::ProxyConfig;
|
||||
|
||||
#[test]
|
||||
fn pressure_activates_on_accept_timeout_spike() {
|
||||
let stats = Stats::new();
|
||||
let shared = ProxySharedState::new();
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.server.conntrack_control.inline_conntrack_control = true;
|
||||
let mut state = PressureState::new(&stats);
|
||||
let sample = PressureSample {
|
||||
conn_pct: Some(10),
|
||||
fd_pct: Some(10),
|
||||
accept_timeout_delta: 1,
|
||||
me_queue_pressure_delta: 0,
|
||||
};
|
||||
|
||||
update_pressure_state(&stats, shared.as_ref(), &cfg, &sample, &mut state);
|
||||
|
||||
assert!(state.active);
|
||||
assert!(shared.conntrack_pressure_active());
|
||||
assert!(stats.get_conntrack_pressure_active());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pressure_releases_after_hysteresis_window() {
|
||||
let stats = Stats::new();
|
||||
let shared = ProxySharedState::new();
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.server.conntrack_control.inline_conntrack_control = true;
|
||||
let mut state = PressureState::new(&stats);
|
||||
|
||||
let high_sample = PressureSample {
|
||||
conn_pct: Some(95),
|
||||
fd_pct: Some(95),
|
||||
accept_timeout_delta: 0,
|
||||
me_queue_pressure_delta: 0,
|
||||
};
|
||||
update_pressure_state(&stats, shared.as_ref(), &cfg, &high_sample, &mut state);
|
||||
assert!(state.active);
|
||||
|
||||
let low_sample = PressureSample {
|
||||
conn_pct: Some(10),
|
||||
fd_pct: Some(10),
|
||||
accept_timeout_delta: 0,
|
||||
me_queue_pressure_delta: 0,
|
||||
};
|
||||
update_pressure_state(&stats, shared.as_ref(), &cfg, &low_sample, &mut state);
|
||||
assert!(state.active);
|
||||
update_pressure_state(&stats, shared.as_ref(), &cfg, &low_sample, &mut state);
|
||||
assert!(state.active);
|
||||
update_pressure_state(&stats, shared.as_ref(), &cfg, &low_sample, &mut state);
|
||||
|
||||
assert!(!state.active);
|
||||
assert!(!shared.conntrack_pressure_active());
|
||||
assert!(!stats.get_conntrack_pressure_active());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pressure_does_not_activate_when_disabled() {
|
||||
let stats = Stats::new();
|
||||
let shared = ProxySharedState::new();
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.server.conntrack_control.inline_conntrack_control = false;
|
||||
let mut state = PressureState::new(&stats);
|
||||
let sample = PressureSample {
|
||||
conn_pct: Some(100),
|
||||
fd_pct: Some(100),
|
||||
accept_timeout_delta: 10,
|
||||
me_queue_pressure_delta: 10,
|
||||
};
|
||||
|
||||
update_pressure_state(&stats, shared.as_ref(), &cfg, &sample, &mut state);
|
||||
|
||||
assert!(!state.active);
|
||||
assert!(!shared.conntrack_pressure_active());
|
||||
assert!(!stats.get_conntrack_pressure_active());
|
||||
}
|
||||
}
|
||||
+45
-9
@@ -339,31 +339,35 @@ fn is_process_running(pid: i32) -> bool {
|
||||
|
||||
/// Drops privileges to the specified user and group.
|
||||
///
|
||||
/// This should be called after binding privileged ports but before
|
||||
/// entering the main event loop.
|
||||
pub fn drop_privileges(user: Option<&str>, group: Option<&str>) -> Result<(), DaemonError> {
|
||||
// Look up group first (need to do this while still root)
|
||||
/// This should be called after binding privileged ports but before entering
|
||||
/// the main event loop.
|
||||
pub fn drop_privileges(
|
||||
user: Option<&str>,
|
||||
group: Option<&str>,
|
||||
pid_file: Option<&PidFile>,
|
||||
) -> Result<(), DaemonError> {
|
||||
let target_gid = if let Some(group_name) = group {
|
||||
Some(lookup_group(group_name)?)
|
||||
} else if let Some(user_name) = user {
|
||||
// If no group specified but user is, use user's primary group
|
||||
Some(lookup_user_primary_gid(user_name)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Look up user
|
||||
let target_uid = if let Some(user_name) = user {
|
||||
Some(lookup_user(user_name)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Drop privileges: set GID first, then UID
|
||||
// (Setting UID first would prevent us from setting GID)
|
||||
if (target_uid.is_some() || target_gid.is_some())
|
||||
&& let Some(file) = pid_file.and_then(|pid| pid.file.as_ref())
|
||||
{
|
||||
unistd::fchown(file, target_uid, target_gid).map_err(DaemonError::PrivilegeDrop)?;
|
||||
}
|
||||
|
||||
if let Some(gid) = target_gid {
|
||||
unistd::setgid(gid).map_err(DaemonError::PrivilegeDrop)?;
|
||||
// Also set supplementary groups to just this one
|
||||
unistd::setgroups(&[gid]).map_err(DaemonError::PrivilegeDrop)?;
|
||||
info!(gid = gid.as_raw(), "Dropped group privileges");
|
||||
}
|
||||
@@ -371,6 +375,38 @@ pub fn drop_privileges(user: Option<&str>, group: Option<&str>) -> Result<(), Da
|
||||
if let Some(uid) = target_uid {
|
||||
unistd::setuid(uid).map_err(DaemonError::PrivilegeDrop)?;
|
||||
info!(uid = uid.as_raw(), "Dropped user privileges");
|
||||
|
||||
if uid.as_raw() != 0
|
||||
&& let Some(pid) = pid_file
|
||||
{
|
||||
let parent = pid.path.parent().unwrap_or(Path::new("."));
|
||||
let probe_path = parent.join(format!(
|
||||
".telemt_pid_probe_{}_{}",
|
||||
std::process::id(),
|
||||
getpid().as_raw()
|
||||
));
|
||||
OpenOptions::new()
|
||||
.write(true)
|
||||
.create_new(true)
|
||||
.mode(0o600)
|
||||
.open(&probe_path)
|
||||
.map_err(|e| {
|
||||
DaemonError::PidFile(format!(
|
||||
"cannot create probe in PID directory {} as uid {} (pid cleanup will fail): {}",
|
||||
parent.display(),
|
||||
uid.as_raw(),
|
||||
e
|
||||
))
|
||||
})?;
|
||||
fs::remove_file(&probe_path).map_err(|e| {
|
||||
DaemonError::PidFile(format!(
|
||||
"cannot remove probe in PID directory {} as uid {} (pid cleanup will fail): {}",
|
||||
parent.display(),
|
||||
uid.as_raw(),
|
||||
e
|
||||
))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -26,6 +26,15 @@ pub struct UserIpTracker {
|
||||
cleanup_drain_lock: Arc<AsyncMutex<()>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct UserIpTrackerMemoryStats {
|
||||
pub active_users: usize,
|
||||
pub recent_users: usize,
|
||||
pub active_entries: usize,
|
||||
pub recent_entries: usize,
|
||||
pub cleanup_queue_len: usize,
|
||||
}
|
||||
|
||||
impl UserIpTracker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
@@ -141,6 +150,13 @@ impl UserIpTracker {
|
||||
|
||||
let mut active_ips = self.active_ips.write().await;
|
||||
let mut recent_ips = self.recent_ips.write().await;
|
||||
let window = *self.limit_window.read().await;
|
||||
let now = Instant::now();
|
||||
|
||||
for user_recent in recent_ips.values_mut() {
|
||||
Self::prune_recent(user_recent, now, window);
|
||||
}
|
||||
|
||||
let mut users =
|
||||
Vec::<String>::with_capacity(active_ips.len().saturating_add(recent_ips.len()));
|
||||
users.extend(active_ips.keys().cloned());
|
||||
@@ -166,6 +182,26 @@ impl UserIpTracker {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn memory_stats(&self) -> UserIpTrackerMemoryStats {
|
||||
let cleanup_queue_len = self
|
||||
.cleanup_queue
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
.len();
|
||||
let active_ips = self.active_ips.read().await;
|
||||
let recent_ips = self.recent_ips.read().await;
|
||||
let active_entries = active_ips.values().map(HashMap::len).sum();
|
||||
let recent_entries = recent_ips.values().map(HashMap::len).sum();
|
||||
|
||||
UserIpTrackerMemoryStats {
|
||||
active_users: active_ips.len(),
|
||||
recent_users: recent_ips.len(),
|
||||
active_entries,
|
||||
recent_entries,
|
||||
cleanup_queue_len,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn set_limit_policy(&self, mode: UserMaxUniqueIpsMode, window_secs: u64) {
|
||||
{
|
||||
let mut current_mode = self.limit_mode.write().await;
|
||||
@@ -451,6 +487,7 @@ impl Default for UserIpTracker {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
fn test_ipv4(oct1: u8, oct2: u8, oct3: u8, oct4: u8) -> IpAddr {
|
||||
IpAddr::V4(Ipv4Addr::new(oct1, oct2, oct3, oct4))
|
||||
@@ -764,4 +801,54 @@ mod tests {
|
||||
tokio::time::sleep(Duration::from_millis(1100)).await;
|
||||
assert!(tracker.check_and_add("test_user", ip2).await.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_stats_reports_queue_and_entry_counts() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("test_user", 4).await;
|
||||
let ip1 = test_ipv4(10, 2, 0, 1);
|
||||
let ip2 = test_ipv4(10, 2, 0, 2);
|
||||
|
||||
tracker.check_and_add("test_user", ip1).await.unwrap();
|
||||
tracker.check_and_add("test_user", ip2).await.unwrap();
|
||||
tracker.enqueue_cleanup("test_user".to_string(), ip1);
|
||||
|
||||
let snapshot = tracker.memory_stats().await;
|
||||
assert_eq!(snapshot.active_users, 1);
|
||||
assert_eq!(snapshot.recent_users, 1);
|
||||
assert_eq!(snapshot.active_entries, 2);
|
||||
assert_eq!(snapshot.recent_entries, 2);
|
||||
assert_eq!(snapshot.cleanup_queue_len, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_compact_prunes_stale_recent_entries() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker
|
||||
.set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1)
|
||||
.await;
|
||||
|
||||
let stale_user = "stale-user".to_string();
|
||||
let stale_ip = test_ipv4(10, 3, 0, 1);
|
||||
{
|
||||
let mut recent_ips = tracker.recent_ips.write().await;
|
||||
recent_ips
|
||||
.entry(stale_user.clone())
|
||||
.or_insert_with(HashMap::new)
|
||||
.insert(stale_ip, Instant::now() - Duration::from_secs(5));
|
||||
}
|
||||
|
||||
tracker.last_compact_epoch_secs.store(0, Ordering::Relaxed);
|
||||
tracker
|
||||
.check_and_add("trigger-user", test_ipv4(10, 3, 0, 2))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let recent_ips = tracker.recent_ips.read().await;
|
||||
let stale_exists = recent_ips
|
||||
.get(&stale_user)
|
||||
.map(|ips| ips.contains_key(&stale_ip))
|
||||
.unwrap_or(false);
|
||||
assert!(!stale_exists);
|
||||
}
|
||||
}
|
||||
|
||||
+59
-21
@@ -88,8 +88,10 @@ pub fn init_logging(
|
||||
// Use a custom fmt layer that writes to syslog
|
||||
let fmt_layer = fmt::Layer::default()
|
||||
.with_ansi(false)
|
||||
.with_target(true)
|
||||
.with_writer(SyslogWriter::new);
|
||||
.with_target(false)
|
||||
.with_level(false)
|
||||
.without_time()
|
||||
.with_writer(SyslogMakeWriter::new());
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(filter_layer)
|
||||
@@ -137,12 +139,17 @@ pub fn init_logging(
|
||||
|
||||
/// Syslog writer for tracing.
|
||||
#[cfg(unix)]
|
||||
#[derive(Clone, Copy)]
|
||||
struct SyslogMakeWriter;
|
||||
|
||||
#[cfg(unix)]
|
||||
#[derive(Clone, Copy)]
|
||||
struct SyslogWriter {
|
||||
_private: (),
|
||||
priority: libc::c_int,
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl SyslogWriter {
|
||||
impl SyslogMakeWriter {
|
||||
fn new() -> Self {
|
||||
// Open syslog connection on first use
|
||||
static INIT: std::sync::Once = std::sync::Once::new();
|
||||
@@ -153,7 +160,18 @@ impl SyslogWriter {
|
||||
libc::openlog(ident, libc::LOG_PID | libc::LOG_NDELAY, libc::LOG_DAEMON);
|
||||
}
|
||||
});
|
||||
Self { _private: () }
|
||||
Self
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn syslog_priority_for_level(level: &tracing::Level) -> libc::c_int {
|
||||
match *level {
|
||||
tracing::Level::ERROR => libc::LOG_ERR,
|
||||
tracing::Level::WARN => libc::LOG_WARNING,
|
||||
tracing::Level::INFO => libc::LOG_INFO,
|
||||
tracing::Level::DEBUG => libc::LOG_DEBUG,
|
||||
tracing::Level::TRACE => libc::LOG_DEBUG,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -168,26 +186,13 @@ impl std::io::Write for SyslogWriter {
|
||||
return Ok(buf.len());
|
||||
}
|
||||
|
||||
// Determine priority based on log level in the message
|
||||
let priority = if msg.contains(" ERROR ") || msg.contains(" error ") {
|
||||
libc::LOG_ERR
|
||||
} else if msg.contains(" WARN ") || msg.contains(" warn ") {
|
||||
libc::LOG_WARNING
|
||||
} else if msg.contains(" INFO ") || msg.contains(" info ") {
|
||||
libc::LOG_INFO
|
||||
} else if msg.contains(" DEBUG ") || msg.contains(" debug ") {
|
||||
libc::LOG_DEBUG
|
||||
} else {
|
||||
libc::LOG_INFO
|
||||
};
|
||||
|
||||
// Write to syslog
|
||||
let c_msg = std::ffi::CString::new(msg.as_bytes())
|
||||
.unwrap_or_else(|_| std::ffi::CString::new("(invalid utf8)").unwrap());
|
||||
|
||||
unsafe {
|
||||
libc::syslog(
|
||||
priority,
|
||||
self.priority,
|
||||
b"%s\0".as_ptr() as *const libc::c_char,
|
||||
c_msg.as_ptr(),
|
||||
);
|
||||
@@ -202,11 +207,19 @@ impl std::io::Write for SyslogWriter {
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for SyslogWriter {
|
||||
impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for SyslogMakeWriter {
|
||||
type Writer = SyslogWriter;
|
||||
|
||||
fn make_writer(&'a self) -> Self::Writer {
|
||||
SyslogWriter::new()
|
||||
SyslogWriter {
|
||||
priority: libc::LOG_INFO,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_writer_for(&'a self, meta: &tracing::Metadata<'_>) -> Self::Writer {
|
||||
SyslogWriter {
|
||||
priority: syslog_priority_for_level(meta.level()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -302,4 +315,29 @@ mod tests {
|
||||
LogDestination::Syslog
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn test_syslog_priority_for_level_mapping() {
|
||||
assert_eq!(
|
||||
syslog_priority_for_level(&tracing::Level::ERROR),
|
||||
libc::LOG_ERR
|
||||
);
|
||||
assert_eq!(
|
||||
syslog_priority_for_level(&tracing::Level::WARN),
|
||||
libc::LOG_WARNING
|
||||
);
|
||||
assert_eq!(
|
||||
syslog_priority_for_level(&tracing::Level::INFO),
|
||||
libc::LOG_INFO
|
||||
);
|
||||
assert_eq!(
|
||||
syslog_priority_for_level(&tracing::Level::DEBUG),
|
||||
libc::LOG_DEBUG
|
||||
);
|
||||
assert_eq!(
|
||||
syslog_priority_for_level(&tracing::Level::TRACE),
|
||||
libc::LOG_DEBUG
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
+83
-15
@@ -18,19 +18,38 @@ use crate::transport::middle_proxy::{
|
||||
pub(crate) fn resolve_runtime_config_path(
|
||||
config_path_cli: &str,
|
||||
startup_cwd: &std::path::Path,
|
||||
config_path_explicit: bool,
|
||||
) -> PathBuf {
|
||||
let raw = PathBuf::from(config_path_cli);
|
||||
let absolute = if raw.is_absolute() {
|
||||
raw
|
||||
} else {
|
||||
startup_cwd.join(raw)
|
||||
};
|
||||
absolute.canonicalize().unwrap_or(absolute)
|
||||
if config_path_explicit {
|
||||
let raw = PathBuf::from(config_path_cli);
|
||||
let absolute = if raw.is_absolute() {
|
||||
raw
|
||||
} else {
|
||||
startup_cwd.join(raw)
|
||||
};
|
||||
return absolute.canonicalize().unwrap_or(absolute);
|
||||
}
|
||||
|
||||
let etc_telemt = std::path::Path::new("/etc/telemt");
|
||||
let candidates = [
|
||||
startup_cwd.join("config.toml"),
|
||||
startup_cwd.join("telemt.toml"),
|
||||
etc_telemt.join("telemt.toml"),
|
||||
etc_telemt.join("config.toml"),
|
||||
];
|
||||
for candidate in candidates {
|
||||
if candidate.is_file() {
|
||||
return candidate.canonicalize().unwrap_or(candidate);
|
||||
}
|
||||
}
|
||||
|
||||
startup_cwd.join("config.toml")
|
||||
}
|
||||
|
||||
/// Parsed CLI arguments.
|
||||
pub(crate) struct CliArgs {
|
||||
pub config_path: String,
|
||||
pub config_path_explicit: bool,
|
||||
pub data_path: Option<PathBuf>,
|
||||
pub silent: bool,
|
||||
pub log_level: Option<String>,
|
||||
@@ -39,6 +58,7 @@ pub(crate) struct CliArgs {
|
||||
|
||||
pub(crate) fn parse_cli() -> CliArgs {
|
||||
let mut config_path = "config.toml".to_string();
|
||||
let mut config_path_explicit = false;
|
||||
let mut data_path: Option<PathBuf> = None;
|
||||
let mut silent = false;
|
||||
let mut log_level: Option<String> = None;
|
||||
@@ -74,6 +94,20 @@ pub(crate) fn parse_cli() -> CliArgs {
|
||||
s.trim_start_matches("--data-path=").to_string(),
|
||||
));
|
||||
}
|
||||
"--working-dir" => {
|
||||
i += 1;
|
||||
if i < args.len() {
|
||||
data_path = Some(PathBuf::from(args[i].clone()));
|
||||
} else {
|
||||
eprintln!("Missing value for --working-dir");
|
||||
std::process::exit(0);
|
||||
}
|
||||
}
|
||||
s if s.starts_with("--working-dir=") => {
|
||||
data_path = Some(PathBuf::from(
|
||||
s.trim_start_matches("--working-dir=").to_string(),
|
||||
));
|
||||
}
|
||||
"--silent" | "-s" => {
|
||||
silent = true;
|
||||
}
|
||||
@@ -111,13 +145,11 @@ pub(crate) fn parse_cli() -> CliArgs {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
s if s.starts_with("--working-dir") => {
|
||||
if !s.contains('=') {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
s if !s.starts_with('-') => {
|
||||
config_path = s.to_string();
|
||||
if !matches!(s, "run" | "start" | "stop" | "reload" | "status") {
|
||||
config_path = s.to_string();
|
||||
config_path_explicit = true;
|
||||
}
|
||||
}
|
||||
other => {
|
||||
eprintln!("Unknown option: {}", other);
|
||||
@@ -128,6 +160,7 @@ pub(crate) fn parse_cli() -> CliArgs {
|
||||
|
||||
CliArgs {
|
||||
config_path,
|
||||
config_path_explicit,
|
||||
data_path,
|
||||
silent,
|
||||
log_level,
|
||||
@@ -152,6 +185,7 @@ fn print_help() {
|
||||
eprintln!(
|
||||
" --data-path <DIR> Set data directory (absolute path; overrides config value)"
|
||||
);
|
||||
eprintln!(" --working-dir <DIR> Alias for --data-path");
|
||||
eprintln!(" --silent, -s Suppress info logs");
|
||||
eprintln!(" --log-level <LEVEL> debug|verbose|normal|silent");
|
||||
eprintln!(" --help, -h Show this help");
|
||||
@@ -210,7 +244,7 @@ mod tests {
|
||||
let target = startup_cwd.join("config.toml");
|
||||
std::fs::write(&target, " ").unwrap();
|
||||
|
||||
let resolved = resolve_runtime_config_path("config.toml", &startup_cwd);
|
||||
let resolved = resolve_runtime_config_path("config.toml", &startup_cwd, true);
|
||||
assert_eq!(resolved, target.canonicalize().unwrap());
|
||||
|
||||
let _ = std::fs::remove_file(&target);
|
||||
@@ -226,11 +260,45 @@ mod tests {
|
||||
let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_missing_{nonce}"));
|
||||
std::fs::create_dir_all(&startup_cwd).unwrap();
|
||||
|
||||
let resolved = resolve_runtime_config_path("missing.toml", &startup_cwd);
|
||||
let resolved = resolve_runtime_config_path("missing.toml", &startup_cwd, true);
|
||||
assert_eq!(resolved, startup_cwd.join("missing.toml"));
|
||||
|
||||
let _ = std::fs::remove_dir(&startup_cwd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_runtime_config_path_uses_startup_candidates_when_not_explicit() {
|
||||
let nonce = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
let startup_cwd =
|
||||
std::env::temp_dir().join(format!("telemt_cfg_startup_candidates_{nonce}"));
|
||||
std::fs::create_dir_all(&startup_cwd).unwrap();
|
||||
let telemt = startup_cwd.join("telemt.toml");
|
||||
std::fs::write(&telemt, " ").unwrap();
|
||||
|
||||
let resolved = resolve_runtime_config_path("config.toml", &startup_cwd, false);
|
||||
assert_eq!(resolved, telemt.canonicalize().unwrap());
|
||||
|
||||
let _ = std::fs::remove_file(&telemt);
|
||||
let _ = std::fs::remove_dir(&startup_cwd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_runtime_config_path_defaults_to_startup_config_when_none_found() {
|
||||
let nonce = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_startup_default_{nonce}"));
|
||||
std::fs::create_dir_all(&startup_cwd).unwrap();
|
||||
|
||||
let resolved = resolve_runtime_config_path("config.toml", &startup_cwd, false);
|
||||
assert_eq!(resolved, startup_cwd.join("config.toml"));
|
||||
|
||||
let _ = std::fs::remove_dir(&startup_cwd);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) {
|
||||
|
||||
@@ -14,6 +14,7 @@ use crate::crypto::SecureRandom;
|
||||
use crate::ip_tracker::UserIpTracker;
|
||||
use crate::proxy::ClientHandler;
|
||||
use crate::proxy::route_mode::{ROUTE_SWITCH_ERROR_MSG, RouteRuntimeController};
|
||||
use crate::proxy::shared_state::ProxySharedState;
|
||||
use crate::startup::{COMPONENT_LISTENERS_BIND, StartupTracker};
|
||||
use crate::stats::beobachten::BeobachtenStore;
|
||||
use crate::stats::{ReplayChecker, Stats};
|
||||
@@ -49,6 +50,7 @@ pub(crate) async fn bind_listeners(
|
||||
tls_cache: Option<Arc<TlsFrontCache>>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
beobachten: Arc<BeobachtenStore>,
|
||||
shared: Arc<ProxySharedState>,
|
||||
max_connections: Arc<Semaphore>,
|
||||
) -> Result<BoundListeners, Box<dyn Error>> {
|
||||
startup_tracker
|
||||
@@ -224,6 +226,7 @@ pub(crate) async fn bind_listeners(
|
||||
let tls_cache = tls_cache.clone();
|
||||
let ip_tracker = ip_tracker.clone();
|
||||
let beobachten = beobachten.clone();
|
||||
let shared = shared.clone();
|
||||
let max_connections_unix = max_connections.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
@@ -259,6 +262,7 @@ pub(crate) async fn bind_listeners(
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
stats.increment_accept_permit_timeout_total();
|
||||
debug!(
|
||||
timeout_ms = accept_permit_timeout_ms,
|
||||
"Dropping accepted unix connection: permit wait timeout"
|
||||
@@ -284,11 +288,12 @@ pub(crate) async fn bind_listeners(
|
||||
let tls_cache = tls_cache.clone();
|
||||
let ip_tracker = ip_tracker.clone();
|
||||
let beobachten = beobachten.clone();
|
||||
let shared = shared.clone();
|
||||
let proxy_protocol_enabled = config.server.proxy_protocol;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let _permit = permit;
|
||||
if let Err(e) = crate::proxy::client::handle_client_stream(
|
||||
if let Err(e) = crate::proxy::client::handle_client_stream_with_shared(
|
||||
stream,
|
||||
fake_peer,
|
||||
config,
|
||||
@@ -302,6 +307,7 @@ pub(crate) async fn bind_listeners(
|
||||
tls_cache,
|
||||
ip_tracker,
|
||||
beobachten,
|
||||
shared,
|
||||
proxy_protocol_enabled,
|
||||
)
|
||||
.await
|
||||
@@ -351,6 +357,7 @@ pub(crate) fn spawn_tcp_accept_loops(
|
||||
tls_cache: Option<Arc<TlsFrontCache>>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
beobachten: Arc<BeobachtenStore>,
|
||||
shared: Arc<ProxySharedState>,
|
||||
max_connections: Arc<Semaphore>,
|
||||
) {
|
||||
for (listener, listener_proxy_protocol) in listeners {
|
||||
@@ -366,6 +373,7 @@ pub(crate) fn spawn_tcp_accept_loops(
|
||||
let tls_cache = tls_cache.clone();
|
||||
let ip_tracker = ip_tracker.clone();
|
||||
let beobachten = beobachten.clone();
|
||||
let shared = shared.clone();
|
||||
let max_connections_tcp = max_connections.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
@@ -400,6 +408,7 @@ pub(crate) fn spawn_tcp_accept_loops(
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
stats.increment_accept_permit_timeout_total();
|
||||
debug!(
|
||||
peer = %peer_addr,
|
||||
timeout_ms = accept_permit_timeout_ms,
|
||||
@@ -421,13 +430,14 @@ pub(crate) fn spawn_tcp_accept_loops(
|
||||
let tls_cache = tls_cache.clone();
|
||||
let ip_tracker = ip_tracker.clone();
|
||||
let beobachten = beobachten.clone();
|
||||
let shared = shared.clone();
|
||||
let proxy_protocol_enabled = listener_proxy_protocol;
|
||||
let real_peer_report = Arc::new(std::sync::Mutex::new(None));
|
||||
let real_peer_report_for_handler = real_peer_report.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let _permit = permit;
|
||||
if let Err(e) = ClientHandler::new(
|
||||
if let Err(e) = ClientHandler::new_with_shared(
|
||||
stream,
|
||||
peer_addr,
|
||||
config,
|
||||
@@ -441,6 +451,7 @@ pub(crate) fn spawn_tcp_accept_loops(
|
||||
tls_cache,
|
||||
ip_tracker,
|
||||
beobachten,
|
||||
shared,
|
||||
proxy_protocol_enabled,
|
||||
real_peer_report_for_handler,
|
||||
)
|
||||
|
||||
+112
-7
@@ -29,10 +29,12 @@ use tracing_subscriber::{EnvFilter, fmt, prelude::*, reload};
|
||||
|
||||
use crate::api;
|
||||
use crate::config::{LogLevel, ProxyConfig};
|
||||
use crate::conntrack_control;
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::ip_tracker::UserIpTracker;
|
||||
use crate::network::probe::{decide_network_capabilities, log_probe_result, run_probe};
|
||||
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
|
||||
use crate::proxy::shared_state::ProxySharedState;
|
||||
use crate::startup::{
|
||||
COMPONENT_API_BOOTSTRAP, COMPONENT_CONFIG_LOAD, COMPONENT_ME_POOL_CONSTRUCT,
|
||||
COMPONENT_ME_POOL_INIT_STAGE1, COMPONENT_ME_PROXY_CONFIG_V4, COMPONENT_ME_PROXY_CONFIG_V6,
|
||||
@@ -110,6 +112,7 @@ async fn run_inner(
|
||||
.await;
|
||||
let cli_args = parse_cli();
|
||||
let config_path_cli = cli_args.config_path;
|
||||
let config_path_explicit = cli_args.config_path_explicit;
|
||||
let data_path = cli_args.data_path;
|
||||
let cli_silent = cli_args.silent;
|
||||
let cli_log_level = cli_args.log_level;
|
||||
@@ -121,7 +124,8 @@ async fn run_inner(
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
let config_path = resolve_runtime_config_path(&config_path_cli, &startup_cwd);
|
||||
let mut config_path =
|
||||
resolve_runtime_config_path(&config_path_cli, &startup_cwd, config_path_explicit);
|
||||
|
||||
let mut config = match ProxyConfig::load(&config_path) {
|
||||
Ok(c) => c,
|
||||
@@ -131,11 +135,99 @@ async fn run_inner(
|
||||
std::process::exit(1);
|
||||
} else {
|
||||
let default = ProxyConfig::default();
|
||||
std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
|
||||
eprintln!(
|
||||
"[telemt] Created default config at {}",
|
||||
config_path.display()
|
||||
);
|
||||
|
||||
let serialized =
|
||||
match toml::to_string_pretty(&default).or_else(|_| toml::to_string(&default)) {
|
||||
Ok(value) => Some(value),
|
||||
Err(serialize_error) => {
|
||||
eprintln!(
|
||||
"[telemt] Warning: failed to serialize default config: {}",
|
||||
serialize_error
|
||||
);
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
if config_path_explicit {
|
||||
if let Some(serialized) = serialized.as_ref() {
|
||||
if let Err(write_error) = std::fs::write(&config_path, serialized) {
|
||||
eprintln!(
|
||||
"[telemt] Error: failed to create explicit config at {}: {}",
|
||||
config_path.display(),
|
||||
write_error
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
eprintln!(
|
||||
"[telemt] Created default config at {}",
|
||||
config_path.display()
|
||||
);
|
||||
} else {
|
||||
eprintln!(
|
||||
"[telemt] Warning: running with in-memory default config without writing to disk"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
let system_dir = std::path::Path::new("/etc/telemt");
|
||||
let system_config_path = system_dir.join("telemt.toml");
|
||||
let startup_config_path = startup_cwd.join("config.toml");
|
||||
let mut persisted = false;
|
||||
|
||||
if let Some(serialized) = serialized.as_ref() {
|
||||
match std::fs::create_dir_all(system_dir) {
|
||||
Ok(()) => match std::fs::write(&system_config_path, serialized) {
|
||||
Ok(()) => {
|
||||
config_path = system_config_path;
|
||||
eprintln!(
|
||||
"[telemt] Created default config at {}",
|
||||
config_path.display()
|
||||
);
|
||||
persisted = true;
|
||||
}
|
||||
Err(write_error) => {
|
||||
eprintln!(
|
||||
"[telemt] Warning: failed to write default config at {}: {}",
|
||||
system_config_path.display(),
|
||||
write_error
|
||||
);
|
||||
}
|
||||
},
|
||||
Err(create_error) => {
|
||||
eprintln!(
|
||||
"[telemt] Warning: failed to create {}: {}",
|
||||
system_dir.display(),
|
||||
create_error
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if !persisted {
|
||||
match std::fs::write(&startup_config_path, serialized) {
|
||||
Ok(()) => {
|
||||
config_path = startup_config_path;
|
||||
eprintln!(
|
||||
"[telemt] Created default config at {}",
|
||||
config_path.display()
|
||||
);
|
||||
persisted = true;
|
||||
}
|
||||
Err(write_error) => {
|
||||
eprintln!(
|
||||
"[telemt] Warning: failed to write default config at {}: {}",
|
||||
startup_config_path.display(),
|
||||
write_error
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !persisted {
|
||||
eprintln!(
|
||||
"[telemt] Warning: running with in-memory default config without writing to disk"
|
||||
);
|
||||
}
|
||||
}
|
||||
default
|
||||
}
|
||||
}
|
||||
@@ -631,6 +723,12 @@ async fn run_inner(
|
||||
)
|
||||
.await;
|
||||
let _admission_tx_hold = admission_tx;
|
||||
let shared_state = ProxySharedState::new();
|
||||
conntrack_control::spawn_conntrack_controller(
|
||||
config_rx.clone(),
|
||||
stats.clone(),
|
||||
shared_state.clone(),
|
||||
);
|
||||
|
||||
let bound = listeners::bind_listeners(
|
||||
&config,
|
||||
@@ -651,6 +749,7 @@ async fn run_inner(
|
||||
tls_cache.clone(),
|
||||
ip_tracker.clone(),
|
||||
beobachten.clone(),
|
||||
shared_state.clone(),
|
||||
max_connections.clone(),
|
||||
)
|
||||
.await?;
|
||||
@@ -664,7 +763,11 @@ async fn run_inner(
|
||||
|
||||
// Drop privileges after binding sockets (which may require root for port < 1024)
|
||||
if daemon_opts.user.is_some() || daemon_opts.group.is_some() {
|
||||
if let Err(e) = drop_privileges(daemon_opts.user.as_deref(), daemon_opts.group.as_deref()) {
|
||||
if let Err(e) = drop_privileges(
|
||||
daemon_opts.user.as_deref(),
|
||||
daemon_opts.group.as_deref(),
|
||||
_pid_file.as_ref(),
|
||||
) {
|
||||
error!(error = %e, "Failed to drop privileges");
|
||||
std::process::exit(1);
|
||||
}
|
||||
@@ -683,6 +786,7 @@ async fn run_inner(
|
||||
&startup_tracker,
|
||||
stats.clone(),
|
||||
beobachten.clone(),
|
||||
shared_state.clone(),
|
||||
ip_tracker.clone(),
|
||||
config_rx.clone(),
|
||||
)
|
||||
@@ -707,6 +811,7 @@ async fn run_inner(
|
||||
tls_cache.clone(),
|
||||
ip_tracker.clone(),
|
||||
beobachten.clone(),
|
||||
shared_state,
|
||||
max_connections.clone(),
|
||||
);
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ use crate::crypto::SecureRandom;
|
||||
use crate::ip_tracker::UserIpTracker;
|
||||
use crate::metrics;
|
||||
use crate::network::probe::NetworkProbe;
|
||||
use crate::proxy::shared_state::ProxySharedState;
|
||||
use crate::startup::{
|
||||
COMPONENT_CONFIG_WATCHER_START, COMPONENT_METRICS_START, COMPONENT_RUNTIME_READY,
|
||||
StartupTracker,
|
||||
@@ -287,6 +288,7 @@ pub(crate) async fn spawn_metrics_if_configured(
|
||||
startup_tracker: &Arc<StartupTracker>,
|
||||
stats: Arc<Stats>,
|
||||
beobachten: Arc<BeobachtenStore>,
|
||||
shared_state: Arc<ProxySharedState>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
config_rx: watch::Receiver<Arc<ProxyConfig>>,
|
||||
) {
|
||||
@@ -320,6 +322,7 @@ pub(crate) async fn spawn_metrics_if_configured(
|
||||
.await;
|
||||
let stats = stats.clone();
|
||||
let beobachten = beobachten.clone();
|
||||
let shared_state = shared_state.clone();
|
||||
let config_rx_metrics = config_rx.clone();
|
||||
let ip_tracker_metrics = ip_tracker.clone();
|
||||
let whitelist = config.server.metrics_whitelist.clone();
|
||||
@@ -331,6 +334,7 @@ pub(crate) async fn spawn_metrics_if_configured(
|
||||
listen_backlog,
|
||||
stats,
|
||||
beobachten,
|
||||
shared_state,
|
||||
ip_tracker_metrics,
|
||||
config_rx_metrics,
|
||||
whitelist,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
mod api;
|
||||
mod cli;
|
||||
mod config;
|
||||
mod conntrack_control;
|
||||
mod crypto;
|
||||
#[cfg(unix)]
|
||||
mod daemon;
|
||||
|
||||
+345
-14
@@ -15,6 +15,7 @@ use tracing::{debug, info, warn};
|
||||
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::ip_tracker::UserIpTracker;
|
||||
use crate::proxy::shared_state::ProxySharedState;
|
||||
use crate::stats::Stats;
|
||||
use crate::stats::beobachten::BeobachtenStore;
|
||||
use crate::transport::{ListenOptions, create_listener};
|
||||
@@ -25,6 +26,7 @@ pub async fn serve(
|
||||
listen_backlog: u32,
|
||||
stats: Arc<Stats>,
|
||||
beobachten: Arc<BeobachtenStore>,
|
||||
shared_state: Arc<ProxySharedState>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
config_rx: tokio::sync::watch::Receiver<Arc<ProxyConfig>>,
|
||||
whitelist: Vec<IpNetwork>,
|
||||
@@ -45,7 +47,13 @@ pub async fn serve(
|
||||
Ok(listener) => {
|
||||
info!("Metrics endpoint: http://{}/metrics and /beobachten", addr);
|
||||
serve_listener(
|
||||
listener, stats, beobachten, ip_tracker, config_rx, whitelist,
|
||||
listener,
|
||||
stats,
|
||||
beobachten,
|
||||
shared_state,
|
||||
ip_tracker,
|
||||
config_rx,
|
||||
whitelist,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -94,13 +102,20 @@ pub async fn serve(
|
||||
}
|
||||
(Some(listener), None) | (None, Some(listener)) => {
|
||||
serve_listener(
|
||||
listener, stats, beobachten, ip_tracker, config_rx, whitelist,
|
||||
listener,
|
||||
stats,
|
||||
beobachten,
|
||||
shared_state,
|
||||
ip_tracker,
|
||||
config_rx,
|
||||
whitelist,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
(Some(listener4), Some(listener6)) => {
|
||||
let stats_v6 = stats.clone();
|
||||
let beobachten_v6 = beobachten.clone();
|
||||
let shared_state_v6 = shared_state.clone();
|
||||
let ip_tracker_v6 = ip_tracker.clone();
|
||||
let config_rx_v6 = config_rx.clone();
|
||||
let whitelist_v6 = whitelist.clone();
|
||||
@@ -109,6 +124,7 @@ pub async fn serve(
|
||||
listener6,
|
||||
stats_v6,
|
||||
beobachten_v6,
|
||||
shared_state_v6,
|
||||
ip_tracker_v6,
|
||||
config_rx_v6,
|
||||
whitelist_v6,
|
||||
@@ -116,7 +132,13 @@ pub async fn serve(
|
||||
.await;
|
||||
});
|
||||
serve_listener(
|
||||
listener4, stats, beobachten, ip_tracker, config_rx, whitelist,
|
||||
listener4,
|
||||
stats,
|
||||
beobachten,
|
||||
shared_state,
|
||||
ip_tracker,
|
||||
config_rx,
|
||||
whitelist,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
@@ -142,6 +164,7 @@ async fn serve_listener(
|
||||
listener: TcpListener,
|
||||
stats: Arc<Stats>,
|
||||
beobachten: Arc<BeobachtenStore>,
|
||||
shared_state: Arc<ProxySharedState>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
config_rx: tokio::sync::watch::Receiver<Arc<ProxyConfig>>,
|
||||
whitelist: Arc<Vec<IpNetwork>>,
|
||||
@@ -162,15 +185,19 @@ async fn serve_listener(
|
||||
|
||||
let stats = stats.clone();
|
||||
let beobachten = beobachten.clone();
|
||||
let shared_state = shared_state.clone();
|
||||
let ip_tracker = ip_tracker.clone();
|
||||
let config_rx_conn = config_rx.clone();
|
||||
tokio::spawn(async move {
|
||||
let svc = service_fn(move |req| {
|
||||
let stats = stats.clone();
|
||||
let beobachten = beobachten.clone();
|
||||
let shared_state = shared_state.clone();
|
||||
let ip_tracker = ip_tracker.clone();
|
||||
let config = config_rx_conn.borrow().clone();
|
||||
async move { handle(req, &stats, &beobachten, &ip_tracker, &config).await }
|
||||
async move {
|
||||
handle(req, &stats, &beobachten, &shared_state, &ip_tracker, &config).await
|
||||
}
|
||||
});
|
||||
if let Err(e) = http1::Builder::new()
|
||||
.serve_connection(hyper_util::rt::TokioIo::new(stream), svc)
|
||||
@@ -186,11 +213,12 @@ async fn handle<B>(
|
||||
req: Request<B>,
|
||||
stats: &Stats,
|
||||
beobachten: &BeobachtenStore,
|
||||
shared_state: &ProxySharedState,
|
||||
ip_tracker: &UserIpTracker,
|
||||
config: &ProxyConfig,
|
||||
) -> Result<Response<Full<Bytes>>, Infallible> {
|
||||
if req.uri().path() == "/metrics" {
|
||||
let body = render_metrics(stats, config, ip_tracker).await;
|
||||
let body = render_metrics(stats, shared_state, config, ip_tracker).await;
|
||||
let resp = Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header("content-type", "text/plain; version=0.0.4; charset=utf-8")
|
||||
@@ -225,7 +253,12 @@ fn render_beobachten(beobachten: &BeobachtenStore, config: &ProxyConfig) -> Stri
|
||||
beobachten.snapshot_text(ttl)
|
||||
}
|
||||
|
||||
async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIpTracker) -> String {
|
||||
async fn render_metrics(
|
||||
stats: &Stats,
|
||||
shared_state: &ProxySharedState,
|
||||
config: &ProxyConfig,
|
||||
ip_tracker: &UserIpTracker,
|
||||
) -> String {
|
||||
use std::fmt::Write;
|
||||
let mut out = String::with_capacity(4096);
|
||||
let telemetry = stats.telemetry_policy();
|
||||
@@ -304,6 +337,27 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
|
||||
}
|
||||
);
|
||||
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# HELP telemt_buffer_pool_buffers_total Snapshot of pooled and allocated buffers"
|
||||
);
|
||||
let _ = writeln!(out, "# TYPE telemt_buffer_pool_buffers_total gauge");
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_buffer_pool_buffers_total{{kind=\"pooled\"}} {}",
|
||||
stats.get_buffer_pool_pooled_gauge()
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_buffer_pool_buffers_total{{kind=\"allocated\"}} {}",
|
||||
stats.get_buffer_pool_allocated_gauge()
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_buffer_pool_buffers_total{{kind=\"in_use\"}} {}",
|
||||
stats.get_buffer_pool_in_use_gauge()
|
||||
);
|
||||
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# HELP telemt_connections_total Total accepted connections"
|
||||
@@ -349,6 +403,170 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
|
||||
}
|
||||
);
|
||||
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# HELP telemt_auth_expensive_checks_total Expensive authentication candidate checks executed during handshake validation"
|
||||
);
|
||||
let _ = writeln!(out, "# TYPE telemt_auth_expensive_checks_total counter");
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_auth_expensive_checks_total {}",
|
||||
if core_enabled {
|
||||
shared_state
|
||||
.handshake
|
||||
.auth_expensive_checks_total
|
||||
.load(std::sync::atomic::Ordering::Relaxed)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
);
|
||||
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# HELP telemt_auth_budget_exhausted_total Handshake validations that hit authentication candidate budget limits"
|
||||
);
|
||||
let _ = writeln!(out, "# TYPE telemt_auth_budget_exhausted_total counter");
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_auth_budget_exhausted_total {}",
|
||||
if core_enabled {
|
||||
shared_state
|
||||
.handshake
|
||||
.auth_budget_exhausted_total
|
||||
.load(std::sync::atomic::Ordering::Relaxed)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
);
|
||||
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# HELP telemt_accept_permit_timeout_total Accepted connections dropped due to permit wait timeout"
|
||||
);
|
||||
let _ = writeln!(out, "# TYPE telemt_accept_permit_timeout_total counter");
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_accept_permit_timeout_total {}",
|
||||
if core_enabled {
|
||||
stats.get_accept_permit_timeout_total()
|
||||
} else {
|
||||
0
|
||||
}
|
||||
);
|
||||
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# HELP telemt_conntrack_control_state Runtime conntrack control state flags"
|
||||
);
|
||||
let _ = writeln!(out, "# TYPE telemt_conntrack_control_state gauge");
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_conntrack_control_state{{flag=\"enabled\"}} {}",
|
||||
if stats.get_conntrack_control_enabled() {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_conntrack_control_state{{flag=\"available\"}} {}",
|
||||
if stats.get_conntrack_control_available() {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_conntrack_control_state{{flag=\"pressure_active\"}} {}",
|
||||
if stats.get_conntrack_pressure_active() {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_conntrack_control_state{{flag=\"rule_apply_ok\"}} {}",
|
||||
if stats.get_conntrack_rule_apply_ok() {
|
||||
1
|
||||
} else {
|
||||
0
|
||||
}
|
||||
);
|
||||
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# HELP telemt_conntrack_event_queue_depth Pending close events in conntrack control queue"
|
||||
);
|
||||
let _ = writeln!(out, "# TYPE telemt_conntrack_event_queue_depth gauge");
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_conntrack_event_queue_depth {}",
|
||||
stats.get_conntrack_event_queue_depth()
|
||||
);
|
||||
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# HELP telemt_conntrack_delete_total Conntrack delete attempts by outcome"
|
||||
);
|
||||
let _ = writeln!(out, "# TYPE telemt_conntrack_delete_total counter");
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_conntrack_delete_total{{result=\"attempt\"}} {}",
|
||||
if core_enabled {
|
||||
stats.get_conntrack_delete_attempt_total()
|
||||
} else {
|
||||
0
|
||||
}
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_conntrack_delete_total{{result=\"success\"}} {}",
|
||||
if core_enabled {
|
||||
stats.get_conntrack_delete_success_total()
|
||||
} else {
|
||||
0
|
||||
}
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_conntrack_delete_total{{result=\"not_found\"}} {}",
|
||||
if core_enabled {
|
||||
stats.get_conntrack_delete_not_found_total()
|
||||
} else {
|
||||
0
|
||||
}
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_conntrack_delete_total{{result=\"error\"}} {}",
|
||||
if core_enabled {
|
||||
stats.get_conntrack_delete_error_total()
|
||||
} else {
|
||||
0
|
||||
}
|
||||
);
|
||||
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# HELP telemt_conntrack_close_event_drop_total Dropped conntrack close events due to queue pressure or unavailable sender"
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# TYPE telemt_conntrack_close_event_drop_total counter"
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_conntrack_close_event_drop_total {}",
|
||||
if core_enabled {
|
||||
stats.get_conntrack_close_event_drop_total()
|
||||
} else {
|
||||
0
|
||||
}
|
||||
);
|
||||
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# HELP telemt_upstream_connect_attempt_total Upstream connect attempts across all requests"
|
||||
@@ -952,6 +1170,39 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
|
||||
}
|
||||
);
|
||||
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# HELP telemt_me_c2me_enqueue_events_total ME client->ME enqueue outcomes"
|
||||
);
|
||||
let _ = writeln!(out, "# TYPE telemt_me_c2me_enqueue_events_total counter");
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_me_c2me_enqueue_events_total{{event=\"full\"}} {}",
|
||||
if me_allows_normal {
|
||||
stats.get_me_c2me_send_full_total()
|
||||
} else {
|
||||
0
|
||||
}
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_me_c2me_enqueue_events_total{{event=\"high_water\"}} {}",
|
||||
if me_allows_normal {
|
||||
stats.get_me_c2me_send_high_water_total()
|
||||
} else {
|
||||
0
|
||||
}
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_me_c2me_enqueue_events_total{{event=\"timeout\"}} {}",
|
||||
if me_allows_normal {
|
||||
stats.get_me_c2me_send_timeout_total()
|
||||
} else {
|
||||
0
|
||||
}
|
||||
);
|
||||
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# HELP telemt_me_d2c_batches_total Total DC->Client flush batches"
|
||||
@@ -2501,6 +2752,48 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
|
||||
if user_enabled { 0 } else { 1 }
|
||||
);
|
||||
|
||||
let ip_memory = ip_tracker.memory_stats().await;
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# HELP telemt_ip_tracker_users Number of users tracked by IP limiter state"
|
||||
);
|
||||
let _ = writeln!(out, "# TYPE telemt_ip_tracker_users gauge");
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_ip_tracker_users{{scope=\"active\"}} {}",
|
||||
ip_memory.active_users
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_ip_tracker_users{{scope=\"recent\"}} {}",
|
||||
ip_memory.recent_users
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# HELP telemt_ip_tracker_entries Number of IP entries tracked by limiter state"
|
||||
);
|
||||
let _ = writeln!(out, "# TYPE telemt_ip_tracker_entries gauge");
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_ip_tracker_entries{{scope=\"active\"}} {}",
|
||||
ip_memory.active_entries
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_ip_tracker_entries{{scope=\"recent\"}} {}",
|
||||
ip_memory.recent_entries
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# HELP telemt_ip_tracker_cleanup_queue_len Deferred disconnect cleanup queue length"
|
||||
);
|
||||
let _ = writeln!(out, "# TYPE telemt_ip_tracker_cleanup_queue_len gauge");
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_ip_tracker_cleanup_queue_len {}",
|
||||
ip_memory.cleanup_queue_len
|
||||
);
|
||||
|
||||
if user_enabled {
|
||||
for entry in stats.iter_user_stats() {
|
||||
let user = entry.key();
|
||||
@@ -2634,6 +2927,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_render_metrics_format() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let shared_state = ProxySharedState::new();
|
||||
let tracker = UserIpTracker::new();
|
||||
let mut config = ProxyConfig::default();
|
||||
config
|
||||
@@ -2645,6 +2939,14 @@ mod tests {
|
||||
stats.increment_connects_all();
|
||||
stats.increment_connects_bad();
|
||||
stats.increment_handshake_timeouts();
|
||||
shared_state
|
||||
.handshake
|
||||
.auth_expensive_checks_total
|
||||
.fetch_add(9, std::sync::atomic::Ordering::Relaxed);
|
||||
shared_state
|
||||
.handshake
|
||||
.auth_budget_exhausted_total
|
||||
.fetch_add(2, std::sync::atomic::Ordering::Relaxed);
|
||||
stats.increment_upstream_connect_attempt_total();
|
||||
stats.increment_upstream_connect_attempt_total();
|
||||
stats.increment_upstream_connect_success_total();
|
||||
@@ -2688,7 +2990,7 @@ mod tests {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let output = render_metrics(&stats, &config, &tracker).await;
|
||||
let output = render_metrics(&stats, shared_state.as_ref(), &config, &tracker).await;
|
||||
|
||||
assert!(output.contains(&format!(
|
||||
"telemt_build_info{{version=\"{}\"}} 1",
|
||||
@@ -2697,6 +2999,8 @@ mod tests {
|
||||
assert!(output.contains("telemt_connections_total 2"));
|
||||
assert!(output.contains("telemt_connections_bad_total 1"));
|
||||
assert!(output.contains("telemt_handshake_timeouts_total 1"));
|
||||
assert!(output.contains("telemt_auth_expensive_checks_total 9"));
|
||||
assert!(output.contains("telemt_auth_budget_exhausted_total 2"));
|
||||
assert!(output.contains("telemt_upstream_connect_attempt_total 2"));
|
||||
assert!(output.contains("telemt_upstream_connect_success_total 1"));
|
||||
assert!(output.contains("telemt_upstream_connect_fail_total 1"));
|
||||
@@ -2743,17 +3047,23 @@ mod tests {
|
||||
assert!(output.contains("telemt_user_unique_ips_recent_window{user=\"alice\"} 1"));
|
||||
assert!(output.contains("telemt_user_unique_ips_limit{user=\"alice\"} 4"));
|
||||
assert!(output.contains("telemt_user_unique_ips_utilization{user=\"alice\"} 0.250000"));
|
||||
assert!(output.contains("telemt_ip_tracker_users{scope=\"active\"} 1"));
|
||||
assert!(output.contains("telemt_ip_tracker_entries{scope=\"active\"} 1"));
|
||||
assert!(output.contains("telemt_ip_tracker_cleanup_queue_len 0"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_render_empty_stats() {
|
||||
let stats = Stats::new();
|
||||
let shared_state = ProxySharedState::new();
|
||||
let tracker = UserIpTracker::new();
|
||||
let config = ProxyConfig::default();
|
||||
let output = render_metrics(&stats, &config, &tracker).await;
|
||||
let output = render_metrics(&stats, &shared_state, &config, &tracker).await;
|
||||
assert!(output.contains("telemt_connections_total 0"));
|
||||
assert!(output.contains("telemt_connections_bad_total 0"));
|
||||
assert!(output.contains("telemt_handshake_timeouts_total 0"));
|
||||
assert!(output.contains("telemt_auth_expensive_checks_total 0"));
|
||||
assert!(output.contains("telemt_auth_budget_exhausted_total 0"));
|
||||
assert!(output.contains("telemt_user_unique_ips_current{user="));
|
||||
assert!(output.contains("telemt_user_unique_ips_recent_window{user="));
|
||||
}
|
||||
@@ -2761,6 +3071,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_render_uses_global_each_unique_ip_limit() {
|
||||
let stats = Stats::new();
|
||||
let shared_state = ProxySharedState::new();
|
||||
stats.increment_user_connects("alice");
|
||||
stats.increment_user_curr_connects("alice");
|
||||
let tracker = UserIpTracker::new();
|
||||
@@ -2771,7 +3082,7 @@ mod tests {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.access.user_max_unique_ips_global_each = 2;
|
||||
|
||||
let output = render_metrics(&stats, &config, &tracker).await;
|
||||
let output = render_metrics(&stats, &shared_state, &config, &tracker).await;
|
||||
|
||||
assert!(output.contains("telemt_user_unique_ips_limit{user=\"alice\"} 2"));
|
||||
assert!(output.contains("telemt_user_unique_ips_utilization{user=\"alice\"} 0.500000"));
|
||||
@@ -2780,14 +3091,16 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_render_has_type_annotations() {
|
||||
let stats = Stats::new();
|
||||
let shared_state = ProxySharedState::new();
|
||||
let tracker = UserIpTracker::new();
|
||||
let config = ProxyConfig::default();
|
||||
let output = render_metrics(&stats, &config, &tracker).await;
|
||||
assert!(output.contains("# TYPE telemt_build_info gauge"));
|
||||
let output = render_metrics(&stats, &shared_state, &config, &tracker).await;
|
||||
assert!(output.contains("# TYPE telemt_uptime_seconds gauge"));
|
||||
assert!(output.contains("# TYPE telemt_connections_total counter"));
|
||||
assert!(output.contains("# TYPE telemt_connections_bad_total counter"));
|
||||
assert!(output.contains("# TYPE telemt_handshake_timeouts_total counter"));
|
||||
assert!(output.contains("# TYPE telemt_auth_expensive_checks_total counter"));
|
||||
assert!(output.contains("# TYPE telemt_auth_budget_exhausted_total counter"));
|
||||
assert!(output.contains("# TYPE telemt_upstream_connect_attempt_total counter"));
|
||||
assert!(output.contains("# TYPE telemt_me_rpc_proxy_req_signal_sent_total counter"));
|
||||
assert!(output.contains("# TYPE telemt_me_idle_close_by_peer_total counter"));
|
||||
@@ -2815,12 +3128,16 @@ mod tests {
|
||||
assert!(output.contains("# TYPE telemt_user_unique_ips_recent_window gauge"));
|
||||
assert!(output.contains("# TYPE telemt_user_unique_ips_limit gauge"));
|
||||
assert!(output.contains("# TYPE telemt_user_unique_ips_utilization gauge"));
|
||||
assert!(output.contains("# TYPE telemt_ip_tracker_users gauge"));
|
||||
assert!(output.contains("# TYPE telemt_ip_tracker_entries gauge"));
|
||||
assert!(output.contains("# TYPE telemt_ip_tracker_cleanup_queue_len gauge"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_endpoint_integration() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let beobachten = Arc::new(BeobachtenStore::new());
|
||||
let shared_state = ProxySharedState::new();
|
||||
let tracker = UserIpTracker::new();
|
||||
let mut config = ProxyConfig::default();
|
||||
stats.increment_connects_all();
|
||||
@@ -2828,7 +3145,7 @@ mod tests {
|
||||
stats.increment_connects_all();
|
||||
|
||||
let req = Request::builder().uri("/metrics").body(()).unwrap();
|
||||
let resp = handle(req, &stats, &beobachten, &tracker, &config)
|
||||
let resp = handle(req, &stats, &beobachten, shared_state.as_ref(), &tracker, &config)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), StatusCode::OK);
|
||||
@@ -2855,7 +3172,14 @@ mod tests {
|
||||
Duration::from_secs(600),
|
||||
);
|
||||
let req_beob = Request::builder().uri("/beobachten").body(()).unwrap();
|
||||
let resp_beob = handle(req_beob, &stats, &beobachten, &tracker, &config)
|
||||
let resp_beob = handle(
|
||||
req_beob,
|
||||
&stats,
|
||||
&beobachten,
|
||||
shared_state.as_ref(),
|
||||
&tracker,
|
||||
&config,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp_beob.status(), StatusCode::OK);
|
||||
@@ -2865,7 +3189,14 @@ mod tests {
|
||||
assert!(beob_text.contains("203.0.113.10-1"));
|
||||
|
||||
let req404 = Request::builder().uri("/other").body(()).unwrap();
|
||||
let resp404 = handle(req404, &stats, &beobachten, &tracker, &config)
|
||||
let resp404 = handle(
|
||||
req404,
|
||||
&stats,
|
||||
&beobachten,
|
||||
shared_state.as_ref(),
|
||||
&tracker,
|
||||
&config,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp404.status(), StatusCode::NOT_FOUND);
|
||||
|
||||
@@ -24,6 +24,8 @@ const DIRECT_S2C_CAP_BYTES: usize = 512 * 1024;
|
||||
const ME_FRAMES_CAP: usize = 96;
|
||||
const ME_BYTES_CAP: usize = 384 * 1024;
|
||||
const ME_DELAY_MIN_US: u64 = 150;
|
||||
const MAX_USER_PROFILES_ENTRIES: usize = 50_000;
|
||||
const MAX_USER_KEY_BYTES: usize = 512;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum AdaptiveTier {
|
||||
@@ -234,32 +236,50 @@ fn profiles() -> &'static DashMap<String, UserAdaptiveProfile> {
|
||||
}
|
||||
|
||||
pub fn seed_tier_for_user(user: &str) -> AdaptiveTier {
|
||||
if user.len() > MAX_USER_KEY_BYTES {
|
||||
return AdaptiveTier::Base;
|
||||
}
|
||||
let now = Instant::now();
|
||||
if let Some(entry) = profiles().get(user) {
|
||||
let value = entry.value();
|
||||
if now.duration_since(value.seen_at) <= PROFILE_TTL {
|
||||
let value = *entry.value();
|
||||
drop(entry);
|
||||
if now.saturating_duration_since(value.seen_at) <= PROFILE_TTL {
|
||||
return value.tier;
|
||||
}
|
||||
profiles().remove_if(user, |_, v| {
|
||||
now.saturating_duration_since(v.seen_at) > PROFILE_TTL
|
||||
});
|
||||
}
|
||||
AdaptiveTier::Base
|
||||
}
|
||||
|
||||
pub fn record_user_tier(user: &str, tier: AdaptiveTier) {
|
||||
let now = Instant::now();
|
||||
if let Some(mut entry) = profiles().get_mut(user) {
|
||||
let existing = *entry;
|
||||
let effective = if now.duration_since(existing.seen_at) > PROFILE_TTL {
|
||||
tier
|
||||
} else {
|
||||
max(existing.tier, tier)
|
||||
};
|
||||
*entry = UserAdaptiveProfile {
|
||||
tier: effective,
|
||||
seen_at: now,
|
||||
};
|
||||
if user.len() > MAX_USER_KEY_BYTES {
|
||||
return;
|
||||
}
|
||||
profiles().insert(user.to_string(), UserAdaptiveProfile { tier, seen_at: now });
|
||||
let now = Instant::now();
|
||||
let mut was_vacant = false;
|
||||
match profiles().entry(user.to_string()) {
|
||||
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
|
||||
let existing = *entry.get();
|
||||
let effective = if now.saturating_duration_since(existing.seen_at) > PROFILE_TTL {
|
||||
tier
|
||||
} else {
|
||||
max(existing.tier, tier)
|
||||
};
|
||||
entry.insert(UserAdaptiveProfile {
|
||||
tier: effective,
|
||||
seen_at: now,
|
||||
});
|
||||
}
|
||||
dashmap::mapref::entry::Entry::Vacant(slot) => {
|
||||
slot.insert(UserAdaptiveProfile { tier, seen_at: now });
|
||||
was_vacant = true;
|
||||
}
|
||||
}
|
||||
if was_vacant && profiles().len() > MAX_USER_PROFILES_ENTRIES {
|
||||
profiles().retain(|_, v| now.saturating_duration_since(v.seen_at) <= PROFILE_TTL);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn direct_copy_buffers_for_tier(
|
||||
@@ -310,6 +330,14 @@ fn scale(base: usize, numerator: usize, denominator: usize, cap: usize) -> usize
|
||||
scaled.min(cap).max(1)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/adaptive_buffers_security_tests.rs"]
|
||||
mod adaptive_buffers_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/adaptive_buffers_record_race_security_tests.rs"]
|
||||
mod adaptive_buffers_record_race_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
+185
-21
@@ -80,11 +80,16 @@ use crate::transport::middle_proxy::MePool;
|
||||
use crate::transport::socket::normalize_ip;
|
||||
use crate::transport::{UpstreamManager, configure_client_socket, parse_proxy_protocol};
|
||||
|
||||
use crate::proxy::direct_relay::handle_via_direct;
|
||||
use crate::proxy::handshake::{HandshakeSuccess, handle_mtproto_handshake, handle_tls_handshake};
|
||||
use crate::proxy::direct_relay::handle_via_direct_with_shared;
|
||||
use crate::proxy::handshake::{
|
||||
HandshakeSuccess, handle_mtproto_handshake_with_shared, handle_tls_handshake_with_shared,
|
||||
};
|
||||
#[cfg(test)]
|
||||
use crate::proxy::handshake::{handle_mtproto_handshake, handle_tls_handshake};
|
||||
use crate::proxy::masking::handle_bad_client;
|
||||
use crate::proxy::middle_relay::handle_via_middle_proxy;
|
||||
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
|
||||
use crate::proxy::shared_state::ProxySharedState;
|
||||
|
||||
fn beobachten_ttl(config: &ProxyConfig) -> Duration {
|
||||
const BEOBACHTEN_TTL_MAX_MINUTES: u64 = 24 * 60;
|
||||
@@ -186,6 +191,24 @@ fn handshake_timeout_with_mask_grace(config: &ProxyConfig) -> Duration {
|
||||
}
|
||||
}
|
||||
|
||||
fn effective_client_first_byte_idle_secs(config: &ProxyConfig, shared: &ProxySharedState) -> u64 {
|
||||
let idle_secs = config.timeouts.client_first_byte_idle_secs;
|
||||
if idle_secs == 0 {
|
||||
return 0;
|
||||
}
|
||||
if shared.conntrack_pressure_active() {
|
||||
idle_secs.min(
|
||||
config
|
||||
.server
|
||||
.conntrack_control
|
||||
.profile
|
||||
.client_first_byte_idle_cap_secs(),
|
||||
)
|
||||
} else {
|
||||
idle_secs
|
||||
}
|
||||
}
|
||||
|
||||
const MASK_CLASSIFIER_PREFETCH_WINDOW: usize = 16;
|
||||
#[cfg(test)]
|
||||
const MASK_CLASSIFIER_PREFETCH_TIMEOUT: Duration = Duration::from_millis(5);
|
||||
@@ -342,7 +365,48 @@ fn synthetic_local_addr(port: u16) -> SocketAddr {
|
||||
SocketAddr::from(([0, 0, 0, 0], port))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub async fn handle_client_stream<S>(
|
||||
stream: S,
|
||||
peer: SocketAddr,
|
||||
config: Arc<ProxyConfig>,
|
||||
stats: Arc<Stats>,
|
||||
upstream_manager: Arc<UpstreamManager>,
|
||||
replay_checker: Arc<ReplayChecker>,
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
me_pool: Option<Arc<MePool>>,
|
||||
route_runtime: Arc<RouteRuntimeController>,
|
||||
tls_cache: Option<Arc<TlsFrontCache>>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
beobachten: Arc<BeobachtenStore>,
|
||||
proxy_protocol_enabled: bool,
|
||||
) -> Result<()>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
handle_client_stream_with_shared(
|
||||
stream,
|
||||
peer,
|
||||
config,
|
||||
stats,
|
||||
upstream_manager,
|
||||
replay_checker,
|
||||
buffer_pool,
|
||||
rng,
|
||||
me_pool,
|
||||
route_runtime,
|
||||
tls_cache,
|
||||
ip_tracker,
|
||||
beobachten,
|
||||
ProxySharedState::new(),
|
||||
proxy_protocol_enabled,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn handle_client_stream_with_shared<S>(
|
||||
mut stream: S,
|
||||
peer: SocketAddr,
|
||||
config: Arc<ProxyConfig>,
|
||||
@@ -356,6 +420,7 @@ pub async fn handle_client_stream<S>(
|
||||
tls_cache: Option<Arc<TlsFrontCache>>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
beobachten: Arc<BeobachtenStore>,
|
||||
shared: Arc<ProxySharedState>,
|
||||
proxy_protocol_enabled: bool,
|
||||
) -> Result<()>
|
||||
where
|
||||
@@ -416,10 +481,11 @@ where
|
||||
|
||||
debug!(peer = %real_peer, "New connection (generic stream)");
|
||||
|
||||
let first_byte = if config.timeouts.client_first_byte_idle_secs == 0 {
|
||||
let first_byte_idle_secs = effective_client_first_byte_idle_secs(&config, shared.as_ref());
|
||||
let first_byte = if first_byte_idle_secs == 0 {
|
||||
None
|
||||
} else {
|
||||
let idle_timeout = Duration::from_secs(config.timeouts.client_first_byte_idle_secs);
|
||||
let idle_timeout = Duration::from_secs(first_byte_idle_secs);
|
||||
let mut first_byte = [0u8; 1];
|
||||
match timeout(idle_timeout, stream.read(&mut first_byte)).await {
|
||||
Ok(Ok(0)) => {
|
||||
@@ -455,7 +521,7 @@ where
|
||||
Err(_) => {
|
||||
debug!(
|
||||
peer = %real_peer,
|
||||
idle_secs = config.timeouts.client_first_byte_idle_secs,
|
||||
idle_secs = first_byte_idle_secs,
|
||||
"Closing idle pooled connection before first client byte"
|
||||
);
|
||||
return Ok(());
|
||||
@@ -550,9 +616,10 @@ where
|
||||
|
||||
let (read_half, write_half) = tokio::io::split(stream);
|
||||
|
||||
let (mut tls_reader, tls_writer, tls_user) = match handle_tls_handshake(
|
||||
let (mut tls_reader, tls_writer, tls_user) = match handle_tls_handshake_with_shared(
|
||||
&handshake, read_half, write_half, real_peer,
|
||||
&config, &replay_checker, &rng, tls_cache.clone(),
|
||||
shared.as_ref(),
|
||||
).await {
|
||||
HandshakeResult::Success(result) => result,
|
||||
HandshakeResult::BadClient { reader, writer } => {
|
||||
@@ -578,9 +645,10 @@ where
|
||||
let mtproto_handshake: [u8; HANDSHAKE_LEN] = mtproto_data[..].try_into()
|
||||
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
|
||||
|
||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake_with_shared(
|
||||
&mtproto_handshake, tls_reader, tls_writer, real_peer,
|
||||
&config, &replay_checker, true, Some(tls_user.as_str()),
|
||||
shared.as_ref(),
|
||||
).await {
|
||||
HandshakeResult::Success(result) => result,
|
||||
HandshakeResult::BadClient { reader, writer } => {
|
||||
@@ -614,11 +682,12 @@ where
|
||||
};
|
||||
|
||||
Ok(HandshakeOutcome::NeedsRelay(Box::pin(
|
||||
RunningClientHandler::handle_authenticated_static(
|
||||
RunningClientHandler::handle_authenticated_static_with_shared(
|
||||
crypto_reader, crypto_writer, success,
|
||||
upstream_manager, stats, config, buffer_pool, rng, me_pool,
|
||||
route_runtime.clone(),
|
||||
local_addr, real_peer, ip_tracker.clone(),
|
||||
shared.clone(),
|
||||
),
|
||||
)))
|
||||
} else {
|
||||
@@ -644,9 +713,10 @@ where
|
||||
|
||||
let (read_half, write_half) = tokio::io::split(stream);
|
||||
|
||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake_with_shared(
|
||||
&handshake, read_half, write_half, real_peer,
|
||||
&config, &replay_checker, false, None,
|
||||
shared.as_ref(),
|
||||
).await {
|
||||
HandshakeResult::Success(result) => result,
|
||||
HandshakeResult::BadClient { reader, writer } => {
|
||||
@@ -665,7 +735,7 @@ where
|
||||
};
|
||||
|
||||
Ok(HandshakeOutcome::NeedsRelay(Box::pin(
|
||||
RunningClientHandler::handle_authenticated_static(
|
||||
RunningClientHandler::handle_authenticated_static_with_shared(
|
||||
crypto_reader,
|
||||
crypto_writer,
|
||||
success,
|
||||
@@ -679,6 +749,7 @@ where
|
||||
local_addr,
|
||||
real_peer,
|
||||
ip_tracker.clone(),
|
||||
shared.clone(),
|
||||
)
|
||||
)))
|
||||
}
|
||||
@@ -731,10 +802,12 @@ pub struct RunningClientHandler {
|
||||
tls_cache: Option<Arc<TlsFrontCache>>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
beobachten: Arc<BeobachtenStore>,
|
||||
shared: Arc<ProxySharedState>,
|
||||
proxy_protocol_enabled: bool,
|
||||
}
|
||||
|
||||
impl ClientHandler {
|
||||
#[cfg(test)]
|
||||
pub fn new(
|
||||
stream: TcpStream,
|
||||
peer: SocketAddr,
|
||||
@@ -751,6 +824,45 @@ impl ClientHandler {
|
||||
beobachten: Arc<BeobachtenStore>,
|
||||
proxy_protocol_enabled: bool,
|
||||
real_peer_report: Arc<std::sync::Mutex<Option<SocketAddr>>>,
|
||||
) -> RunningClientHandler {
|
||||
Self::new_with_shared(
|
||||
stream,
|
||||
peer,
|
||||
config,
|
||||
stats,
|
||||
upstream_manager,
|
||||
replay_checker,
|
||||
buffer_pool,
|
||||
rng,
|
||||
me_pool,
|
||||
route_runtime,
|
||||
tls_cache,
|
||||
ip_tracker,
|
||||
beobachten,
|
||||
ProxySharedState::new(),
|
||||
proxy_protocol_enabled,
|
||||
real_peer_report,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new_with_shared(
|
||||
stream: TcpStream,
|
||||
peer: SocketAddr,
|
||||
config: Arc<ProxyConfig>,
|
||||
stats: Arc<Stats>,
|
||||
upstream_manager: Arc<UpstreamManager>,
|
||||
replay_checker: Arc<ReplayChecker>,
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
me_pool: Option<Arc<MePool>>,
|
||||
route_runtime: Arc<RouteRuntimeController>,
|
||||
tls_cache: Option<Arc<TlsFrontCache>>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
beobachten: Arc<BeobachtenStore>,
|
||||
shared: Arc<ProxySharedState>,
|
||||
proxy_protocol_enabled: bool,
|
||||
real_peer_report: Arc<std::sync::Mutex<Option<SocketAddr>>>,
|
||||
) -> RunningClientHandler {
|
||||
let normalized_peer = normalize_ip(peer);
|
||||
RunningClientHandler {
|
||||
@@ -769,6 +881,7 @@ impl ClientHandler {
|
||||
tls_cache,
|
||||
ip_tracker,
|
||||
beobachten,
|
||||
shared,
|
||||
proxy_protocol_enabled,
|
||||
}
|
||||
}
|
||||
@@ -874,11 +987,12 @@ impl RunningClientHandler {
|
||||
}
|
||||
}
|
||||
|
||||
let first_byte = if self.config.timeouts.client_first_byte_idle_secs == 0 {
|
||||
let first_byte_idle_secs =
|
||||
effective_client_first_byte_idle_secs(&self.config, self.shared.as_ref());
|
||||
let first_byte = if first_byte_idle_secs == 0 {
|
||||
None
|
||||
} else {
|
||||
let idle_timeout =
|
||||
Duration::from_secs(self.config.timeouts.client_first_byte_idle_secs);
|
||||
let idle_timeout = Duration::from_secs(first_byte_idle_secs);
|
||||
let mut first_byte = [0u8; 1];
|
||||
match timeout(idle_timeout, self.stream.read(&mut first_byte)).await {
|
||||
Ok(Ok(0)) => {
|
||||
@@ -914,7 +1028,7 @@ impl RunningClientHandler {
|
||||
Err(_) => {
|
||||
debug!(
|
||||
peer = %self.peer,
|
||||
idle_secs = self.config.timeouts.client_first_byte_idle_secs,
|
||||
idle_secs = first_byte_idle_secs,
|
||||
"Closing idle pooled connection before first client byte"
|
||||
);
|
||||
return Ok(None);
|
||||
@@ -1058,7 +1172,7 @@ impl RunningClientHandler {
|
||||
|
||||
let (read_half, write_half) = self.stream.into_split();
|
||||
|
||||
let (mut tls_reader, tls_writer, tls_user) = match handle_tls_handshake(
|
||||
let (mut tls_reader, tls_writer, tls_user) = match handle_tls_handshake_with_shared(
|
||||
&handshake,
|
||||
read_half,
|
||||
write_half,
|
||||
@@ -1067,6 +1181,7 @@ impl RunningClientHandler {
|
||||
&replay_checker,
|
||||
&self.rng,
|
||||
self.tls_cache.clone(),
|
||||
self.shared.as_ref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -1095,7 +1210,7 @@ impl RunningClientHandler {
|
||||
.try_into()
|
||||
.map_err(|_| ProxyError::InvalidHandshake("Short MTProto handshake".into()))?;
|
||||
|
||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake_with_shared(
|
||||
&mtproto_handshake,
|
||||
tls_reader,
|
||||
tls_writer,
|
||||
@@ -1104,6 +1219,7 @@ impl RunningClientHandler {
|
||||
&replay_checker,
|
||||
true,
|
||||
Some(tls_user.as_str()),
|
||||
self.shared.as_ref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -1140,7 +1256,7 @@ impl RunningClientHandler {
|
||||
};
|
||||
|
||||
Ok(HandshakeOutcome::NeedsRelay(Box::pin(
|
||||
Self::handle_authenticated_static(
|
||||
Self::handle_authenticated_static_with_shared(
|
||||
crypto_reader,
|
||||
crypto_writer,
|
||||
success,
|
||||
@@ -1154,6 +1270,7 @@ impl RunningClientHandler {
|
||||
local_addr,
|
||||
peer,
|
||||
self.ip_tracker,
|
||||
self.shared,
|
||||
),
|
||||
)))
|
||||
}
|
||||
@@ -1192,7 +1309,7 @@ impl RunningClientHandler {
|
||||
|
||||
let (read_half, write_half) = self.stream.into_split();
|
||||
|
||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake(
|
||||
let (crypto_reader, crypto_writer, success) = match handle_mtproto_handshake_with_shared(
|
||||
&handshake,
|
||||
read_half,
|
||||
write_half,
|
||||
@@ -1201,6 +1318,7 @@ impl RunningClientHandler {
|
||||
&replay_checker,
|
||||
false,
|
||||
None,
|
||||
self.shared.as_ref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -1221,7 +1339,7 @@ impl RunningClientHandler {
|
||||
};
|
||||
|
||||
Ok(HandshakeOutcome::NeedsRelay(Box::pin(
|
||||
Self::handle_authenticated_static(
|
||||
Self::handle_authenticated_static_with_shared(
|
||||
crypto_reader,
|
||||
crypto_writer,
|
||||
success,
|
||||
@@ -1235,6 +1353,7 @@ impl RunningClientHandler {
|
||||
local_addr,
|
||||
peer,
|
||||
self.ip_tracker,
|
||||
self.shared,
|
||||
),
|
||||
)))
|
||||
}
|
||||
@@ -1243,6 +1362,7 @@ impl RunningClientHandler {
|
||||
/// Two modes:
|
||||
/// - Direct: TCP relay to TG DC (existing behavior)
|
||||
/// - Middle Proxy: RPC multiplex through ME pool (new — supports CDN DCs)
|
||||
#[cfg(test)]
|
||||
async fn handle_authenticated_static<R, W>(
|
||||
client_reader: CryptoReader<R>,
|
||||
client_writer: CryptoWriter<W>,
|
||||
@@ -1258,6 +1378,45 @@ impl RunningClientHandler {
|
||||
peer_addr: SocketAddr,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
) -> Result<()>
|
||||
where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
Self::handle_authenticated_static_with_shared(
|
||||
client_reader,
|
||||
client_writer,
|
||||
success,
|
||||
upstream_manager,
|
||||
stats,
|
||||
config,
|
||||
buffer_pool,
|
||||
rng,
|
||||
me_pool,
|
||||
route_runtime,
|
||||
local_addr,
|
||||
peer_addr,
|
||||
ip_tracker,
|
||||
ProxySharedState::new(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn handle_authenticated_static_with_shared<R, W>(
|
||||
client_reader: CryptoReader<R>,
|
||||
client_writer: CryptoWriter<W>,
|
||||
success: HandshakeSuccess,
|
||||
upstream_manager: Arc<UpstreamManager>,
|
||||
stats: Arc<Stats>,
|
||||
config: Arc<ProxyConfig>,
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
me_pool: Option<Arc<MePool>>,
|
||||
route_runtime: Arc<RouteRuntimeController>,
|
||||
local_addr: SocketAddr,
|
||||
peer_addr: SocketAddr,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
shared: Arc<ProxySharedState>,
|
||||
) -> Result<()>
|
||||
where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
@@ -1299,11 +1458,12 @@ impl RunningClientHandler {
|
||||
route_runtime.subscribe(),
|
||||
route_snapshot,
|
||||
session_id,
|
||||
shared.clone(),
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
warn!("use_middle_proxy=true but MePool not initialized, falling back to direct");
|
||||
handle_via_direct(
|
||||
handle_via_direct_with_shared(
|
||||
client_reader,
|
||||
client_writer,
|
||||
success,
|
||||
@@ -1315,12 +1475,14 @@ impl RunningClientHandler {
|
||||
route_runtime.subscribe(),
|
||||
route_snapshot,
|
||||
session_id,
|
||||
local_addr,
|
||||
shared.clone(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
} else {
|
||||
// Direct mode (original behavior)
|
||||
handle_via_direct(
|
||||
handle_via_direct_with_shared(
|
||||
client_reader,
|
||||
client_writer,
|
||||
success,
|
||||
@@ -1332,6 +1494,8 @@ impl RunningClientHandler {
|
||||
route_runtime.subscribe(),
|
||||
route_snapshot,
|
||||
session_id,
|
||||
local_addr,
|
||||
shared.clone(),
|
||||
)
|
||||
.await
|
||||
};
|
||||
|
||||
+106
-2
@@ -6,6 +6,7 @@ use std::net::SocketAddr;
|
||||
use std::path::{Component, Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, split};
|
||||
use tokio::sync::watch;
|
||||
@@ -16,11 +17,13 @@ use crate::crypto::SecureRandom;
|
||||
use crate::error::{ProxyError, Result};
|
||||
use crate::protocol::constants::*;
|
||||
use crate::proxy::handshake::{HandshakeSuccess, encrypt_tg_nonce_with_ciphers, generate_tg_nonce};
|
||||
use crate::proxy::relay::relay_bidirectional;
|
||||
use crate::proxy::route_mode::{
|
||||
ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state,
|
||||
cutover_stagger_delay,
|
||||
};
|
||||
use crate::proxy::shared_state::{
|
||||
ConntrackCloseEvent, ConntrackClosePublishResult, ConntrackCloseReason, ProxySharedState,
|
||||
};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
||||
use crate::transport::UpstreamManager;
|
||||
@@ -225,7 +228,43 @@ fn unknown_dc_test_lock() -> &'static Mutex<()> {
|
||||
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn handle_via_direct<R, W>(
|
||||
client_reader: CryptoReader<R>,
|
||||
client_writer: CryptoWriter<W>,
|
||||
success: HandshakeSuccess,
|
||||
upstream_manager: Arc<UpstreamManager>,
|
||||
stats: Arc<Stats>,
|
||||
config: Arc<ProxyConfig>,
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
route_rx: watch::Receiver<RouteCutoverState>,
|
||||
route_snapshot: RouteCutoverState,
|
||||
session_id: u64,
|
||||
) -> Result<()>
|
||||
where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
handle_via_direct_with_shared(
|
||||
client_reader,
|
||||
client_writer,
|
||||
success,
|
||||
upstream_manager,
|
||||
stats,
|
||||
config.clone(),
|
||||
buffer_pool,
|
||||
rng,
|
||||
route_rx,
|
||||
route_snapshot,
|
||||
session_id,
|
||||
SocketAddr::from(([0, 0, 0, 0], config.server.port)),
|
||||
ProxySharedState::new(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_via_direct_with_shared<R, W>(
|
||||
client_reader: CryptoReader<R>,
|
||||
client_writer: CryptoWriter<W>,
|
||||
success: HandshakeSuccess,
|
||||
@@ -237,6 +276,8 @@ pub(crate) async fn handle_via_direct<R, W>(
|
||||
mut route_rx: watch::Receiver<RouteCutoverState>,
|
||||
route_snapshot: RouteCutoverState,
|
||||
session_id: u64,
|
||||
local_addr: SocketAddr,
|
||||
shared: Arc<ProxySharedState>,
|
||||
) -> Result<()>
|
||||
where
|
||||
R: AsyncRead + Unpin + Send + 'static,
|
||||
@@ -276,7 +317,19 @@ where
|
||||
stats.increment_user_connects(user);
|
||||
let _direct_connection_lease = stats.acquire_direct_connection_lease();
|
||||
|
||||
let relay_result = relay_bidirectional(
|
||||
let buffer_pool_trim = Arc::clone(&buffer_pool);
|
||||
let relay_activity_timeout = if shared.conntrack_pressure_active() {
|
||||
Duration::from_secs(
|
||||
config
|
||||
.server
|
||||
.conntrack_control
|
||||
.profile
|
||||
.direct_activity_timeout_secs(),
|
||||
)
|
||||
} else {
|
||||
Duration::from_secs(1800)
|
||||
};
|
||||
let relay_result = crate::proxy::relay::relay_bidirectional_with_activity_timeout(
|
||||
client_reader,
|
||||
client_writer,
|
||||
tg_reader,
|
||||
@@ -287,6 +340,7 @@ where
|
||||
Arc::clone(&stats),
|
||||
config.access.user_data_quota.get(user).copied(),
|
||||
buffer_pool,
|
||||
relay_activity_timeout,
|
||||
);
|
||||
tokio::pin!(relay_result);
|
||||
let relay_result = loop {
|
||||
@@ -321,9 +375,59 @@ where
|
||||
Err(e) => debug!(user = %user, error = %e, "Direct relay ended with error"),
|
||||
}
|
||||
|
||||
buffer_pool_trim.trim_to(buffer_pool_trim.max_buffers().min(64));
|
||||
let pool_snapshot = buffer_pool_trim.stats();
|
||||
stats.set_buffer_pool_gauges(
|
||||
pool_snapshot.pooled,
|
||||
pool_snapshot.allocated,
|
||||
pool_snapshot.allocated.saturating_sub(pool_snapshot.pooled),
|
||||
);
|
||||
|
||||
let close_reason = classify_conntrack_close_reason(&relay_result);
|
||||
let publish_result = shared.publish_conntrack_close_event(ConntrackCloseEvent {
|
||||
src: success.peer,
|
||||
dst: local_addr,
|
||||
reason: close_reason,
|
||||
});
|
||||
if !matches!(
|
||||
publish_result,
|
||||
ConntrackClosePublishResult::Sent | ConntrackClosePublishResult::Disabled
|
||||
) {
|
||||
stats.increment_conntrack_close_event_drop_total();
|
||||
}
|
||||
|
||||
relay_result
|
||||
}
|
||||
|
||||
fn classify_conntrack_close_reason(result: &Result<()>) -> ConntrackCloseReason {
|
||||
match result {
|
||||
Ok(()) => ConntrackCloseReason::NormalEof,
|
||||
Err(crate::error::ProxyError::Io(error))
|
||||
if matches!(error.kind(), std::io::ErrorKind::TimedOut) =>
|
||||
{
|
||||
ConntrackCloseReason::Timeout
|
||||
}
|
||||
Err(crate::error::ProxyError::Io(error))
|
||||
if matches!(
|
||||
error.kind(),
|
||||
std::io::ErrorKind::ConnectionReset
|
||||
| std::io::ErrorKind::ConnectionAborted
|
||||
| std::io::ErrorKind::BrokenPipe
|
||||
| std::io::ErrorKind::NotConnected
|
||||
| std::io::ErrorKind::UnexpectedEof
|
||||
) =>
|
||||
{
|
||||
ConntrackCloseReason::Reset
|
||||
}
|
||||
Err(crate::error::ProxyError::Proxy(message))
|
||||
if message.contains("pressure") || message.contains("evicted") =>
|
||||
{
|
||||
ConntrackCloseReason::Pressure
|
||||
}
|
||||
Err(_) => ConntrackCloseReason::Other,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
|
||||
let prefer_v6 = config.network.prefer == 6 && config.network.ipv6.unwrap_or(true);
|
||||
let datacenters = if prefer_v6 {
|
||||
|
||||
+1128
-244
File diff suppressed because it is too large
Load Diff
+51
-2
@@ -249,6 +249,43 @@ async fn wait_mask_connect_budget(started: Instant) {
|
||||
}
|
||||
}
|
||||
|
||||
// Log-normal sample bounded to [floor, ceiling]. Median = sqrt(floor * ceiling).
|
||||
// Implements Box-Muller transform for standard normal sampling — no external
|
||||
// dependency on rand_distr (which is incompatible with rand 0.10).
|
||||
// sigma is chosen so ~99% of raw samples land inside [floor, ceiling] before clamp.
|
||||
// When floor > ceiling (misconfiguration), returns ceiling (the smaller value).
|
||||
// When floor == ceiling, returns that value. When both are 0, returns 0.
|
||||
pub(crate) fn sample_lognormal_percentile_bounded(
|
||||
floor: u64,
|
||||
ceiling: u64,
|
||||
rng: &mut impl Rng,
|
||||
) -> u64 {
|
||||
if ceiling == 0 && floor == 0 {
|
||||
return 0;
|
||||
}
|
||||
if floor > ceiling {
|
||||
return ceiling;
|
||||
}
|
||||
if floor == ceiling {
|
||||
return floor;
|
||||
}
|
||||
let floor_f = floor.max(1) as f64;
|
||||
let ceiling_f = ceiling.max(1) as f64;
|
||||
let mu = (floor_f.ln() + ceiling_f.ln()) / 2.0;
|
||||
// 4.65 ≈ 2 * 2.326 (double-sided z-score for 99th percentile)
|
||||
let sigma = ((ceiling_f / floor_f).ln() / 4.65).max(0.01);
|
||||
// Box-Muller transform: two uniform samples → one standard normal sample
|
||||
let u1: f64 = rng.random_range(f64::MIN_POSITIVE..1.0);
|
||||
let u2: f64 = rng.random_range(0.0_f64..std::f64::consts::TAU);
|
||||
let normal_sample = (-2.0_f64 * u1.ln()).sqrt() * u2.cos();
|
||||
let raw = (mu + sigma * normal_sample).exp();
|
||||
if raw.is_finite() {
|
||||
(raw as u64).clamp(floor, ceiling)
|
||||
} else {
|
||||
((floor_f * ceiling_f).sqrt()) as u64
|
||||
}
|
||||
}
|
||||
|
||||
fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration {
|
||||
if config.censorship.mask_timing_normalization_enabled {
|
||||
let floor = config.censorship.mask_timing_normalization_floor_ms;
|
||||
@@ -257,14 +294,18 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration {
|
||||
if ceiling == 0 {
|
||||
return Duration::from_millis(0);
|
||||
}
|
||||
// floor=0 stays uniform: log-normal cannot model distribution anchored at zero
|
||||
let mut rng = rand::rng();
|
||||
return Duration::from_millis(rng.random_range(0..=ceiling));
|
||||
}
|
||||
if ceiling > floor {
|
||||
let mut rng = rand::rng();
|
||||
return Duration::from_millis(rng.random_range(floor..=ceiling));
|
||||
return Duration::from_millis(sample_lognormal_percentile_bounded(
|
||||
floor, ceiling, &mut rng,
|
||||
));
|
||||
}
|
||||
return Duration::from_millis(floor);
|
||||
// ceiling <= floor: use the larger value (fail-closed: preserve longer delay)
|
||||
return Duration::from_millis(floor.max(ceiling));
|
||||
}
|
||||
|
||||
MASK_TIMEOUT
|
||||
@@ -1003,3 +1044,11 @@ mod masking_padding_timeout_adversarial_tests;
|
||||
#[cfg(all(test, feature = "redteam_offline_expected_fail"))]
|
||||
#[path = "tests/masking_offline_target_redteam_expected_fail_tests.rs"]
|
||||
mod masking_offline_target_redteam_expected_fail_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_baseline_invariant_tests.rs"]
|
||||
mod masking_baseline_invariant_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/masking_lognormal_timing_security_tests.rs"]
|
||||
mod masking_lognormal_timing_security_tests;
|
||||
|
||||
+577
-157
File diff suppressed because it is too large
Load Diff
@@ -67,6 +67,7 @@ pub mod middle_relay;
|
||||
pub mod relay;
|
||||
pub mod route_mode;
|
||||
pub mod session_eviction;
|
||||
pub mod shared_state;
|
||||
|
||||
pub use client::ClientHandler;
|
||||
#[allow(unused_imports)]
|
||||
@@ -75,3 +76,15 @@ pub use handshake::*;
|
||||
pub use masking::*;
|
||||
#[allow(unused_imports)]
|
||||
pub use relay::*;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/test_harness_common.rs"]
|
||||
mod test_harness_common;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/proxy_shared_state_isolation_tests.rs"]
|
||||
mod proxy_shared_state_isolation_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/proxy_shared_state_parallel_execution_tests.rs"]
|
||||
mod proxy_shared_state_parallel_execution_tests;
|
||||
|
||||
+186
-32
@@ -70,6 +70,7 @@ use tracing::{debug, trace, warn};
|
||||
///
|
||||
/// iOS keeps Telegram connections alive in background for up to 30 minutes.
|
||||
/// Closing earlier causes unnecessary reconnects and handshake overhead.
|
||||
#[allow(dead_code)]
|
||||
const ACTIVITY_TIMEOUT: Duration = Duration::from_secs(1800);
|
||||
|
||||
/// Watchdog check interval — also used for periodic rate logging.
|
||||
@@ -269,6 +270,7 @@ const QUOTA_NEAR_LIMIT_BYTES: u64 = 64 * 1024;
|
||||
const QUOTA_LARGE_CHARGE_BYTES: u64 = 16 * 1024;
|
||||
const QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES: u64 = 4 * 1024;
|
||||
const QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES: u64 = 64 * 1024;
|
||||
const QUOTA_RESERVE_SPIN_RETRIES: usize = 64;
|
||||
|
||||
#[inline]
|
||||
fn quota_adaptive_interval_bytes(remaining_before: u64) -> u64 {
|
||||
@@ -313,6 +315,50 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
||||
if n > 0 {
|
||||
let n_to_charge = n as u64;
|
||||
|
||||
if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) {
|
||||
let mut reserved_total = None;
|
||||
let mut reserve_rounds = 0usize;
|
||||
while reserved_total.is_none() {
|
||||
for _ in 0..QUOTA_RESERVE_SPIN_RETRIES {
|
||||
match this.user_stats.quota_try_reserve(n_to_charge, limit) {
|
||||
Ok(total) => {
|
||||
reserved_total = Some(total);
|
||||
break;
|
||||
}
|
||||
Err(crate::stats::QuotaReserveError::LimitExceeded) => {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
buf.set_filled(before);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
Err(crate::stats::QuotaReserveError::Contended) => {
|
||||
std::hint::spin_loop();
|
||||
}
|
||||
}
|
||||
}
|
||||
reserve_rounds = reserve_rounds.saturating_add(1);
|
||||
if reserved_total.is_none() && reserve_rounds >= 8 {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
buf.set_filled(before);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
}
|
||||
|
||||
if should_immediate_quota_check(remaining, n_to_charge) {
|
||||
this.quota_bytes_since_check = 0;
|
||||
} else {
|
||||
this.quota_bytes_since_check =
|
||||
this.quota_bytes_since_check.saturating_add(n_to_charge);
|
||||
let interval = quota_adaptive_interval_bytes(remaining);
|
||||
if this.quota_bytes_since_check >= interval {
|
||||
this.quota_bytes_since_check = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if reserved_total.unwrap_or(0) >= limit {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
}
|
||||
}
|
||||
|
||||
// C→S: client sent data
|
||||
this.counters
|
||||
.c2s_bytes
|
||||
@@ -325,27 +371,6 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
||||
this.stats
|
||||
.increment_user_msgs_from_handle(this.user_stats.as_ref());
|
||||
|
||||
if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) {
|
||||
this.stats
|
||||
.quota_charge_post_write(this.user_stats.as_ref(), n_to_charge);
|
||||
if should_immediate_quota_check(remaining, n_to_charge) {
|
||||
this.quota_bytes_since_check = 0;
|
||||
if this.user_stats.quota_used() >= limit {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
}
|
||||
} else {
|
||||
this.quota_bytes_since_check =
|
||||
this.quota_bytes_since_check.saturating_add(n_to_charge);
|
||||
let interval = quota_adaptive_interval_bytes(remaining);
|
||||
if this.quota_bytes_since_check >= interval {
|
||||
this.quota_bytes_since_check = 0;
|
||||
if this.user_stats.quota_used() >= limit {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trace!(user = %this.user, bytes = n, "C->S");
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
@@ -367,18 +392,73 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
||||
}
|
||||
|
||||
let mut remaining_before = None;
|
||||
let mut reserved_bytes = 0u64;
|
||||
let mut write_buf = buf;
|
||||
if let Some(limit) = this.quota_limit {
|
||||
let used_before = this.user_stats.quota_used();
|
||||
let remaining = limit.saturating_sub(used_before);
|
||||
if remaining == 0 {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
if !buf.is_empty() {
|
||||
let mut reserve_rounds = 0usize;
|
||||
while reserved_bytes == 0 {
|
||||
let used_before = this.user_stats.quota_used();
|
||||
let remaining = limit.saturating_sub(used_before);
|
||||
if remaining == 0 {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
remaining_before = Some(remaining);
|
||||
|
||||
let desired = remaining.min(buf.len() as u64);
|
||||
for _ in 0..QUOTA_RESERVE_SPIN_RETRIES {
|
||||
match this.user_stats.quota_try_reserve(desired, limit) {
|
||||
Ok(_) => {
|
||||
reserved_bytes = desired;
|
||||
write_buf = &buf[..desired as usize];
|
||||
break;
|
||||
}
|
||||
Err(crate::stats::QuotaReserveError::LimitExceeded) => {
|
||||
break;
|
||||
}
|
||||
Err(crate::stats::QuotaReserveError::Contended) => {
|
||||
std::hint::spin_loop();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
reserve_rounds = reserve_rounds.saturating_add(1);
|
||||
if reserved_bytes == 0 && reserve_rounds >= 8 {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let used_before = this.user_stats.quota_used();
|
||||
let remaining = limit.saturating_sub(used_before);
|
||||
if remaining == 0 {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
remaining_before = Some(remaining);
|
||||
}
|
||||
remaining_before = Some(remaining);
|
||||
}
|
||||
|
||||
match Pin::new(&mut this.inner).poll_write(cx, buf) {
|
||||
match Pin::new(&mut this.inner).poll_write(cx, write_buf) {
|
||||
Poll::Ready(Ok(n)) => {
|
||||
if reserved_bytes > n as u64 {
|
||||
let refund = reserved_bytes - n as u64;
|
||||
let mut current = this.user_stats.quota_used.load(Ordering::Relaxed);
|
||||
loop {
|
||||
let next = current.saturating_sub(refund);
|
||||
match this.user_stats.quota_used.compare_exchange_weak(
|
||||
current,
|
||||
next,
|
||||
Ordering::Relaxed,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => break,
|
||||
Err(observed) => current = observed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
let n_to_charge = n as u64;
|
||||
|
||||
@@ -395,8 +475,6 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
||||
.increment_user_msgs_to_handle(this.user_stats.as_ref());
|
||||
|
||||
if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) {
|
||||
this.stats
|
||||
.quota_charge_post_write(this.user_stats.as_ref(), n_to_charge);
|
||||
if should_immediate_quota_check(remaining, n_to_charge) {
|
||||
this.quota_bytes_since_check = 0;
|
||||
if this.user_stats.quota_used() >= limit {
|
||||
@@ -419,7 +497,42 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
||||
}
|
||||
Poll::Ready(Ok(n))
|
||||
}
|
||||
other => other,
|
||||
Poll::Ready(Err(err)) => {
|
||||
if reserved_bytes > 0 {
|
||||
let mut current = this.user_stats.quota_used.load(Ordering::Relaxed);
|
||||
loop {
|
||||
let next = current.saturating_sub(reserved_bytes);
|
||||
match this.user_stats.quota_used.compare_exchange_weak(
|
||||
current,
|
||||
next,
|
||||
Ordering::Relaxed,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => break,
|
||||
Err(observed) => current = observed,
|
||||
}
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(err))
|
||||
}
|
||||
Poll::Pending => {
|
||||
if reserved_bytes > 0 {
|
||||
let mut current = this.user_stats.quota_used.load(Ordering::Relaxed);
|
||||
loop {
|
||||
let next = current.saturating_sub(reserved_bytes);
|
||||
match this.user_stats.quota_used.compare_exchange_weak(
|
||||
current,
|
||||
next,
|
||||
Ordering::Relaxed,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => break,
|
||||
Err(observed) => current = observed,
|
||||
}
|
||||
}
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -453,6 +566,7 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
||||
/// - Clean shutdown: both write sides are shut down on exit
|
||||
/// - Error propagation: quota exits return `ProxyError::DataQuotaExceeded`,
|
||||
/// other I/O failures are returned as `ProxyError::Io`
|
||||
#[allow(dead_code)]
|
||||
pub async fn relay_bidirectional<CR, CW, SR, SW>(
|
||||
client_reader: CR,
|
||||
client_writer: CW,
|
||||
@@ -471,6 +585,42 @@ where
|
||||
SR: AsyncRead + Unpin + Send + 'static,
|
||||
SW: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
relay_bidirectional_with_activity_timeout(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
c2s_buf_size,
|
||||
s2c_buf_size,
|
||||
user,
|
||||
stats,
|
||||
quota_limit,
|
||||
_buffer_pool,
|
||||
ACTIVITY_TIMEOUT,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn relay_bidirectional_with_activity_timeout<CR, CW, SR, SW>(
|
||||
client_reader: CR,
|
||||
client_writer: CW,
|
||||
server_reader: SR,
|
||||
server_writer: SW,
|
||||
c2s_buf_size: usize,
|
||||
s2c_buf_size: usize,
|
||||
user: &str,
|
||||
stats: Arc<Stats>,
|
||||
quota_limit: Option<u64>,
|
||||
_buffer_pool: Arc<BufferPool>,
|
||||
activity_timeout: Duration,
|
||||
) -> Result<()>
|
||||
where
|
||||
CR: AsyncRead + Unpin + Send + 'static,
|
||||
CW: AsyncWrite + Unpin + Send + 'static,
|
||||
SR: AsyncRead + Unpin + Send + 'static,
|
||||
SW: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
let activity_timeout = activity_timeout.max(Duration::from_secs(1));
|
||||
let epoch = Instant::now();
|
||||
let counters = Arc::new(SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||
@@ -512,7 +662,7 @@ where
|
||||
}
|
||||
|
||||
// ── Activity timeout ────────────────────────────────────
|
||||
if idle >= ACTIVITY_TIMEOUT {
|
||||
if idle >= activity_timeout {
|
||||
let c2s = wd_counters.c2s_bytes.load(Ordering::Relaxed);
|
||||
let s2c = wd_counters.s2c_bytes.load(Ordering::Relaxed);
|
||||
warn!(
|
||||
@@ -671,3 +821,7 @@ mod relay_watchdog_delta_security_tests;
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_atomic_quota_invariant_tests.rs"]
|
||||
mod relay_atomic_quota_invariant_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_baseline_invariant_tests.rs"]
|
||||
mod relay_baseline_invariant_tests;
|
||||
|
||||
@@ -0,0 +1,165 @@
|
||||
use std::collections::HashSet;
|
||||
use std::collections::hash_map::RandomState;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Instant;
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::proxy::handshake::{AuthProbeSaturationState, AuthProbeState};
|
||||
use crate::proxy::middle_relay::{DesyncDedupRotationState, RelayIdleCandidateRegistry};
|
||||
|
||||
const HANDSHAKE_RECENT_USER_RING_LEN: usize = 64;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum ConntrackCloseReason {
|
||||
NormalEof,
|
||||
Timeout,
|
||||
Pressure,
|
||||
Reset,
|
||||
Other,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub(crate) struct ConntrackCloseEvent {
|
||||
pub(crate) src: SocketAddr,
|
||||
pub(crate) dst: SocketAddr,
|
||||
pub(crate) reason: ConntrackCloseReason,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum ConntrackClosePublishResult {
|
||||
Sent,
|
||||
Disabled,
|
||||
QueueFull,
|
||||
QueueClosed,
|
||||
}
|
||||
|
||||
pub(crate) struct HandshakeSharedState {
|
||||
pub(crate) auth_probe: DashMap<IpAddr, AuthProbeState>,
|
||||
pub(crate) auth_probe_saturation: Mutex<Option<AuthProbeSaturationState>>,
|
||||
pub(crate) auth_probe_eviction_hasher: RandomState,
|
||||
pub(crate) invalid_secret_warned: Mutex<HashSet<(String, String)>>,
|
||||
pub(crate) unknown_sni_warn_next_allowed: Mutex<Option<Instant>>,
|
||||
pub(crate) sticky_user_by_ip: DashMap<IpAddr, u32>,
|
||||
pub(crate) sticky_user_by_ip_prefix: DashMap<u64, u32>,
|
||||
pub(crate) sticky_user_by_sni_hash: DashMap<u64, u32>,
|
||||
pub(crate) recent_user_ring: Box<[AtomicU32]>,
|
||||
pub(crate) recent_user_ring_seq: AtomicU64,
|
||||
pub(crate) auth_expensive_checks_total: AtomicU64,
|
||||
pub(crate) auth_budget_exhausted_total: AtomicU64,
|
||||
}
|
||||
|
||||
pub(crate) struct MiddleRelaySharedState {
|
||||
pub(crate) desync_dedup: DashMap<u64, Instant>,
|
||||
pub(crate) desync_dedup_previous: DashMap<u64, Instant>,
|
||||
pub(crate) desync_hasher: RandomState,
|
||||
pub(crate) desync_full_cache_last_emit_at: Mutex<Option<Instant>>,
|
||||
pub(crate) desync_dedup_rotation_state: Mutex<DesyncDedupRotationState>,
|
||||
pub(crate) relay_idle_registry: Mutex<RelayIdleCandidateRegistry>,
|
||||
pub(crate) relay_idle_mark_seq: AtomicU64,
|
||||
}
|
||||
|
||||
pub(crate) struct ProxySharedState {
|
||||
pub(crate) handshake: HandshakeSharedState,
|
||||
pub(crate) middle_relay: MiddleRelaySharedState,
|
||||
pub(crate) conntrack_pressure_active: AtomicBool,
|
||||
pub(crate) conntrack_close_tx: Mutex<Option<mpsc::Sender<ConntrackCloseEvent>>>,
|
||||
}
|
||||
|
||||
impl ProxySharedState {
|
||||
pub(crate) fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
handshake: HandshakeSharedState {
|
||||
auth_probe: DashMap::new(),
|
||||
auth_probe_saturation: Mutex::new(None),
|
||||
auth_probe_eviction_hasher: RandomState::new(),
|
||||
invalid_secret_warned: Mutex::new(HashSet::new()),
|
||||
unknown_sni_warn_next_allowed: Mutex::new(None),
|
||||
sticky_user_by_ip: DashMap::new(),
|
||||
sticky_user_by_ip_prefix: DashMap::new(),
|
||||
sticky_user_by_sni_hash: DashMap::new(),
|
||||
recent_user_ring: std::iter::repeat_with(|| AtomicU32::new(0))
|
||||
.take(HANDSHAKE_RECENT_USER_RING_LEN)
|
||||
.collect::<Vec<_>>()
|
||||
.into_boxed_slice(),
|
||||
recent_user_ring_seq: AtomicU64::new(0),
|
||||
auth_expensive_checks_total: AtomicU64::new(0),
|
||||
auth_budget_exhausted_total: AtomicU64::new(0),
|
||||
},
|
||||
middle_relay: MiddleRelaySharedState {
|
||||
desync_dedup: DashMap::new(),
|
||||
desync_dedup_previous: DashMap::new(),
|
||||
desync_hasher: RandomState::new(),
|
||||
desync_full_cache_last_emit_at: Mutex::new(None),
|
||||
desync_dedup_rotation_state: Mutex::new(DesyncDedupRotationState::default()),
|
||||
relay_idle_registry: Mutex::new(RelayIdleCandidateRegistry::default()),
|
||||
relay_idle_mark_seq: AtomicU64::new(0),
|
||||
},
|
||||
conntrack_pressure_active: AtomicBool::new(false),
|
||||
conntrack_close_tx: Mutex::new(None),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn set_conntrack_close_sender(&self, tx: mpsc::Sender<ConntrackCloseEvent>) {
|
||||
match self.conntrack_close_tx.lock() {
|
||||
Ok(mut guard) => {
|
||||
*guard = Some(tx);
|
||||
}
|
||||
Err(poisoned) => {
|
||||
let mut guard = poisoned.into_inner();
|
||||
*guard = Some(tx);
|
||||
self.conntrack_close_tx.clear_poison();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn disable_conntrack_close_sender(&self) {
|
||||
match self.conntrack_close_tx.lock() {
|
||||
Ok(mut guard) => {
|
||||
*guard = None;
|
||||
}
|
||||
Err(poisoned) => {
|
||||
let mut guard = poisoned.into_inner();
|
||||
*guard = None;
|
||||
self.conntrack_close_tx.clear_poison();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn publish_conntrack_close_event(
|
||||
&self,
|
||||
event: ConntrackCloseEvent,
|
||||
) -> ConntrackClosePublishResult {
|
||||
let tx = match self.conntrack_close_tx.lock() {
|
||||
Ok(guard) => guard.clone(),
|
||||
Err(poisoned) => {
|
||||
let guard = poisoned.into_inner();
|
||||
let cloned = guard.clone();
|
||||
self.conntrack_close_tx.clear_poison();
|
||||
cloned
|
||||
}
|
||||
};
|
||||
|
||||
let Some(tx) = tx else {
|
||||
return ConntrackClosePublishResult::Disabled;
|
||||
};
|
||||
|
||||
match tx.try_send(event) {
|
||||
Ok(()) => ConntrackClosePublishResult::Sent,
|
||||
Err(mpsc::error::TrySendError::Full(_)) => ConntrackClosePublishResult::QueueFull,
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => ConntrackClosePublishResult::QueueClosed,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set_conntrack_pressure_active(&self, active: bool) {
|
||||
self.conntrack_pressure_active
|
||||
.store(active, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub(crate) fn conntrack_pressure_active(&self) -> bool {
|
||||
self.conntrack_pressure_active.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,260 @@
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
static RACE_TEST_KEY_COUNTER: AtomicUsize = AtomicUsize::new(1_000_000);
|
||||
|
||||
fn race_unique_key(prefix: &str) -> String {
|
||||
let id = RACE_TEST_KEY_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
format!("{}_{}", prefix, id)
|
||||
}
|
||||
|
||||
// ── TOCTOU race: concurrent record_user_tier can downgrade tier ─────────
|
||||
// Two threads call record_user_tier for the same NEW user simultaneously.
|
||||
// Thread A records Tier1, Thread B records Base. Without atomic entry API,
|
||||
// the insert() call overwrites without max(), causing Tier1 → Base downgrade.
|
||||
|
||||
#[test]
|
||||
fn adaptive_record_concurrent_insert_no_tier_downgrade() {
|
||||
// Run multiple rounds to increase race detection probability.
|
||||
for round in 0..50 {
|
||||
let key = race_unique_key(&format!("race_downgrade_{}", round));
|
||||
let key_a = key.clone();
|
||||
let key_b = key.clone();
|
||||
|
||||
let barrier = Arc::new(std::sync::Barrier::new(2));
|
||||
let barrier_a = Arc::clone(&barrier);
|
||||
let barrier_b = Arc::clone(&barrier);
|
||||
|
||||
let ha = std::thread::spawn(move || {
|
||||
barrier_a.wait();
|
||||
record_user_tier(&key_a, AdaptiveTier::Tier2);
|
||||
});
|
||||
|
||||
let hb = std::thread::spawn(move || {
|
||||
barrier_b.wait();
|
||||
record_user_tier(&key_b, AdaptiveTier::Base);
|
||||
});
|
||||
|
||||
ha.join().expect("thread A panicked");
|
||||
hb.join().expect("thread B panicked");
|
||||
|
||||
let result = seed_tier_for_user(&key);
|
||||
profiles().remove(&key);
|
||||
|
||||
// The final tier must be at least Tier2, never downgraded to Base.
|
||||
// With correct max() semantics: max(Tier2, Base) = Tier2.
|
||||
assert!(
|
||||
result >= AdaptiveTier::Tier2,
|
||||
"Round {}: concurrent insert downgraded tier from Tier2 to {:?}",
|
||||
round,
|
||||
result,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── TOCTOU race: three threads write three tiers, highest must survive ──
|
||||
|
||||
#[test]
|
||||
fn adaptive_record_triple_concurrent_insert_highest_tier_survives() {
|
||||
for round in 0..30 {
|
||||
let key = race_unique_key(&format!("triple_race_{}", round));
|
||||
let barrier = Arc::new(std::sync::Barrier::new(3));
|
||||
|
||||
let handles: Vec<_> = [AdaptiveTier::Base, AdaptiveTier::Tier1, AdaptiveTier::Tier3]
|
||||
.into_iter()
|
||||
.map(|tier| {
|
||||
let k = key.clone();
|
||||
let b = Arc::clone(&barrier);
|
||||
std::thread::spawn(move || {
|
||||
b.wait();
|
||||
record_user_tier(&k, tier);
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for h in handles {
|
||||
h.join().expect("thread panicked");
|
||||
}
|
||||
|
||||
let result = seed_tier_for_user(&key);
|
||||
profiles().remove(&key);
|
||||
|
||||
assert!(
|
||||
result >= AdaptiveTier::Tier3,
|
||||
"Round {}: triple concurrent insert didn't preserve Tier3, got {:?}",
|
||||
round,
|
||||
result,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Stress: 20 threads writing different tiers to same key ──────────────
|
||||
|
||||
#[test]
|
||||
fn adaptive_record_20_concurrent_writers_no_panic_no_downgrade() {
|
||||
let key = race_unique_key("stress_20");
|
||||
let barrier = Arc::new(std::sync::Barrier::new(20));
|
||||
|
||||
let handles: Vec<_> = (0..20u32)
|
||||
.map(|i| {
|
||||
let k = key.clone();
|
||||
let b = Arc::clone(&barrier);
|
||||
std::thread::spawn(move || {
|
||||
b.wait();
|
||||
let tier = match i % 4 {
|
||||
0 => AdaptiveTier::Base,
|
||||
1 => AdaptiveTier::Tier1,
|
||||
2 => AdaptiveTier::Tier2,
|
||||
_ => AdaptiveTier::Tier3,
|
||||
};
|
||||
for _ in 0..100 {
|
||||
record_user_tier(&k, tier);
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for h in handles {
|
||||
h.join().expect("thread panicked");
|
||||
}
|
||||
|
||||
let result = seed_tier_for_user(&key);
|
||||
profiles().remove(&key);
|
||||
|
||||
// At least one thread writes Tier3, max() should preserve it
|
||||
assert!(
|
||||
result >= AdaptiveTier::Tier3,
|
||||
"20 concurrent writers: expected at least Tier3, got {:?}",
|
||||
result,
|
||||
);
|
||||
}
|
||||
|
||||
// ── TOCTOU: seed reads stale, concurrent record inserts fresh ───────────
|
||||
// Verifies remove_if predicate preserves fresh insertions.
|
||||
|
||||
#[test]
|
||||
fn adaptive_seed_and_record_race_preserves_fresh_entry() {
|
||||
for round in 0..30 {
|
||||
let key = race_unique_key(&format!("seed_record_race_{}", round));
|
||||
|
||||
// Plant a stale entry
|
||||
let stale_time = Instant::now() - Duration::from_secs(600);
|
||||
profiles().insert(
|
||||
key.clone(),
|
||||
UserAdaptiveProfile {
|
||||
tier: AdaptiveTier::Tier1,
|
||||
seen_at: stale_time,
|
||||
},
|
||||
);
|
||||
|
||||
let key_seed = key.clone();
|
||||
let key_record = key.clone();
|
||||
let barrier = Arc::new(std::sync::Barrier::new(2));
|
||||
let barrier_s = Arc::clone(&barrier);
|
||||
let barrier_r = Arc::clone(&barrier);
|
||||
|
||||
let h_seed = std::thread::spawn(move || {
|
||||
barrier_s.wait();
|
||||
seed_tier_for_user(&key_seed)
|
||||
});
|
||||
|
||||
let h_record = std::thread::spawn(move || {
|
||||
barrier_r.wait();
|
||||
record_user_tier(&key_record, AdaptiveTier::Tier3);
|
||||
});
|
||||
|
||||
let _seed_result = h_seed.join().expect("seed thread panicked");
|
||||
h_record.join().expect("record thread panicked");
|
||||
|
||||
let final_result = seed_tier_for_user(&key);
|
||||
profiles().remove(&key);
|
||||
|
||||
// Fresh Tier3 entry should survive the stale-removal race.
|
||||
// Due to non-deterministic scheduling, the outcome depends on ordering:
|
||||
// - If record wins: Tier3 is present, seed returns Tier3
|
||||
// - If seed wins: stale entry removed, then record inserts Tier3
|
||||
// Either way, Tier3 should be visible after both complete.
|
||||
assert!(
|
||||
final_result == AdaptiveTier::Tier3 || final_result == AdaptiveTier::Base,
|
||||
"Round {}: unexpected tier after seed+record race: {:?}",
|
||||
round,
|
||||
final_result,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Eviction safety: retain() during concurrent inserts ─────────────────
|
||||
|
||||
#[test]
|
||||
fn adaptive_eviction_during_concurrent_inserts_no_panic() {
|
||||
let prefix = race_unique_key("evict_conc");
|
||||
let stale_time = Instant::now() - Duration::from_secs(600);
|
||||
|
||||
// Pre-fill with stale entries to push past the eviction threshold
|
||||
for i in 0..100 {
|
||||
let k = format!("{}_{}", prefix, i);
|
||||
profiles().insert(
|
||||
k,
|
||||
UserAdaptiveProfile {
|
||||
tier: AdaptiveTier::Base,
|
||||
seen_at: stale_time,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
let barrier = Arc::new(std::sync::Barrier::new(10));
|
||||
let handles: Vec<_> = (0..10)
|
||||
.map(|t| {
|
||||
let b = Arc::clone(&barrier);
|
||||
let pfx = prefix.clone();
|
||||
std::thread::spawn(move || {
|
||||
b.wait();
|
||||
for i in 0..50 {
|
||||
let k = format!("{}_t{}_{}", pfx, t, i);
|
||||
record_user_tier(&k, AdaptiveTier::Tier1);
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for h in handles {
|
||||
h.join().expect("eviction thread panicked");
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
profiles().retain(|k, _| !k.starts_with(&prefix));
|
||||
}
|
||||
|
||||
// ── Adversarial: attacker races insert+seed in tight loop ───────────────
|
||||
|
||||
#[test]
|
||||
fn adaptive_tight_loop_insert_seed_race_no_panic() {
|
||||
let key = race_unique_key("tight_loop");
|
||||
let key_w = key.clone();
|
||||
let key_r = key.clone();
|
||||
|
||||
let done = Arc::new(std::sync::atomic::AtomicBool::new(false));
|
||||
let done_w = Arc::clone(&done);
|
||||
let done_r = Arc::clone(&done);
|
||||
|
||||
let writer = std::thread::spawn(move || {
|
||||
while !done_w.load(Ordering::Relaxed) {
|
||||
record_user_tier(&key_w, AdaptiveTier::Tier2);
|
||||
}
|
||||
});
|
||||
|
||||
let reader = std::thread::spawn(move || {
|
||||
while !done_r.load(Ordering::Relaxed) {
|
||||
let _ = seed_tier_for_user(&key_r);
|
||||
}
|
||||
});
|
||||
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
done.store(true, Ordering::Relaxed);
|
||||
|
||||
writer.join().expect("writer panicked");
|
||||
reader.join().expect("reader panicked");
|
||||
profiles().remove(&key);
|
||||
}
|
||||
@@ -0,0 +1,453 @@
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
// Unique key generator to avoid test interference through the global DashMap.
|
||||
static TEST_KEY_COUNTER: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
fn unique_key(prefix: &str) -> String {
|
||||
let id = TEST_KEY_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
format!("{}_{}", prefix, id)
|
||||
}
|
||||
|
||||
// ── Positive / Lifecycle ────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn adaptive_seed_unknown_user_returns_base() {
|
||||
let key = unique_key("seed_unknown");
|
||||
assert_eq!(seed_tier_for_user(&key), AdaptiveTier::Base);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_record_then_seed_returns_recorded_tier() {
|
||||
let key = unique_key("record_seed");
|
||||
record_user_tier(&key, AdaptiveTier::Tier1);
|
||||
assert_eq!(seed_tier_for_user(&key), AdaptiveTier::Tier1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_separate_users_have_independent_tiers() {
|
||||
let key_a = unique_key("indep_a");
|
||||
let key_b = unique_key("indep_b");
|
||||
record_user_tier(&key_a, AdaptiveTier::Tier1);
|
||||
record_user_tier(&key_b, AdaptiveTier::Tier2);
|
||||
assert_eq!(seed_tier_for_user(&key_a), AdaptiveTier::Tier1);
|
||||
assert_eq!(seed_tier_for_user(&key_b), AdaptiveTier::Tier2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_record_upgrades_tier_within_ttl() {
|
||||
let key = unique_key("upgrade");
|
||||
record_user_tier(&key, AdaptiveTier::Base);
|
||||
record_user_tier(&key, AdaptiveTier::Tier1);
|
||||
assert_eq!(seed_tier_for_user(&key), AdaptiveTier::Tier1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_record_does_not_downgrade_within_ttl() {
|
||||
let key = unique_key("no_downgrade");
|
||||
record_user_tier(&key, AdaptiveTier::Tier2);
|
||||
record_user_tier(&key, AdaptiveTier::Base);
|
||||
// max(Tier2, Base) = Tier2 — within TTL the higher tier is retained
|
||||
assert_eq!(seed_tier_for_user(&key), AdaptiveTier::Tier2);
|
||||
}
|
||||
|
||||
// ── Edge Cases ──────────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn adaptive_base_tier_buffers_unchanged() {
|
||||
let (c2s, s2c) = direct_copy_buffers_for_tier(AdaptiveTier::Base, 65536, 262144);
|
||||
assert_eq!(c2s, 65536);
|
||||
assert_eq!(s2c, 262144);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_tier1_buffers_within_caps() {
|
||||
let (c2s, s2c) = direct_copy_buffers_for_tier(AdaptiveTier::Tier1, 65536, 262144);
|
||||
assert!(c2s > 65536, "Tier1 c2s should exceed Base");
|
||||
assert!(
|
||||
c2s <= 128 * 1024,
|
||||
"Tier1 c2s should not exceed DIRECT_C2S_CAP_BYTES"
|
||||
);
|
||||
assert!(s2c > 262144, "Tier1 s2c should exceed Base");
|
||||
assert!(
|
||||
s2c <= 512 * 1024,
|
||||
"Tier1 s2c should not exceed DIRECT_S2C_CAP_BYTES"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_tier3_buffers_capped() {
|
||||
let (c2s, s2c) = direct_copy_buffers_for_tier(AdaptiveTier::Tier3, 65536, 262144);
|
||||
assert!(c2s <= 128 * 1024, "Tier3 c2s must not exceed cap");
|
||||
assert!(s2c <= 512 * 1024, "Tier3 s2c must not exceed cap");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_scale_zero_base_returns_at_least_one() {
|
||||
// scale(0, num, den, cap) should return at least 1 (the .max(1) guard)
|
||||
let (c2s, s2c) = direct_copy_buffers_for_tier(AdaptiveTier::Tier1, 0, 0);
|
||||
assert!(c2s >= 1);
|
||||
assert!(s2c >= 1);
|
||||
}
|
||||
|
||||
// ── Stale Entry Handling ────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn adaptive_stale_profile_returns_base_tier() {
|
||||
let key = unique_key("stale_base");
|
||||
// Manually insert a stale entry with seen_at in the far past.
|
||||
// PROFILE_TTL = 300s, so 600s ago is well past expiry.
|
||||
let stale_time = Instant::now() - Duration::from_secs(600);
|
||||
profiles().insert(
|
||||
key.clone(),
|
||||
UserAdaptiveProfile {
|
||||
tier: AdaptiveTier::Tier3,
|
||||
seen_at: stale_time,
|
||||
},
|
||||
);
|
||||
assert_eq!(
|
||||
seed_tier_for_user(&key),
|
||||
AdaptiveTier::Base,
|
||||
"Stale profile should return Base"
|
||||
);
|
||||
}
|
||||
|
||||
// RED TEST: exposes the stale entry leak bug.
|
||||
// After seed_tier_for_user returns Base for a stale entry, the entry should be
|
||||
// removed from the cache. Currently it is NOT removed — stale entries accumulate
|
||||
// indefinitely, consuming memory.
|
||||
#[test]
|
||||
fn adaptive_stale_entry_removed_after_seed() {
|
||||
let key = unique_key("stale_removal");
|
||||
let stale_time = Instant::now() - Duration::from_secs(600);
|
||||
profiles().insert(
|
||||
key.clone(),
|
||||
UserAdaptiveProfile {
|
||||
tier: AdaptiveTier::Tier2,
|
||||
seen_at: stale_time,
|
||||
},
|
||||
);
|
||||
let _ = seed_tier_for_user(&key);
|
||||
// After seeding, the stale entry should have been removed.
|
||||
assert!(
|
||||
!profiles().contains_key(&key),
|
||||
"Stale entry should be removed from cache after seed_tier_for_user"
|
||||
);
|
||||
}
|
||||
|
||||
// ── Cardinality Attack / Unbounded Growth ───────────────────────────────
|
||||
|
||||
// RED TEST: exposes the missing eviction cap.
|
||||
// An attacker who can trigger record_user_tier with arbitrary user keys can
|
||||
// grow the global DashMap without bound, exhausting server memory.
|
||||
// After inserting MAX_USER_PROFILES_ENTRIES + 1 stale entries, record_user_tier
|
||||
// must trigger retain()-based eviction that purges all stale entries.
|
||||
#[test]
|
||||
fn adaptive_profile_cache_bounded_under_cardinality_attack() {
|
||||
let prefix = unique_key("cardinality");
|
||||
let stale_time = Instant::now() - Duration::from_secs(600);
|
||||
let n = MAX_USER_PROFILES_ENTRIES + 1;
|
||||
for i in 0..n {
|
||||
let key = format!("{}_{}", prefix, i);
|
||||
profiles().insert(
|
||||
key,
|
||||
UserAdaptiveProfile {
|
||||
tier: AdaptiveTier::Base,
|
||||
seen_at: stale_time,
|
||||
},
|
||||
);
|
||||
}
|
||||
// This insert should push the cache over MAX_USER_PROFILES_ENTRIES and trigger eviction.
|
||||
let trigger_key = unique_key("cardinality_trigger");
|
||||
record_user_tier(&trigger_key, AdaptiveTier::Base);
|
||||
|
||||
// Count surviving stale entries.
|
||||
let mut surviving_stale = 0;
|
||||
for i in 0..n {
|
||||
let key = format!("{}_{}", prefix, i);
|
||||
if profiles().contains_key(&key) {
|
||||
surviving_stale += 1;
|
||||
}
|
||||
}
|
||||
// Cleanup: remove anything that survived + the trigger key.
|
||||
for i in 0..n {
|
||||
let key = format!("{}_{}", prefix, i);
|
||||
profiles().remove(&key);
|
||||
}
|
||||
profiles().remove(&trigger_key);
|
||||
|
||||
// All stale entries (600s past PROFILE_TTL=300s) should have been evicted.
|
||||
assert_eq!(
|
||||
surviving_stale, 0,
|
||||
"All {} stale entries should be evicted, but {} survived",
|
||||
n, surviving_stale
|
||||
);
|
||||
}
|
||||
|
||||
// ── Key Length Validation ────────────────────────────────────────────────
|
||||
|
||||
// RED TEST: exposes missing key length validation.
|
||||
// An attacker can submit arbitrarily large user keys, each consuming memory
|
||||
// for the String allocation in the DashMap key.
|
||||
#[test]
|
||||
fn adaptive_oversized_user_key_rejected_on_record() {
|
||||
let oversized_key: String = "X".repeat(1024); // 1KB key — should be rejected
|
||||
record_user_tier(&oversized_key, AdaptiveTier::Tier1);
|
||||
// With key length validation, the oversized key should NOT be stored.
|
||||
let stored = profiles().contains_key(&oversized_key);
|
||||
// Cleanup regardless
|
||||
profiles().remove(&oversized_key);
|
||||
assert!(
|
||||
!stored,
|
||||
"Oversized user key (1024 bytes) should be rejected by record_user_tier"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_oversized_user_key_rejected_on_seed() {
|
||||
let oversized_key: String = "X".repeat(1024);
|
||||
// Insert it directly to test seed behavior
|
||||
profiles().insert(
|
||||
oversized_key.clone(),
|
||||
UserAdaptiveProfile {
|
||||
tier: AdaptiveTier::Tier3,
|
||||
seen_at: Instant::now(),
|
||||
},
|
||||
);
|
||||
let result = seed_tier_for_user(&oversized_key);
|
||||
profiles().remove(&oversized_key);
|
||||
assert_eq!(
|
||||
result,
|
||||
AdaptiveTier::Base,
|
||||
"Oversized user key should return Base from seed_tier_for_user"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_empty_user_key_safe() {
|
||||
// Empty string is a valid (if unusual) key — should not panic
|
||||
record_user_tier("", AdaptiveTier::Tier1);
|
||||
let tier = seed_tier_for_user("");
|
||||
profiles().remove("");
|
||||
assert_eq!(tier, AdaptiveTier::Tier1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_max_length_key_accepted() {
|
||||
// A key at exactly 512 bytes should be accepted
|
||||
let key: String = "K".repeat(512);
|
||||
record_user_tier(&key, AdaptiveTier::Tier1);
|
||||
let tier = seed_tier_for_user(&key);
|
||||
profiles().remove(&key);
|
||||
assert_eq!(tier, AdaptiveTier::Tier1);
|
||||
}
|
||||
|
||||
// ── Concurrent Access Safety ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn adaptive_concurrent_record_and_seed_no_torn_read() {
|
||||
let key = unique_key("concurrent_rw");
|
||||
let key_clone = key.clone();
|
||||
|
||||
// Record from multiple threads simultaneously
|
||||
let handles: Vec<_> = (0..10)
|
||||
.map(|i| {
|
||||
let k = key_clone.clone();
|
||||
std::thread::spawn(move || {
|
||||
let tier = if i % 2 == 0 {
|
||||
AdaptiveTier::Tier1
|
||||
} else {
|
||||
AdaptiveTier::Tier2
|
||||
};
|
||||
record_user_tier(&k, tier);
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for h in handles {
|
||||
h.join().expect("thread panicked");
|
||||
}
|
||||
|
||||
let result = seed_tier_for_user(&key);
|
||||
profiles().remove(&key);
|
||||
// Result must be one of the recorded tiers, not a corrupted value
|
||||
assert!(
|
||||
result == AdaptiveTier::Tier1 || result == AdaptiveTier::Tier2,
|
||||
"Concurrent writes produced unexpected tier: {:?}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_concurrent_seed_does_not_panic() {
|
||||
let key = unique_key("concurrent_seed");
|
||||
record_user_tier(&key, AdaptiveTier::Tier1);
|
||||
let key_clone = key.clone();
|
||||
|
||||
let handles: Vec<_> = (0..20)
|
||||
.map(|_| {
|
||||
let k = key_clone.clone();
|
||||
std::thread::spawn(move || {
|
||||
for _ in 0..100 {
|
||||
let _ = seed_tier_for_user(&k);
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for h in handles {
|
||||
h.join().expect("concurrent seed panicked");
|
||||
}
|
||||
profiles().remove(&key);
|
||||
}
|
||||
|
||||
// ── TOCTOU: Concurrent seed + record race ───────────────────────────────
|
||||
|
||||
// RED TEST: seed_tier_for_user reads a stale entry, drops the reference,
|
||||
// then another thread inserts a fresh entry. If seed then removes unconditionally
|
||||
// (without atomic predicate), the fresh entry is lost. With remove_if, the
|
||||
// fresh entry survives.
|
||||
#[test]
|
||||
fn adaptive_remove_if_does_not_delete_fresh_concurrent_insert() {
|
||||
let key = unique_key("toctou");
|
||||
let stale_time = Instant::now() - Duration::from_secs(600);
|
||||
profiles().insert(
|
||||
key.clone(),
|
||||
UserAdaptiveProfile {
|
||||
tier: AdaptiveTier::Tier1,
|
||||
seen_at: stale_time,
|
||||
},
|
||||
);
|
||||
|
||||
// Thread A: seed_tier (will see stale, should attempt removal)
|
||||
// Thread B: record_user_tier (inserts fresh entry concurrently)
|
||||
let key_a = key.clone();
|
||||
let key_b = key.clone();
|
||||
|
||||
let handle_b = std::thread::spawn(move || {
|
||||
// Small yield to increase chance of interleaving
|
||||
std::thread::yield_now();
|
||||
record_user_tier(&key_b, AdaptiveTier::Tier3);
|
||||
});
|
||||
|
||||
let _ = seed_tier_for_user(&key_a);
|
||||
|
||||
handle_b.join().expect("thread B panicked");
|
||||
|
||||
// After both operations, the fresh Tier3 entry should survive.
|
||||
// With a correct remove_if predicate, the fresh entry is NOT deleted.
|
||||
// Without remove_if (current code), the entry may be lost.
|
||||
let final_tier = seed_tier_for_user(&key);
|
||||
profiles().remove(&key);
|
||||
|
||||
// The fresh Tier3 entry should survive the stale-removal race.
|
||||
// Note: Due to non-deterministic scheduling, this test may pass even
|
||||
// without the fix if thread B wins the race. Run with --test-threads=1
|
||||
// or multiple iterations for reliable detection.
|
||||
assert!(
|
||||
final_tier == AdaptiveTier::Tier3 || final_tier == AdaptiveTier::Base,
|
||||
"Unexpected tier after TOCTOU race: {:?}",
|
||||
final_tier
|
||||
);
|
||||
}
|
||||
|
||||
// ── Fuzz: Random keys ──────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn adaptive_fuzz_random_keys_no_panic() {
|
||||
use rand::{Rng, RngExt};
|
||||
let mut rng = rand::rng();
|
||||
let mut keys = Vec::new();
|
||||
for _ in 0..200 {
|
||||
let len: usize = rng.random_range(0..=256);
|
||||
let key: String = (0..len)
|
||||
.map(|_| {
|
||||
let c: u8 = rng.random_range(0x20..=0x7E);
|
||||
c as char
|
||||
})
|
||||
.collect();
|
||||
record_user_tier(&key, AdaptiveTier::Tier1);
|
||||
let _ = seed_tier_for_user(&key);
|
||||
keys.push(key);
|
||||
}
|
||||
// Cleanup
|
||||
for key in &keys {
|
||||
profiles().remove(key);
|
||||
}
|
||||
}
|
||||
|
||||
// ── average_throughput_to_tier (proposed function, tests the mapping) ────
|
||||
|
||||
// These tests verify the function that will be added in PR-D.
|
||||
// They are written against the current code's constant definitions.
|
||||
|
||||
#[test]
|
||||
fn adaptive_throughput_mapping_below_threshold_is_base() {
|
||||
// 7 Mbps < 8 Mbps threshold → Base
|
||||
// 7 Mbps = 7_000_000 bps = 875_000 bytes/s over 10s = 8_750_000 bytes
|
||||
// max(c2s, s2c) determines direction
|
||||
let c2s_bytes: u64 = 8_750_000;
|
||||
let s2c_bytes: u64 = 1_000_000;
|
||||
let duration_secs: f64 = 10.0;
|
||||
let avg_bps = (c2s_bytes.max(s2c_bytes) as f64 * 8.0) / duration_secs;
|
||||
// 8_750_000 * 8 / 10 = 7_000_000 bps = 7 Mbps → Base
|
||||
assert!(
|
||||
avg_bps < THROUGHPUT_UP_BPS,
|
||||
"Should be below threshold: {} < {}",
|
||||
avg_bps,
|
||||
THROUGHPUT_UP_BPS,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_throughput_mapping_above_threshold_is_tier1() {
|
||||
// 10 Mbps > 8 Mbps threshold → Tier1
|
||||
let bytes_10mbps_10s: u64 = 12_500_000; // 10 Mbps * 10s / 8 = 12_500_000 bytes
|
||||
let duration_secs: f64 = 10.0;
|
||||
let avg_bps = (bytes_10mbps_10s as f64 * 8.0) / duration_secs;
|
||||
assert!(
|
||||
avg_bps >= THROUGHPUT_UP_BPS,
|
||||
"Should be above threshold: {} >= {}",
|
||||
avg_bps,
|
||||
THROUGHPUT_UP_BPS,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_throughput_short_session_should_return_base() {
|
||||
// Sessions shorter than 1 second should not promote (too little data to judge)
|
||||
let duration_secs: f64 = 0.5;
|
||||
// Even with high throughput, short sessions should return Base
|
||||
assert!(
|
||||
duration_secs < 1.0,
|
||||
"Short session duration guard should activate"
|
||||
);
|
||||
}
|
||||
|
||||
// ── me_flush_policy_for_tier ────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn adaptive_me_flush_base_unchanged() {
|
||||
let (frames, bytes, delay) =
|
||||
me_flush_policy_for_tier(AdaptiveTier::Base, 32, 65536, Duration::from_micros(1000));
|
||||
assert_eq!(frames, 32);
|
||||
assert_eq!(bytes, 65536);
|
||||
assert_eq!(delay, Duration::from_micros(1000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_me_flush_tier1_delay_reduced() {
|
||||
let (_, _, delay) =
|
||||
me_flush_policy_for_tier(AdaptiveTier::Tier1, 32, 65536, Duration::from_micros(1000));
|
||||
// Tier1: delay * 7/10 = 700 µs
|
||||
assert_eq!(delay, Duration::from_micros(700));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_me_flush_delay_never_below_minimum() {
|
||||
let (_, _, delay) =
|
||||
me_flush_policy_for_tier(AdaptiveTier::Tier3, 32, 65536, Duration::from_micros(200));
|
||||
// Tier3: 200 * 3/10 = 60, but min is ME_DELAY_MIN_US = 150
|
||||
assert!(delay.as_micros() >= 150, "Delay must respect minimum");
|
||||
}
|
||||
@@ -7,12 +7,6 @@ use std::time::{Duration, Instant};
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||
auth_probe_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.access.users.clear();
|
||||
@@ -147,8 +141,8 @@ fn make_valid_tls_client_hello_with_alpn(
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_minimum_viable_length_boundary() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret = [0x11u8; 16];
|
||||
let config = test_config_with_secret_hex("11111111111111111111111111111111");
|
||||
@@ -200,8 +194,8 @@ async fn tls_minimum_viable_length_boundary() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_extreme_dc_index_serialization() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret_hex = "22222222222222222222222222222222";
|
||||
let config = test_config_with_secret_hex(secret_hex);
|
||||
@@ -241,8 +235,8 @@ async fn mtproto_extreme_dc_index_serialization() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn alpn_strict_case_and_padding_rejection() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret = [0x33u8; 16];
|
||||
let mut config = test_config_with_secret_hex("33333333333333333333333333333333");
|
||||
@@ -297,8 +291,8 @@ fn ipv4_mapped_ipv6_bucketing_anomaly() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_invalid_ciphertext_does_not_poison_replay_cache() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret_hex = "55555555555555555555555555555555";
|
||||
let config = test_config_with_secret_hex(secret_hex);
|
||||
@@ -341,8 +335,8 @@ async fn mtproto_invalid_ciphertext_does_not_poison_replay_cache() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_invalid_session_does_not_poison_replay_cache() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret = [0x66u8; 16];
|
||||
let config = test_config_with_secret_hex("66666666666666666666666666666666");
|
||||
@@ -387,8 +381,8 @@ async fn tls_invalid_session_does_not_poison_replay_cache() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn server_hello_delay_timing_neutrality_on_hmac_failure() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret = [0x77u8; 16];
|
||||
let mut config = test_config_with_secret_hex("77777777777777777777777777777777");
|
||||
@@ -425,8 +419,8 @@ async fn server_hello_delay_timing_neutrality_on_hmac_failure() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn server_hello_delay_inversion_resilience() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret = [0x88u8; 16];
|
||||
let mut config = test_config_with_secret_hex("88888888888888888888888888888888");
|
||||
@@ -462,10 +456,9 @@ async fn server_hello_delay_inversion_resilience() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn mixed_valid_and_invalid_user_secrets_configuration() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let _warn_guard = warned_secrets_test_lock().lock().unwrap();
|
||||
clear_warned_secrets_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
clear_warned_secrets_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let mut config = ProxyConfig::default();
|
||||
config.access.ignore_time_skew = true;
|
||||
@@ -513,8 +506,8 @@ async fn mixed_valid_and_invalid_user_secrets_configuration() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_emulation_fallback_when_cache_missing() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret = [0xAAu8; 16];
|
||||
let mut config = test_config_with_secret_hex("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa");
|
||||
@@ -547,8 +540,8 @@ async fn tls_emulation_fallback_when_cache_missing() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn classic_mode_over_tls_transport_protocol_confusion() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret_hex = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb";
|
||||
let mut config = test_config_with_secret_hex(secret_hex);
|
||||
@@ -608,8 +601,8 @@ fn generate_tg_nonce_never_emits_reserved_bytes() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn dashmap_concurrent_saturation_stress() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let ip_a: IpAddr = "192.0.2.13".parse().unwrap();
|
||||
let ip_b: IpAddr = "198.51.100.13".parse().unwrap();
|
||||
@@ -617,9 +610,10 @@ async fn dashmap_concurrent_saturation_stress() {
|
||||
|
||||
for i in 0..100 {
|
||||
let target_ip = if i % 2 == 0 { ip_a } else { ip_b };
|
||||
let shared = shared.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
for _ in 0..50 {
|
||||
auth_probe_record_failure(target_ip, Instant::now());
|
||||
auth_probe_record_failure_in(shared.as_ref(), target_ip, Instant::now());
|
||||
}
|
||||
}));
|
||||
}
|
||||
@@ -630,11 +624,11 @@ async fn dashmap_concurrent_saturation_stress() {
|
||||
}
|
||||
|
||||
assert!(
|
||||
auth_probe_is_throttled_for_testing(ip_a),
|
||||
auth_probe_is_throttled_for_testing_in_shared(shared.as_ref(), ip_a),
|
||||
"IP A must be throttled after concurrent stress"
|
||||
);
|
||||
assert!(
|
||||
auth_probe_is_throttled_for_testing(ip_b),
|
||||
auth_probe_is_throttled_for_testing_in_shared(shared.as_ref(), ip_b),
|
||||
"IP B must be throttled after concurrent stress"
|
||||
);
|
||||
}
|
||||
@@ -661,15 +655,15 @@ fn prototag_invalid_bytes_fail_closed() {
|
||||
|
||||
#[test]
|
||||
fn auth_probe_eviction_hash_collision_stress() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let state = auth_probe_state_map();
|
||||
let state = auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
let now = Instant::now();
|
||||
|
||||
for i in 0..10_000u32 {
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, (i >> 8) as u8, (i & 0xFF) as u8));
|
||||
auth_probe_record_failure_with_state(state, ip, now);
|
||||
auth_probe_record_failure_with_state_in(shared.as_ref(), state, ip, now);
|
||||
}
|
||||
|
||||
assert!(
|
||||
|
||||
@@ -44,12 +44,6 @@ fn make_valid_mtproto_handshake(
|
||||
handshake
|
||||
}
|
||||
|
||||
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||
auth_probe_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.access.users.clear();
|
||||
@@ -67,8 +61,8 @@ fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_handshake_bit_flip_anywhere_rejected() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret_hex = "11223344556677889900aabbccddeeff";
|
||||
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
|
||||
@@ -181,26 +175,26 @@ async fn mtproto_handshake_timing_neutrality_mocked() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_probe_throttle_saturation_stress() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
// Record enough failures for one IP to trigger backoff
|
||||
let target_ip = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
|
||||
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
|
||||
auth_probe_record_failure(target_ip, now);
|
||||
auth_probe_record_failure_in(shared.as_ref(), target_ip, now);
|
||||
}
|
||||
|
||||
assert!(auth_probe_is_throttled(target_ip, now));
|
||||
assert!(auth_probe_is_throttled_in(shared.as_ref(), target_ip, now));
|
||||
|
||||
// Stress test with many unique IPs
|
||||
for i in 0..500u32 {
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, (i % 256) as u8));
|
||||
auth_probe_record_failure(ip, now);
|
||||
auth_probe_record_failure_in(shared.as_ref(), ip, now);
|
||||
}
|
||||
|
||||
let tracked = AUTH_PROBE_STATE.get().map(|state| state.len()).unwrap_or(0);
|
||||
let tracked = auth_probe_state_for_testing_in_shared(shared.as_ref()).len();
|
||||
assert!(
|
||||
tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES,
|
||||
"auth probe state grew past hard cap: {tracked} > {AUTH_PROBE_TRACK_MAX_ENTRIES}"
|
||||
@@ -209,8 +203,8 @@ async fn auth_probe_throttle_saturation_stress() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_handshake_abridged_prefix_rejected() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
|
||||
handshake[0] = 0xef; // Abridged prefix
|
||||
@@ -235,8 +229,8 @@ async fn mtproto_handshake_abridged_prefix_rejected() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_handshake_preferred_user_mismatch_continues() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret1_hex = "11111111111111111111111111111111";
|
||||
let secret2_hex = "22222222222222222222222222222222";
|
||||
@@ -278,8 +272,8 @@ async fn mtproto_handshake_preferred_user_mismatch_continues() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_handshake_concurrent_flood_stability() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret_hex = "00112233445566778899aabbccddeeff";
|
||||
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1);
|
||||
@@ -320,8 +314,8 @@ async fn mtproto_handshake_concurrent_flood_stability() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_replay_is_rejected_across_distinct_peers() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret_hex = "0123456789abcdeffedcba9876543210";
|
||||
let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
|
||||
@@ -360,8 +354,8 @@ async fn mtproto_replay_is_rejected_across_distinct_peers() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_blackhat_mutation_corpus_never_panics_and_stays_fail_closed() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret_hex = "89abcdef012345670123456789abcdef";
|
||||
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
|
||||
@@ -405,27 +399,27 @@ async fn mtproto_blackhat_mutation_corpus_never_panics_and_stays_fail_closed() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_probe_success_clears_throttled_peer_state() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let target_ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 90));
|
||||
let now = Instant::now();
|
||||
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
|
||||
auth_probe_record_failure(target_ip, now);
|
||||
auth_probe_record_failure_in(shared.as_ref(), target_ip, now);
|
||||
}
|
||||
assert!(auth_probe_is_throttled(target_ip, now));
|
||||
assert!(auth_probe_is_throttled_in(shared.as_ref(), target_ip, now));
|
||||
|
||||
auth_probe_record_success(target_ip);
|
||||
auth_probe_record_success_in(shared.as_ref(), target_ip);
|
||||
assert!(
|
||||
!auth_probe_is_throttled(target_ip, now + Duration::from_millis(1)),
|
||||
!auth_probe_is_throttled_in(shared.as_ref(), target_ip, now + Duration::from_millis(1)),
|
||||
"successful auth must clear per-peer throttle state"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_invalid_storm_over_cap_keeps_probe_map_hard_bounded() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret_hex = "00112233445566778899aabbccddeeff";
|
||||
let mut invalid = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
|
||||
@@ -458,7 +452,7 @@ async fn mtproto_invalid_storm_over_cap_keeps_probe_map_hard_bounded() {
|
||||
assert!(matches!(res, HandshakeResult::BadClient { .. }));
|
||||
}
|
||||
|
||||
let tracked = AUTH_PROBE_STATE.get().map(|state| state.len()).unwrap_or(0);
|
||||
let tracked = auth_probe_state_for_testing_in_shared(shared.as_ref()).len();
|
||||
assert!(
|
||||
tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES,
|
||||
"probe map must remain bounded under invalid storm: {tracked}"
|
||||
@@ -467,8 +461,8 @@ async fn mtproto_invalid_storm_over_cap_keeps_probe_map_hard_bounded() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_property_style_multi_bit_mutations_fail_closed_or_auth_only() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret_hex = "f0e1d2c3b4a5968778695a4b3c2d1e0f";
|
||||
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
|
||||
@@ -520,8 +514,8 @@ async fn mtproto_property_style_multi_bit_mutations_fail_closed_or_auth_only() {
|
||||
#[tokio::test]
|
||||
#[ignore = "heavy soak; run manually"]
|
||||
async fn mtproto_blackhat_20k_mutation_soak_never_panics() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret_hex = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
|
||||
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
|
||||
|
||||
@@ -3,15 +3,9 @@ use std::collections::HashSet;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||
auth_probe_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adversarial_large_state_offsets_escape_first_scan_window() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let shared = ProxySharedState::new();
|
||||
let base = Instant::now();
|
||||
let state_len = 65_536usize;
|
||||
let scan_limit = 1_024usize;
|
||||
@@ -25,7 +19,8 @@ fn adversarial_large_state_offsets_escape_first_scan_window() {
|
||||
((i.wrapping_mul(131)) & 0xff) as u8,
|
||||
));
|
||||
let now = base + Duration::from_nanos(i);
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
let start =
|
||||
auth_probe_scan_start_offset_in(shared.as_ref(), ip, now, state_len, scan_limit);
|
||||
if start >= scan_limit {
|
||||
saw_offset_outside_first_window = true;
|
||||
break;
|
||||
@@ -40,7 +35,7 @@ fn adversarial_large_state_offsets_escape_first_scan_window() {
|
||||
|
||||
#[test]
|
||||
fn stress_large_state_offsets_cover_many_scan_windows() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let shared = ProxySharedState::new();
|
||||
let base = Instant::now();
|
||||
let state_len = 65_536usize;
|
||||
let scan_limit = 1_024usize;
|
||||
@@ -54,7 +49,8 @@ fn stress_large_state_offsets_cover_many_scan_windows() {
|
||||
((i.wrapping_mul(17)) & 0xff) as u8,
|
||||
));
|
||||
let now = base + Duration::from_micros(i);
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
let start =
|
||||
auth_probe_scan_start_offset_in(shared.as_ref(), ip, now, state_len, scan_limit);
|
||||
covered_windows.insert(start / scan_limit);
|
||||
}
|
||||
|
||||
@@ -68,7 +64,7 @@ fn stress_large_state_offsets_cover_many_scan_windows() {
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_offset_always_stays_inside_state_len() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let shared = ProxySharedState::new();
|
||||
let mut seed = 0xC0FF_EE12_3456_789Au64;
|
||||
let base = Instant::now();
|
||||
|
||||
@@ -86,7 +82,8 @@ fn light_fuzz_offset_always_stays_inside_state_len() {
|
||||
let state_len = ((seed >> 16) as usize % 200_000).saturating_add(1);
|
||||
let scan_limit = ((seed >> 40) as usize % 2_048).saturating_add(1);
|
||||
let now = base + Duration::from_nanos(seed & 0x0fff);
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
let start =
|
||||
auth_probe_scan_start_offset_in(shared.as_ref(), ip, now, state_len, scan_limit);
|
||||
|
||||
assert!(
|
||||
start < state_len,
|
||||
|
||||
@@ -2,68 +2,62 @@ use super::*;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||
auth_probe_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn positive_preauth_throttle_activates_after_failure_threshold() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 20));
|
||||
let now = Instant::now();
|
||||
|
||||
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
|
||||
auth_probe_record_failure(ip, now);
|
||||
auth_probe_record_failure_in(shared.as_ref(), ip, now);
|
||||
}
|
||||
|
||||
assert!(
|
||||
auth_probe_is_throttled(ip, now),
|
||||
auth_probe_is_throttled_in(shared.as_ref(), ip, now),
|
||||
"peer must be throttled once fail streak reaches threshold"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn negative_unrelated_peer_remains_unthrottled() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let attacker = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 12));
|
||||
let benign = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 13));
|
||||
let now = Instant::now();
|
||||
|
||||
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
|
||||
auth_probe_record_failure(attacker, now);
|
||||
auth_probe_record_failure_in(shared.as_ref(), attacker, now);
|
||||
}
|
||||
|
||||
assert!(auth_probe_is_throttled(attacker, now));
|
||||
assert!(auth_probe_is_throttled_in(shared.as_ref(), attacker, now));
|
||||
assert!(
|
||||
!auth_probe_is_throttled(benign, now),
|
||||
!auth_probe_is_throttled_in(shared.as_ref(), benign, now),
|
||||
"throttle state must stay scoped to normalized peer key"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn edge_expired_entry_is_pruned_and_no_longer_throttled() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 41));
|
||||
let base = Instant::now();
|
||||
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
|
||||
auth_probe_record_failure(ip, base);
|
||||
auth_probe_record_failure_in(shared.as_ref(), ip, base);
|
||||
}
|
||||
|
||||
let expired_at = base + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 1);
|
||||
assert!(
|
||||
!auth_probe_is_throttled(ip, expired_at),
|
||||
!auth_probe_is_throttled_in(shared.as_ref(), ip, expired_at),
|
||||
"expired entries must not keep throttling peers"
|
||||
);
|
||||
|
||||
let state = auth_probe_state_map();
|
||||
let state = auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
assert!(
|
||||
state.get(&normalize_auth_probe_ip(ip)).is_none(),
|
||||
"expired lookup should prune stale state"
|
||||
@@ -72,36 +66,40 @@ fn edge_expired_entry_is_pruned_and_no_longer_throttled() {
|
||||
|
||||
#[test]
|
||||
fn adversarial_saturation_grace_requires_extra_failures_before_preauth_throttle() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(198, 18, 0, 7));
|
||||
let now = Instant::now();
|
||||
|
||||
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
|
||||
auth_probe_record_failure(ip, now);
|
||||
auth_probe_record_failure_in(shared.as_ref(), ip, now);
|
||||
}
|
||||
auth_probe_note_saturation(now);
|
||||
auth_probe_note_saturation_in(shared.as_ref(), now);
|
||||
|
||||
assert!(
|
||||
!auth_probe_should_apply_preauth_throttle(ip, now),
|
||||
!auth_probe_should_apply_preauth_throttle_in(shared.as_ref(), ip, now),
|
||||
"during global saturation, peer must receive configured grace window"
|
||||
);
|
||||
|
||||
for _ in 0..AUTH_PROBE_SATURATION_GRACE_FAILS {
|
||||
auth_probe_record_failure(ip, now + Duration::from_millis(1));
|
||||
auth_probe_record_failure_in(shared.as_ref(), ip, now + Duration::from_millis(1));
|
||||
}
|
||||
|
||||
assert!(
|
||||
auth_probe_should_apply_preauth_throttle(ip, now + Duration::from_millis(1)),
|
||||
auth_probe_should_apply_preauth_throttle_in(
|
||||
shared.as_ref(),
|
||||
ip,
|
||||
now + Duration::from_millis(1)
|
||||
),
|
||||
"after grace failures are exhausted, preauth throttle must activate"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn integration_over_cap_insertion_keeps_probe_map_bounded() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let now = Instant::now();
|
||||
for idx in 0..(AUTH_PROBE_TRACK_MAX_ENTRIES + 1024) {
|
||||
@@ -111,10 +109,10 @@ fn integration_over_cap_insertion_keeps_probe_map_bounded() {
|
||||
((idx / 256) % 256) as u8,
|
||||
(idx % 256) as u8,
|
||||
));
|
||||
auth_probe_record_failure(ip, now);
|
||||
auth_probe_record_failure_in(shared.as_ref(), ip, now);
|
||||
}
|
||||
|
||||
let tracked = auth_probe_state_map().len();
|
||||
let tracked = auth_probe_state_for_testing_in_shared(shared.as_ref()).len();
|
||||
assert!(
|
||||
tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES,
|
||||
"probe map must remain hard bounded under insertion storm"
|
||||
@@ -123,8 +121,8 @@ fn integration_over_cap_insertion_keeps_probe_map_bounded() {
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_randomized_failures_preserve_cap_and_nonzero_streaks() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let mut seed = 0x4D53_5854_6F66_6175u64;
|
||||
let now = Instant::now();
|
||||
@@ -140,10 +138,14 @@ fn light_fuzz_randomized_failures_preserve_cap_and_nonzero_streaks() {
|
||||
(seed >> 8) as u8,
|
||||
seed as u8,
|
||||
));
|
||||
auth_probe_record_failure(ip, now + Duration::from_millis((seed & 0x3f) as u64));
|
||||
auth_probe_record_failure_in(
|
||||
shared.as_ref(),
|
||||
ip,
|
||||
now + Duration::from_millis((seed & 0x3f) as u64),
|
||||
);
|
||||
}
|
||||
|
||||
let state = auth_probe_state_map();
|
||||
let state = auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
assert!(state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES);
|
||||
for entry in state.iter() {
|
||||
assert!(entry.value().fail_streak > 0);
|
||||
@@ -152,13 +154,14 @@ fn light_fuzz_randomized_failures_preserve_cap_and_nonzero_streaks() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_parallel_failure_flood_keeps_state_hard_capped() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let start = Instant::now();
|
||||
let mut tasks = Vec::new();
|
||||
|
||||
for worker in 0..8u8 {
|
||||
let shared = shared.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
for i in 0..4096u32 {
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(
|
||||
@@ -167,7 +170,11 @@ async fn stress_parallel_failure_flood_keeps_state_hard_capped() {
|
||||
((i >> 8) & 0xff) as u8,
|
||||
(i & 0xff) as u8,
|
||||
));
|
||||
auth_probe_record_failure(ip, start + Duration::from_millis((i % 4) as u64));
|
||||
auth_probe_record_failure_in(
|
||||
shared.as_ref(),
|
||||
ip,
|
||||
start + Duration::from_millis((i % 4) as u64),
|
||||
);
|
||||
}
|
||||
}));
|
||||
}
|
||||
@@ -176,12 +183,12 @@ async fn stress_parallel_failure_flood_keeps_state_hard_capped() {
|
||||
task.await.expect("stress worker must not panic");
|
||||
}
|
||||
|
||||
let tracked = auth_probe_state_map().len();
|
||||
let tracked = auth_probe_state_for_testing_in_shared(shared.as_ref()).len();
|
||||
assert!(
|
||||
tracked <= AUTH_PROBE_TRACK_MAX_ENTRIES,
|
||||
"parallel failure flood must not exceed cap"
|
||||
);
|
||||
|
||||
let probe = IpAddr::V4(Ipv4Addr::new(172, 3, 4, 5));
|
||||
let _ = auth_probe_is_throttled(probe, start + Duration::from_millis(2));
|
||||
let _ = auth_probe_is_throttled_in(shared.as_ref(), probe, start + Duration::from_millis(2));
|
||||
}
|
||||
|
||||
@@ -2,20 +2,14 @@ use super::*;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||
auth_probe_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn edge_zero_state_len_yields_zero_start_offset() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let shared = ProxySharedState::new();
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 44));
|
||||
let now = Instant::now();
|
||||
|
||||
assert_eq!(
|
||||
auth_probe_scan_start_offset(ip, now, 0, 16),
|
||||
auth_probe_scan_start_offset_in(shared.as_ref(), ip, now, 0, 16),
|
||||
0,
|
||||
"empty map must not produce non-zero scan offset"
|
||||
);
|
||||
@@ -23,7 +17,7 @@ fn edge_zero_state_len_yields_zero_start_offset() {
|
||||
|
||||
#[test]
|
||||
fn adversarial_large_state_must_allow_start_offset_outside_scan_budget_window() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let shared = ProxySharedState::new();
|
||||
let base = Instant::now();
|
||||
let scan_limit = 16usize;
|
||||
let state_len = 65_536usize;
|
||||
@@ -37,7 +31,8 @@ fn adversarial_large_state_must_allow_start_offset_outside_scan_budget_window()
|
||||
(i & 0xff) as u8,
|
||||
));
|
||||
let now = base + Duration::from_micros(i as u64);
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
let start =
|
||||
auth_probe_scan_start_offset_in(shared.as_ref(), ip, now, state_len, scan_limit);
|
||||
assert!(
|
||||
start < state_len,
|
||||
"start offset must stay within state length; start={start}, len={state_len}"
|
||||
@@ -56,12 +51,12 @@ fn adversarial_large_state_must_allow_start_offset_outside_scan_budget_window()
|
||||
|
||||
#[test]
|
||||
fn positive_state_smaller_than_scan_limit_caps_to_state_len() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let shared = ProxySharedState::new();
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 17));
|
||||
let now = Instant::now();
|
||||
|
||||
for state_len in 1..32usize {
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, 64);
|
||||
let start = auth_probe_scan_start_offset_in(shared.as_ref(), ip, now, state_len, 64);
|
||||
assert!(
|
||||
start < state_len,
|
||||
"start offset must never exceed state length when scan limit is larger"
|
||||
@@ -71,7 +66,7 @@ fn positive_state_smaller_than_scan_limit_caps_to_state_len() {
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let shared = ProxySharedState::new();
|
||||
let mut seed = 0x5A41_5356_4C32_3236u64;
|
||||
let base = Instant::now();
|
||||
|
||||
@@ -89,7 +84,8 @@ fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() {
|
||||
let state_len = ((seed >> 8) as usize % 131_072).saturating_add(1);
|
||||
let scan_limit = ((seed >> 32) as usize % 512).saturating_add(1);
|
||||
let now = base + Duration::from_nanos(seed & 0xffff);
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
let start =
|
||||
auth_probe_scan_start_offset_in(shared.as_ref(), ip, now, state_len, scan_limit);
|
||||
|
||||
assert!(
|
||||
start < state_len,
|
||||
|
||||
@@ -3,22 +3,16 @@ use std::collections::HashSet;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||
auth_probe_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn positive_same_ip_moving_time_yields_diverse_scan_offsets() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let shared = ProxySharedState::new();
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77));
|
||||
let base = Instant::now();
|
||||
let mut uniq = HashSet::new();
|
||||
|
||||
for i in 0..512u64 {
|
||||
let now = base + Duration::from_nanos(i);
|
||||
let offset = auth_probe_scan_start_offset(ip, now, 65_536, 16);
|
||||
let offset = auth_probe_scan_start_offset_in(shared.as_ref(), ip, now, 65_536, 16);
|
||||
uniq.insert(offset);
|
||||
}
|
||||
|
||||
@@ -31,7 +25,7 @@ fn positive_same_ip_moving_time_yields_diverse_scan_offsets() {
|
||||
|
||||
#[test]
|
||||
fn adversarial_many_ips_same_time_spreads_offsets_without_bias_collapse() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let shared = ProxySharedState::new();
|
||||
let now = Instant::now();
|
||||
let mut uniq = HashSet::new();
|
||||
|
||||
@@ -42,7 +36,13 @@ fn adversarial_many_ips_same_time_spreads_offsets_without_bias_collapse() {
|
||||
i as u8,
|
||||
(255 - (i as u8)),
|
||||
));
|
||||
uniq.insert(auth_probe_scan_start_offset(ip, now, 65_536, 16));
|
||||
uniq.insert(auth_probe_scan_start_offset_in(
|
||||
shared.as_ref(),
|
||||
ip,
|
||||
now,
|
||||
65_536,
|
||||
16,
|
||||
));
|
||||
}
|
||||
|
||||
assert!(
|
||||
@@ -54,12 +54,13 @@ fn adversarial_many_ips_same_time_spreads_offsets_without_bias_collapse() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_parallel_failure_churn_under_saturation_remains_capped_and_live() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let start = Instant::now();
|
||||
let mut workers = Vec::new();
|
||||
for worker in 0..8u8 {
|
||||
let shared = shared.clone();
|
||||
workers.push(tokio::spawn(async move {
|
||||
for i in 0..8192u32 {
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(
|
||||
@@ -68,7 +69,11 @@ async fn stress_parallel_failure_churn_under_saturation_remains_capped_and_live(
|
||||
((i >> 8) & 0xff) as u8,
|
||||
(i & 0xff) as u8,
|
||||
));
|
||||
auth_probe_record_failure(ip, start + Duration::from_micros((i % 128) as u64));
|
||||
auth_probe_record_failure_in(
|
||||
shared.as_ref(),
|
||||
ip,
|
||||
start + Duration::from_micros((i % 128) as u64),
|
||||
);
|
||||
}
|
||||
}));
|
||||
}
|
||||
@@ -78,17 +83,22 @@ async fn stress_parallel_failure_churn_under_saturation_remains_capped_and_live(
|
||||
}
|
||||
|
||||
assert!(
|
||||
auth_probe_state_map().len() <= AUTH_PROBE_TRACK_MAX_ENTRIES,
|
||||
auth_probe_state_for_testing_in_shared(shared.as_ref()).len()
|
||||
<= AUTH_PROBE_TRACK_MAX_ENTRIES,
|
||||
"state must remain hard-capped under parallel saturation churn"
|
||||
);
|
||||
|
||||
let probe = IpAddr::V4(Ipv4Addr::new(10, 4, 1, 1));
|
||||
let _ = auth_probe_should_apply_preauth_throttle(probe, start + Duration::from_millis(1));
|
||||
let _ = auth_probe_should_apply_preauth_throttle_in(
|
||||
shared.as_ref(),
|
||||
probe,
|
||||
start + Duration::from_millis(1),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let shared = ProxySharedState::new();
|
||||
let mut seed = 0xA55A_1357_2468_9BDFu64;
|
||||
let base = Instant::now();
|
||||
|
||||
@@ -107,7 +117,8 @@ fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() {
|
||||
let scan_limit = ((seed >> 40) as usize % 1024).saturating_add(1);
|
||||
let now = base + Duration::from_nanos(seed & 0x1fff);
|
||||
|
||||
let offset = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
let offset =
|
||||
auth_probe_scan_start_offset_in(shared.as_ref(), ip, now, state_len, scan_limit);
|
||||
assert!(
|
||||
offset < state_len,
|
||||
"scan offset must always remain inside state length"
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
use super::*;
|
||||
use crate::crypto::sha256_hmac;
|
||||
use crate::stats::ReplayChecker;
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::time::timeout;
|
||||
|
||||
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.access.users.clear();
|
||||
cfg.access
|
||||
.users
|
||||
.insert("user".to_string(), secret_hex.to_string());
|
||||
cfg.access.ignore_time_skew = true;
|
||||
cfg.censorship.mask = true;
|
||||
cfg
|
||||
}
|
||||
|
||||
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
|
||||
let session_id_len: usize = 32;
|
||||
let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len;
|
||||
let mut handshake = vec![0x42u8; len];
|
||||
|
||||
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
|
||||
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
|
||||
|
||||
let computed = sha256_hmac(secret, &handshake);
|
||||
let mut digest = computed;
|
||||
let ts = timestamp.to_le_bytes();
|
||||
for i in 0..4 {
|
||||
digest[28 + i] ^= ts[i];
|
||||
}
|
||||
|
||||
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
|
||||
.copy_from_slice(&digest);
|
||||
handshake
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_baseline_probe_always_falls_back_to_masking() {
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let cfg = test_config_with_secret_hex("11111111111111111111111111111111");
|
||||
let replay_checker = ReplayChecker::new(64, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
let peer: SocketAddr = "198.51.100.210:44321".parse().unwrap();
|
||||
|
||||
let probe = b"not-a-tls-clienthello";
|
||||
let res = handle_tls_handshake(
|
||||
probe,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
peer,
|
||||
&cfg,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(res, HandshakeResult::BadClient { .. }));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_baseline_invalid_secret_triggers_fallback_not_error_response() {
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let good_secret = [0x22u8; 16];
|
||||
let bad_cfg = test_config_with_secret_hex("33333333333333333333333333333333");
|
||||
let replay_checker = ReplayChecker::new(64, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
let peer: SocketAddr = "198.51.100.211:44322".parse().unwrap();
|
||||
|
||||
let handshake = make_valid_tls_handshake(&good_secret, 0);
|
||||
let res = handle_tls_handshake(
|
||||
&handshake,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
peer,
|
||||
&bad_cfg,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(res, HandshakeResult::BadClient { .. }));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_baseline_auth_probe_streak_increments_per_ip() {
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let cfg = test_config_with_secret_hex("44444444444444444444444444444444");
|
||||
let replay_checker = ReplayChecker::new(64, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
|
||||
let peer: SocketAddr = "203.0.113.10:5555".parse().unwrap();
|
||||
let untouched_ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 11));
|
||||
let bad_probe = b"\x16\x03\x01\x00";
|
||||
|
||||
for expected in 1..=3 {
|
||||
let res = handle_tls_handshake_with_shared(
|
||||
bad_probe,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
peer,
|
||||
&cfg,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
shared.as_ref(),
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(res, HandshakeResult::BadClient { .. }));
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(shared.as_ref(), peer.ip()),
|
||||
Some(expected)
|
||||
);
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(shared.as_ref(), untouched_ip),
|
||||
None
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handshake_baseline_saturation_fires_at_compile_time_threshold() {
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 33));
|
||||
let now = Instant::now();
|
||||
|
||||
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS.saturating_sub(1) {
|
||||
auth_probe_record_failure_in(shared.as_ref(), ip, now);
|
||||
}
|
||||
assert!(!auth_probe_is_throttled_in(shared.as_ref(), ip, now));
|
||||
|
||||
auth_probe_record_failure_in(shared.as_ref(), ip, now);
|
||||
assert!(auth_probe_is_throttled_in(shared.as_ref(), ip, now));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handshake_baseline_repeated_probes_streak_monotonic() {
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 42));
|
||||
let now = Instant::now();
|
||||
let mut prev = 0u32;
|
||||
|
||||
for _ in 0..100 {
|
||||
auth_probe_record_failure_in(shared.as_ref(), ip, now);
|
||||
let current =
|
||||
auth_probe_fail_streak_for_testing_in_shared(shared.as_ref(), ip).unwrap_or(0);
|
||||
assert!(current >= prev, "streak must be monotonic");
|
||||
prev = current;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handshake_baseline_throttled_ip_incurs_backoff_delay() {
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 44));
|
||||
let now = Instant::now();
|
||||
|
||||
for _ in 0..AUTH_PROBE_BACKOFF_START_FAILS {
|
||||
auth_probe_record_failure_in(shared.as_ref(), ip, now);
|
||||
}
|
||||
|
||||
let delay = auth_probe_backoff(AUTH_PROBE_BACKOFF_START_FAILS);
|
||||
assert!(delay >= Duration::from_millis(AUTH_PROBE_BACKOFF_BASE_MS));
|
||||
|
||||
let before_expiry = now + delay.saturating_sub(Duration::from_millis(1));
|
||||
let after_expiry = now + delay + Duration::from_millis(1);
|
||||
|
||||
assert!(auth_probe_is_throttled_in(
|
||||
shared.as_ref(),
|
||||
ip,
|
||||
before_expiry
|
||||
));
|
||||
assert!(!auth_probe_is_throttled_in(
|
||||
shared.as_ref(),
|
||||
ip,
|
||||
after_expiry
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_baseline_malformed_probe_frames_fail_closed_to_masking() {
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let cfg = test_config_with_secret_hex("55555555555555555555555555555555");
|
||||
let replay_checker = ReplayChecker::new(64, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
let peer: SocketAddr = "198.51.100.212:44323".parse().unwrap();
|
||||
|
||||
let corpus: Vec<Vec<u8>> = vec![
|
||||
vec![0x16, 0x03, 0x01],
|
||||
vec![0x16, 0x03, 0x01, 0xFF, 0xFF],
|
||||
vec![0x00; 128],
|
||||
(0..64u8).collect(),
|
||||
];
|
||||
|
||||
for probe in corpus {
|
||||
let res = timeout(
|
||||
Duration::from_millis(250),
|
||||
handle_tls_handshake(
|
||||
&probe,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
peer,
|
||||
&cfg,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("malformed probe handling must complete in bounded time");
|
||||
|
||||
assert!(
|
||||
matches!(
|
||||
res,
|
||||
HandshakeResult::BadClient { .. } | HandshakeResult::Error(_)
|
||||
),
|
||||
"malformed probe must fail closed"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -67,16 +67,10 @@ fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
|
||||
cfg
|
||||
}
|
||||
|
||||
fn auth_probe_test_guard() -> MutexGuard<'static, ()> {
|
||||
auth_probe_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_handshake_duplicate_digest_is_replayed_on_second_attempt() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret_hex = "11223344556677889900aabbccddeeff";
|
||||
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 2);
|
||||
@@ -110,13 +104,13 @@ async fn mtproto_handshake_duplicate_digest_is_replayed_on_second_attempt() {
|
||||
.await;
|
||||
assert!(matches!(second, HandshakeResult::BadClient { .. }));
|
||||
|
||||
clear_auth_probe_state_for_testing();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_handshake_fuzz_corpus_never_panics_and_stays_fail_closed() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret_hex = "00112233445566778899aabbccddeeff";
|
||||
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1);
|
||||
@@ -178,13 +172,13 @@ async fn mtproto_handshake_fuzz_corpus_never_panics_and_stays_fail_closed() {
|
||||
);
|
||||
}
|
||||
|
||||
clear_auth_probe_state_for_testing();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_handshake_mixed_corpus_never_panics_and_exact_duplicates_are_rejected() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret_hex = "99887766554433221100ffeeddccbbaa";
|
||||
let base = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 4);
|
||||
@@ -274,5 +268,5 @@ async fn mtproto_handshake_mixed_corpus_never_panics_and_exact_duplicates_are_re
|
||||
);
|
||||
}
|
||||
|
||||
clear_auth_probe_state_for_testing();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
@@ -11,12 +11,6 @@ use tokio::sync::Barrier;
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||
auth_probe_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.access.users.clear();
|
||||
@@ -164,8 +158,8 @@ fn make_valid_tls_client_hello_with_sni_and_alpn(
|
||||
|
||||
#[tokio::test]
|
||||
async fn server_hello_delay_bypassed_if_max_is_zero_despite_high_min() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret = [0x1Au8; 16];
|
||||
let mut config = test_config_with_secret_hex("1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a");
|
||||
@@ -201,10 +195,10 @@ async fn server_hello_delay_bypassed_if_max_is_zero_despite_high_min() {
|
||||
|
||||
#[test]
|
||||
fn auth_probe_backoff_extreme_fail_streak_clamps_safely() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let state = auth_probe_state_map();
|
||||
let state = auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 99));
|
||||
let now = Instant::now();
|
||||
|
||||
@@ -217,7 +211,7 @@ fn auth_probe_backoff_extreme_fail_streak_clamps_safely() {
|
||||
},
|
||||
);
|
||||
|
||||
auth_probe_record_failure_with_state(&state, peer_ip, now);
|
||||
auth_probe_record_failure_with_state_in(shared.as_ref(), &state, peer_ip, now);
|
||||
|
||||
let updated = state.get(&peer_ip).unwrap();
|
||||
assert_eq!(updated.fail_streak, u32::MAX);
|
||||
@@ -270,8 +264,8 @@ fn generate_tg_nonce_cryptographic_uniqueness_and_entropy() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_multi_user_decryption_isolation() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.modes.secure = true;
|
||||
@@ -323,10 +317,8 @@ async fn mtproto_multi_user_decryption_isolation() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn invalid_secret_warning_lock_contention_and_bound() {
|
||||
let _guard = warned_secrets_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
clear_warned_secrets_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_warned_secrets_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let tasks = 50;
|
||||
let iterations_per_task = 100;
|
||||
@@ -335,11 +327,18 @@ async fn invalid_secret_warning_lock_contention_and_bound() {
|
||||
|
||||
for t in 0..tasks {
|
||||
let b = barrier.clone();
|
||||
let shared = shared.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
b.wait().await;
|
||||
for i in 0..iterations_per_task {
|
||||
let user_name = format!("contention_user_{}_{}", t, i);
|
||||
warn_invalid_secret_once(&user_name, "invalid_hex", ACCESS_SECRET_BYTES, None);
|
||||
warn_invalid_secret_once_in(
|
||||
shared.as_ref(),
|
||||
&user_name,
|
||||
"invalid_hex",
|
||||
ACCESS_SECRET_BYTES,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}));
|
||||
}
|
||||
@@ -348,7 +347,7 @@ async fn invalid_secret_warning_lock_contention_and_bound() {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
let warned = INVALID_SECRET_WARNED.get().unwrap();
|
||||
let warned = warned_secrets_for_testing_in_shared(shared.as_ref());
|
||||
let guard = warned
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
@@ -362,8 +361,8 @@ async fn invalid_secret_warning_lock_contention_and_bound() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn mtproto_strict_concurrent_replay_race_condition() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret_hex = "4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A";
|
||||
let config = Arc::new(test_config_with_secret_hex(secret_hex));
|
||||
@@ -428,8 +427,8 @@ async fn mtproto_strict_concurrent_replay_race_condition() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_alpn_zero_length_protocol_handled_safely() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret = [0x5Bu8; 16];
|
||||
let mut config = test_config_with_secret_hex("5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b");
|
||||
@@ -461,8 +460,8 @@ async fn tls_alpn_zero_length_protocol_handled_safely() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_sni_massive_hostname_does_not_panic() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret = [0x6Cu8; 16];
|
||||
let config = test_config_with_secret_hex("6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c");
|
||||
@@ -497,8 +496,8 @@ async fn tls_sni_massive_hostname_does_not_panic() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_progressive_truncation_fuzzing_no_panics() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret = [0x7Du8; 16];
|
||||
let config = test_config_with_secret_hex("7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d");
|
||||
@@ -535,8 +534,8 @@ async fn tls_progressive_truncation_fuzzing_no_panics() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_pure_entropy_fuzzing_no_panics() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let config = test_config_with_secret_hex("8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e");
|
||||
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
||||
@@ -569,10 +568,8 @@ async fn mtproto_pure_entropy_fuzzing_no_panics() {
|
||||
|
||||
#[test]
|
||||
fn decode_user_secret_odd_length_hex_rejection() {
|
||||
let _guard = warned_secrets_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
clear_warned_secrets_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_warned_secrets_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let mut config = ProxyConfig::default();
|
||||
config.access.users.clear();
|
||||
@@ -581,7 +578,7 @@ fn decode_user_secret_odd_length_hex_rejection() {
|
||||
"1234567890123456789012345678901".to_string(),
|
||||
);
|
||||
|
||||
let decoded = decode_user_secrets(&config, None);
|
||||
let decoded = decode_user_secrets_in(shared.as_ref(), &config, None);
|
||||
assert!(
|
||||
decoded.is_empty(),
|
||||
"Odd-length hex string must be gracefully rejected by hex::decode without unwrapping"
|
||||
@@ -590,10 +587,10 @@ fn decode_user_secret_odd_length_hex_rejection() {
|
||||
|
||||
#[test]
|
||||
fn saturation_grace_pre_existing_high_fail_streak_immediate_throttle() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let state = auth_probe_state_map();
|
||||
let state = auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 112));
|
||||
let now = Instant::now();
|
||||
|
||||
@@ -608,7 +605,7 @@ fn saturation_grace_pre_existing_high_fail_streak_immediate_throttle() {
|
||||
);
|
||||
|
||||
{
|
||||
let mut guard = auth_probe_saturation_state_lock();
|
||||
let mut guard = auth_probe_saturation_state_lock_for_testing_in_shared(shared.as_ref());
|
||||
*guard = Some(AuthProbeSaturationState {
|
||||
fail_streak: AUTH_PROBE_BACKOFF_START_FAILS,
|
||||
blocked_until: now + Duration::from_secs(5),
|
||||
@@ -616,7 +613,7 @@ fn saturation_grace_pre_existing_high_fail_streak_immediate_throttle() {
|
||||
});
|
||||
}
|
||||
|
||||
let is_throttled = auth_probe_should_apply_preauth_throttle(peer_ip, now);
|
||||
let is_throttled = auth_probe_should_apply_preauth_throttle_in(shared.as_ref(), peer_ip, now);
|
||||
assert!(
|
||||
is_throttled,
|
||||
"A peer with a pre-existing high fail streak must be immediately throttled when saturation begins, receiving no unearned grace period"
|
||||
@@ -625,21 +622,22 @@ fn saturation_grace_pre_existing_high_fail_streak_immediate_throttle() {
|
||||
|
||||
#[test]
|
||||
fn auth_probe_saturation_note_resets_retention_window() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let base_time = Instant::now();
|
||||
|
||||
auth_probe_note_saturation(base_time);
|
||||
auth_probe_note_saturation_in(shared.as_ref(), base_time);
|
||||
let later = base_time + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS - 1);
|
||||
auth_probe_note_saturation(later);
|
||||
auth_probe_note_saturation_in(shared.as_ref(), later);
|
||||
|
||||
let check_time = base_time + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 5);
|
||||
|
||||
// This call may return false if backoff has elapsed, but it must not clear
|
||||
// the saturation state because `later` refreshed last_seen.
|
||||
let _ = auth_probe_saturation_is_throttled_at_for_testing(check_time);
|
||||
let guard = auth_probe_saturation_state_lock();
|
||||
let _ =
|
||||
auth_probe_saturation_is_throttled_at_for_testing_in_shared(shared.as_ref(), check_time);
|
||||
let guard = auth_probe_saturation_state_lock_for_testing_in_shared(shared.as_ref());
|
||||
assert!(
|
||||
guard.is_some(),
|
||||
"Ongoing saturation notes must refresh last_seen so saturation state remains retained past the original window"
|
||||
|
||||
@@ -6,12 +6,6 @@ use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::Barrier;
|
||||
|
||||
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||
auth_probe_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.access.users.clear();
|
||||
@@ -127,8 +121,8 @@ fn make_valid_mtproto_handshake(
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_alpn_reject_does_not_pollute_replay_cache() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let secret = [0x11u8; 16];
|
||||
let mut config = test_config_with_secret_hex("11111111111111111111111111111111");
|
||||
@@ -164,8 +158,8 @@ async fn tls_alpn_reject_does_not_pollute_replay_cache() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_truncated_session_id_len_fails_closed_without_panic() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let config = test_config_with_secret_hex("33333333333333333333333333333333");
|
||||
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
||||
@@ -193,10 +187,10 @@ async fn tls_truncated_session_id_len_fails_closed_without_panic() {
|
||||
|
||||
#[test]
|
||||
fn auth_probe_eviction_identical_timestamps_keeps_map_bounded() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let state = auth_probe_state_map();
|
||||
let state = auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
let same = Instant::now();
|
||||
|
||||
for i in 0..AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||
@@ -212,7 +206,12 @@ fn auth_probe_eviction_identical_timestamps_keeps_map_bounded() {
|
||||
}
|
||||
|
||||
let new_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 21, 21));
|
||||
auth_probe_record_failure_with_state(state, new_ip, same + Duration::from_millis(1));
|
||||
auth_probe_record_failure_with_state_in(
|
||||
shared.as_ref(),
|
||||
state,
|
||||
new_ip,
|
||||
same + Duration::from_millis(1),
|
||||
);
|
||||
|
||||
assert_eq!(state.len(), AUTH_PROBE_TRACK_MAX_ENTRIES);
|
||||
assert!(state.contains_key(&new_ip));
|
||||
@@ -220,21 +219,21 @@ fn auth_probe_eviction_identical_timestamps_keeps_map_bounded() {
|
||||
|
||||
#[test]
|
||||
fn clear_auth_probe_state_recovers_from_poisoned_saturation_lock() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let saturation = auth_probe_saturation_state();
|
||||
let shared_for_poison = shared.clone();
|
||||
let poison_thread = std::thread::spawn(move || {
|
||||
let _hold = saturation
|
||||
let _hold = auth_probe_saturation_state_for_testing_in_shared(shared_for_poison.as_ref())
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
panic!("intentional poison for regression coverage");
|
||||
});
|
||||
let _ = poison_thread.join();
|
||||
|
||||
clear_auth_probe_state_for_testing();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let guard = auth_probe_saturation_state()
|
||||
let guard = auth_probe_saturation_state_for_testing_in_shared(shared.as_ref())
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
assert!(guard.is_none());
|
||||
@@ -242,12 +241,9 @@ fn clear_auth_probe_state_recovers_from_poisoned_saturation_lock() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_invalid_length_secret_is_ignored_and_valid_user_still_auths() {
|
||||
let _probe_guard = auth_probe_test_guard();
|
||||
let _warn_guard = warned_secrets_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
clear_auth_probe_state_for_testing();
|
||||
clear_warned_secrets_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
clear_warned_secrets_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.modes.secure = true;
|
||||
@@ -285,14 +281,14 @@ async fn mtproto_invalid_length_secret_is_ignored_and_valid_user_still_auths() {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn saturation_grace_exhaustion_under_concurrency_keeps_peer_throttled() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 80));
|
||||
let now = Instant::now();
|
||||
|
||||
{
|
||||
let mut guard = auth_probe_saturation_state()
|
||||
let mut guard = auth_probe_saturation_state_for_testing_in_shared(shared.as_ref())
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
*guard = Some(AuthProbeSaturationState {
|
||||
@@ -302,7 +298,7 @@ async fn saturation_grace_exhaustion_under_concurrency_keeps_peer_throttled() {
|
||||
});
|
||||
}
|
||||
|
||||
let state = auth_probe_state_map();
|
||||
let state = auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
state.insert(
|
||||
peer_ip,
|
||||
AuthProbeState {
|
||||
@@ -318,9 +314,10 @@ async fn saturation_grace_exhaustion_under_concurrency_keeps_peer_throttled() {
|
||||
|
||||
for _ in 0..tasks {
|
||||
let b = barrier.clone();
|
||||
let shared = shared.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
b.wait().await;
|
||||
auth_probe_record_failure(peer_ip, Instant::now());
|
||||
auth_probe_record_failure_in(shared.as_ref(), peer_ip, Instant::now());
|
||||
}));
|
||||
}
|
||||
|
||||
@@ -333,7 +330,8 @@ async fn saturation_grace_exhaustion_under_concurrency_keeps_peer_throttled() {
|
||||
final_state.fail_streak
|
||||
>= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
|
||||
);
|
||||
assert!(auth_probe_should_apply_preauth_throttle(
|
||||
assert!(auth_probe_should_apply_preauth_throttle_in(
|
||||
shared.as_ref(),
|
||||
peer_ip,
|
||||
Instant::now()
|
||||
));
|
||||
|
||||
@@ -1,46 +1,39 @@
|
||||
use super::*;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||
auth_probe_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
fn poison_saturation_mutex() {
|
||||
let saturation = auth_probe_saturation_state();
|
||||
let poison_thread = std::thread::spawn(move || {
|
||||
fn poison_saturation_mutex(shared: &ProxySharedState) {
|
||||
let saturation = auth_probe_saturation_state_for_testing_in_shared(shared);
|
||||
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
let _guard = saturation
|
||||
.lock()
|
||||
.expect("saturation mutex must be lockable for poison setup");
|
||||
panic!("intentional poison for saturation mutex resilience test");
|
||||
});
|
||||
let _ = poison_thread.join();
|
||||
}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_probe_saturation_note_recovers_after_mutex_poison() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
poison_saturation_mutex();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
poison_saturation_mutex(shared.as_ref());
|
||||
|
||||
let now = Instant::now();
|
||||
auth_probe_note_saturation(now);
|
||||
auth_probe_note_saturation_in(shared.as_ref(), now);
|
||||
|
||||
assert!(
|
||||
auth_probe_saturation_is_throttled_at_for_testing(now),
|
||||
auth_probe_saturation_is_throttled_at_for_testing_in_shared(shared.as_ref(), now),
|
||||
"poisoned saturation mutex must not disable saturation throttling"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_probe_saturation_check_recovers_after_mutex_poison() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
poison_saturation_mutex();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
poison_saturation_mutex(shared.as_ref());
|
||||
|
||||
{
|
||||
let mut guard = auth_probe_saturation_state_lock();
|
||||
let mut guard = auth_probe_saturation_state_lock_for_testing_in_shared(shared.as_ref());
|
||||
*guard = Some(AuthProbeSaturationState {
|
||||
fail_streak: AUTH_PROBE_BACKOFF_START_FAILS,
|
||||
blocked_until: Instant::now() + Duration::from_millis(10),
|
||||
@@ -49,23 +42,25 @@ fn auth_probe_saturation_check_recovers_after_mutex_poison() {
|
||||
}
|
||||
|
||||
assert!(
|
||||
auth_probe_saturation_is_throttled_for_testing(),
|
||||
auth_probe_saturation_is_throttled_for_testing_in_shared(shared.as_ref()),
|
||||
"throttle check must recover poisoned saturation mutex and stay fail-closed"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clear_auth_probe_state_clears_saturation_even_if_poisoned() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
poison_saturation_mutex();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
poison_saturation_mutex(shared.as_ref());
|
||||
|
||||
auth_probe_note_saturation(Instant::now());
|
||||
assert!(auth_probe_saturation_is_throttled_for_testing());
|
||||
auth_probe_note_saturation_in(shared.as_ref(), Instant::now());
|
||||
assert!(auth_probe_saturation_is_throttled_for_testing_in_shared(
|
||||
shared.as_ref()
|
||||
));
|
||||
|
||||
clear_auth_probe_state_for_testing();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
assert!(
|
||||
!auth_probe_saturation_is_throttled_for_testing(),
|
||||
!auth_probe_saturation_is_throttled_for_testing_in_shared(shared.as_ref()),
|
||||
"clear helper must clear saturation state even after poison"
|
||||
);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,12 +4,6 @@ use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION};
|
||||
use std::net::SocketAddr;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||
auth_probe_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
fn make_valid_mtproto_handshake(
|
||||
secret_hex: &str,
|
||||
proto_tag: ProtoTag,
|
||||
@@ -149,8 +143,8 @@ fn median_ns(samples: &mut [u128]) -> u128 {
|
||||
#[tokio::test]
|
||||
#[ignore = "manual benchmark: timing-sensitive and host-dependent"]
|
||||
async fn mtproto_user_scan_timing_manual_benchmark() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
clear_auth_probe_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
const DECOY_USERS: usize = 8_000;
|
||||
const ITERATIONS: usize = 250;
|
||||
@@ -243,7 +237,7 @@ async fn mtproto_user_scan_timing_manual_benchmark() {
|
||||
#[tokio::test]
|
||||
#[ignore = "manual benchmark: timing-sensitive and host-dependent"]
|
||||
async fn tls_sni_preferred_vs_no_sni_fallback_manual_benchmark() {
|
||||
let _guard = auth_probe_test_guard();
|
||||
let shared = ProxySharedState::new();
|
||||
|
||||
const DECOY_USERS: usize = 8_000;
|
||||
const ITERATIONS: usize = 250;
|
||||
@@ -281,7 +275,7 @@ async fn tls_sni_preferred_vs_no_sni_fallback_manual_benchmark() {
|
||||
let no_sni = make_valid_tls_handshake(&target_secret, (i as u32).wrapping_add(10_000));
|
||||
|
||||
let started_sni = Instant::now();
|
||||
let sni_secrets = decode_user_secrets(&config, Some(preferred_user));
|
||||
let sni_secrets = decode_user_secrets_in(shared.as_ref(), &config, Some(preferred_user));
|
||||
let sni_result = tls::validate_tls_handshake_with_replay_window(
|
||||
&with_sni,
|
||||
&sni_secrets,
|
||||
@@ -292,7 +286,7 @@ async fn tls_sni_preferred_vs_no_sni_fallback_manual_benchmark() {
|
||||
assert!(sni_result.is_some());
|
||||
|
||||
let started_no_sni = Instant::now();
|
||||
let no_sni_secrets = decode_user_secrets(&config, None);
|
||||
let no_sni_secrets = decode_user_secrets_in(shared.as_ref(), &config, None);
|
||||
let no_sni_result = tls::validate_tls_handshake_with_replay_window(
|
||||
&no_sni,
|
||||
&no_sni_secrets,
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
use super::*;
|
||||
use tokio::io::duplex;
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
#[test]
|
||||
fn masking_baseline_timing_normalization_budget_within_bounds() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.censorship.mask_timing_normalization_enabled = true;
|
||||
config.censorship.mask_timing_normalization_floor_ms = 120;
|
||||
config.censorship.mask_timing_normalization_ceiling_ms = 180;
|
||||
|
||||
for _ in 0..256 {
|
||||
let budget = mask_outcome_target_budget(&config);
|
||||
assert!(budget >= Duration::from_millis(120));
|
||||
assert!(budget <= Duration::from_millis(180));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn masking_baseline_fallback_relays_to_mask_host() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let backend_addr = listener.local_addr().unwrap();
|
||||
let initial = b"GET /baseline HTTP/1.1\r\nHost: x\r\n\r\n".to_vec();
|
||||
let reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
|
||||
|
||||
let accept_task = tokio::spawn({
|
||||
let initial = initial.clone();
|
||||
let reply = reply.clone();
|
||||
async move {
|
||||
let (mut stream, _) = listener.accept().await.unwrap();
|
||||
let mut seen = vec![0u8; initial.len()];
|
||||
stream.read_exact(&mut seen).await.unwrap();
|
||||
assert_eq!(seen, initial);
|
||||
stream.write_all(&reply).await.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = false;
|
||||
config.censorship.mask = true;
|
||||
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
config.censorship.mask_port = backend_addr.port();
|
||||
config.censorship.mask_unix_sock = None;
|
||||
config.censorship.mask_proxy_protocol = 0;
|
||||
|
||||
let peer: SocketAddr = "203.0.113.70:55070".parse().unwrap();
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
|
||||
let (client_reader, _client_writer) = duplex(1024);
|
||||
let (mut visible_reader, visible_writer) = duplex(2048);
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
handle_bad_client(
|
||||
client_reader,
|
||||
visible_writer,
|
||||
&initial,
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut observed = vec![0u8; reply.len()];
|
||||
visible_reader.read_exact(&mut observed).await.unwrap();
|
||||
assert_eq!(observed, reply);
|
||||
accept_task.await.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn masking_baseline_no_normalization_returns_default_budget() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.censorship.mask_timing_normalization_enabled = false;
|
||||
let budget = mask_outcome_target_budget(&config);
|
||||
assert_eq!(budget, MASK_TIMEOUT);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn masking_baseline_unreachable_mask_host_silent_failure() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = false;
|
||||
config.censorship.mask = true;
|
||||
config.censorship.mask_unix_sock = None;
|
||||
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
config.censorship.mask_port = 1;
|
||||
config.censorship.mask_timing_normalization_enabled = false;
|
||||
|
||||
let peer: SocketAddr = "203.0.113.71:55071".parse().unwrap();
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
let (client_reader, _client_writer) = duplex(1024);
|
||||
let (mut visible_reader, visible_writer) = duplex(1024);
|
||||
|
||||
let started = Instant::now();
|
||||
handle_bad_client(
|
||||
client_reader,
|
||||
visible_writer,
|
||||
b"GET / HTTP/1.1\r\n\r\n",
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
let elapsed = started.elapsed();
|
||||
|
||||
assert!(elapsed < Duration::from_secs(1));
|
||||
|
||||
let mut buf = [0u8; 1];
|
||||
let read_res = timeout(Duration::from_millis(50), visible_reader.read(&mut buf)).await;
|
||||
match read_res {
|
||||
Ok(Ok(0)) | Err(_) => {}
|
||||
Ok(Ok(n)) => panic!("expected no response bytes, got {n}"),
|
||||
Ok(Err(e)) => panic!("unexpected client-side read error: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn masking_baseline_light_fuzz_initial_data_no_panic() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.general.beobachten = false;
|
||||
config.censorship.mask = false;
|
||||
|
||||
let peer: SocketAddr = "203.0.113.72:55072".parse().unwrap();
|
||||
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
let beobachten = BeobachtenStore::new();
|
||||
|
||||
let corpus: Vec<Vec<u8>> = vec![
|
||||
vec![],
|
||||
vec![0x00],
|
||||
vec![0xFF; 1024],
|
||||
(0..255u8).collect(),
|
||||
b"\xF0\x28\x8C\x28".to_vec(),
|
||||
];
|
||||
|
||||
for sample in corpus {
|
||||
let (client_reader, _client_writer) = duplex(1024);
|
||||
let (_visible_reader, visible_writer) = duplex(1024);
|
||||
timeout(
|
||||
Duration::from_millis(300),
|
||||
handle_bad_client(
|
||||
client_reader,
|
||||
visible_writer,
|
||||
&sample,
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("fuzz sample must complete in bounded time");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,336 @@
|
||||
use super::*;
|
||||
use rand::SeedableRng;
|
||||
use rand::rngs::StdRng;
|
||||
|
||||
fn seeded_rng(seed: u64) -> StdRng {
|
||||
StdRng::seed_from_u64(seed)
|
||||
}
|
||||
|
||||
// ── Positive: all samples within configured envelope ────────────────────
|
||||
|
||||
#[test]
|
||||
fn masking_lognormal_all_samples_within_configured_envelope() {
|
||||
let mut rng = seeded_rng(42);
|
||||
let floor: u64 = 500;
|
||||
let ceiling: u64 = 2000;
|
||||
for _ in 0..10_000 {
|
||||
let val = sample_lognormal_percentile_bounded(floor, ceiling, &mut rng);
|
||||
assert!(
|
||||
val >= floor && val <= ceiling,
|
||||
"sample {} outside [{}, {}]",
|
||||
val,
|
||||
floor,
|
||||
ceiling,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Statistical: median near geometric mean ─────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn masking_lognormal_sample_median_near_geometric_mean_of_range() {
|
||||
let mut rng = seeded_rng(42);
|
||||
let floor: u64 = 500;
|
||||
let ceiling: u64 = 2000;
|
||||
let geometric_mean = ((floor as f64) * (ceiling as f64)).sqrt();
|
||||
|
||||
let mut samples: Vec<u64> = (0..10_000)
|
||||
.map(|_| sample_lognormal_percentile_bounded(floor, ceiling, &mut rng))
|
||||
.collect();
|
||||
samples.sort();
|
||||
let median = samples[samples.len() / 2] as f64;
|
||||
|
||||
let tolerance = geometric_mean * 0.10;
|
||||
assert!(
|
||||
(median - geometric_mean).abs() <= tolerance,
|
||||
"median {} not within 10% of geometric mean {} (tolerance {})",
|
||||
median,
|
||||
geometric_mean,
|
||||
tolerance,
|
||||
);
|
||||
}
|
||||
|
||||
// ── Edge: degenerate floor == ceiling returns exactly that value ─────────
|
||||
|
||||
#[test]
|
||||
fn masking_lognormal_degenerate_floor_eq_ceiling_returns_floor() {
|
||||
let mut rng = seeded_rng(99);
|
||||
for _ in 0..100 {
|
||||
let val = sample_lognormal_percentile_bounded(1000, 1000, &mut rng);
|
||||
assert_eq!(
|
||||
val, 1000,
|
||||
"floor == ceiling must always return exactly that value"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Edge: floor > ceiling (misconfiguration) clamps safely ──────────────
|
||||
|
||||
#[test]
|
||||
fn masking_lognormal_floor_greater_than_ceiling_returns_ceiling() {
|
||||
let mut rng = seeded_rng(77);
|
||||
let val = sample_lognormal_percentile_bounded(2000, 500, &mut rng);
|
||||
assert_eq!(
|
||||
val, 500,
|
||||
"floor > ceiling misconfiguration must return ceiling (the minimum)"
|
||||
);
|
||||
}
|
||||
|
||||
// ── Edge: floor == 1, ceiling == 1 ──────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn masking_lognormal_floor_1_ceiling_1_returns_1() {
|
||||
let mut rng = seeded_rng(12);
|
||||
let val = sample_lognormal_percentile_bounded(1, 1, &mut rng);
|
||||
assert_eq!(val, 1);
|
||||
}
|
||||
|
||||
// ── Edge: floor == 1, ceiling very large ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn masking_lognormal_wide_range_all_samples_within_bounds() {
|
||||
let mut rng = seeded_rng(55);
|
||||
let floor: u64 = 1;
|
||||
let ceiling: u64 = 100_000;
|
||||
for _ in 0..10_000 {
|
||||
let val = sample_lognormal_percentile_bounded(floor, ceiling, &mut rng);
|
||||
assert!(
|
||||
val >= floor && val <= ceiling,
|
||||
"sample {} outside [{}, {}]",
|
||||
val,
|
||||
floor,
|
||||
ceiling,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Adversarial: extreme sigma (floor very close to ceiling) ────────────
|
||||
|
||||
#[test]
|
||||
fn masking_lognormal_narrow_range_does_not_panic() {
|
||||
let mut rng = seeded_rng(88);
|
||||
let floor: u64 = 999;
|
||||
let ceiling: u64 = 1001;
|
||||
for _ in 0..10_000 {
|
||||
let val = sample_lognormal_percentile_bounded(floor, ceiling, &mut rng);
|
||||
assert!(
|
||||
val >= floor && val <= ceiling,
|
||||
"narrow range sample {} outside [{}, {}]",
|
||||
val,
|
||||
floor,
|
||||
ceiling,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Adversarial: u64::MAX ceiling does not overflow ──────────────────────
|
||||
|
||||
#[test]
|
||||
fn masking_lognormal_u64_max_ceiling_no_overflow() {
|
||||
let mut rng = seeded_rng(123);
|
||||
let floor: u64 = 1;
|
||||
let ceiling: u64 = u64::MAX;
|
||||
for _ in 0..1000 {
|
||||
let val = sample_lognormal_percentile_bounded(floor, ceiling, &mut rng);
|
||||
assert!(val >= floor, "sample {} below floor {}", val, floor);
|
||||
// u64::MAX clamp ensures no overflow
|
||||
}
|
||||
}
|
||||
|
||||
// ── Adversarial: floor == 0 guard ───────────────────────────────────────
|
||||
// The function should handle floor=0 gracefully even though callers
|
||||
// should never pass it. Verifies no panic on ln(0).
|
||||
|
||||
#[test]
|
||||
fn masking_lognormal_floor_zero_no_panic() {
|
||||
let mut rng = seeded_rng(200);
|
||||
let val = sample_lognormal_percentile_bounded(0, 1000, &mut rng);
|
||||
assert!(val <= 1000, "sample {} exceeds ceiling 1000", val);
|
||||
}
|
||||
|
||||
// ── Adversarial: both zero → returns 0 ──────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn masking_lognormal_both_zero_returns_zero() {
|
||||
let mut rng = seeded_rng(201);
|
||||
let val = sample_lognormal_percentile_bounded(0, 0, &mut rng);
|
||||
assert_eq!(val, 0, "floor=0 ceiling=0 must return 0");
|
||||
}
|
||||
|
||||
// ── Distribution shape: not uniform ─────────────────────────────────────
|
||||
// A DPI classifier trained on uniform delay samples should detect a
|
||||
// distribution where > 60% of samples fall in the lower half of the range.
|
||||
// Log-normal is right-skewed: more samples near floor than ceiling.
|
||||
|
||||
#[test]
|
||||
fn masking_lognormal_distribution_is_right_skewed() {
|
||||
let mut rng = seeded_rng(42);
|
||||
let floor: u64 = 100;
|
||||
let ceiling: u64 = 5000;
|
||||
let midpoint = (floor + ceiling) / 2;
|
||||
|
||||
let samples: Vec<u64> = (0..10_000)
|
||||
.map(|_| sample_lognormal_percentile_bounded(floor, ceiling, &mut rng))
|
||||
.collect();
|
||||
|
||||
let below_mid = samples.iter().filter(|&&s| s < midpoint).count();
|
||||
let ratio = below_mid as f64 / samples.len() as f64;
|
||||
|
||||
assert!(
|
||||
ratio > 0.55,
|
||||
"Log-normal should be right-skewed (>55% below midpoint), got {}%",
|
||||
ratio * 100.0,
|
||||
);
|
||||
}
|
||||
|
||||
// ── Determinism: same seed produces same sequence ───────────────────────
|
||||
|
||||
#[test]
|
||||
fn masking_lognormal_deterministic_with_same_seed() {
|
||||
let mut rng1 = seeded_rng(42);
|
||||
let mut rng2 = seeded_rng(42);
|
||||
for _ in 0..100 {
|
||||
let a = sample_lognormal_percentile_bounded(500, 2000, &mut rng1);
|
||||
let b = sample_lognormal_percentile_bounded(500, 2000, &mut rng2);
|
||||
assert_eq!(a, b, "Same seed must produce same output");
|
||||
}
|
||||
}
|
||||
|
||||
// ── Fuzz: 1000 random (floor, ceiling) pairs, no panics ─────────────────
|
||||
|
||||
#[test]
|
||||
fn masking_lognormal_fuzz_random_params_no_panic() {
|
||||
use rand::Rng;
|
||||
let mut rng = seeded_rng(999);
|
||||
for _ in 0..1000 {
|
||||
let a: u64 = rng.random_range(0..=10_000);
|
||||
let b: u64 = rng.random_range(0..=10_000);
|
||||
let floor = a.min(b);
|
||||
let ceiling = a.max(b);
|
||||
let val = sample_lognormal_percentile_bounded(floor, ceiling, &mut rng);
|
||||
assert!(
|
||||
val >= floor && val <= ceiling,
|
||||
"fuzz: sample {} outside [{}, {}]",
|
||||
val,
|
||||
floor,
|
||||
ceiling,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Fuzz: adversarial floor > ceiling pairs ──────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn masking_lognormal_fuzz_inverted_params_no_panic() {
|
||||
use rand::Rng;
|
||||
let mut rng = seeded_rng(777);
|
||||
for _ in 0..500 {
|
||||
let floor: u64 = rng.random_range(1..=10_000);
|
||||
let ceiling: u64 = rng.random_range(0..floor);
|
||||
// When floor > ceiling, must return ceiling (the smaller value)
|
||||
let val = sample_lognormal_percentile_bounded(floor, ceiling, &mut rng);
|
||||
assert_eq!(
|
||||
val, ceiling,
|
||||
"inverted: floor={} ceiling={} should return ceiling, got {}",
|
||||
floor, ceiling, val,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Security: clamp spike check ─────────────────────────────────────────
|
||||
// With well-parameterized sigma, no more than 5% of samples should be
|
||||
// at exactly floor or exactly ceiling (clamp spikes). A spike > 10%
|
||||
// is detectable by DPI as bimodal.
|
||||
|
||||
#[test]
|
||||
fn masking_lognormal_no_clamp_spike_at_boundaries() {
|
||||
let mut rng = seeded_rng(42);
|
||||
let floor: u64 = 500;
|
||||
let ceiling: u64 = 2000;
|
||||
let n = 10_000;
|
||||
let samples: Vec<u64> = (0..n)
|
||||
.map(|_| sample_lognormal_percentile_bounded(floor, ceiling, &mut rng))
|
||||
.collect();
|
||||
|
||||
let at_floor = samples.iter().filter(|&&s| s == floor).count();
|
||||
let at_ceiling = samples.iter().filter(|&&s| s == ceiling).count();
|
||||
let floor_pct = at_floor as f64 / n as f64;
|
||||
let ceiling_pct = at_ceiling as f64 / n as f64;
|
||||
|
||||
assert!(
|
||||
floor_pct < 0.05,
|
||||
"floor clamp spike: {}% of samples at exactly floor (max 5%)",
|
||||
floor_pct * 100.0,
|
||||
);
|
||||
assert!(
|
||||
ceiling_pct < 0.05,
|
||||
"ceiling clamp spike: {}% of samples at exactly ceiling (max 5%)",
|
||||
ceiling_pct * 100.0,
|
||||
);
|
||||
}
|
||||
|
||||
// ── Integration: mask_outcome_target_budget uses log-normal for path 3 ──
|
||||
|
||||
#[tokio::test]
|
||||
async fn masking_lognormal_integration_budget_within_bounds() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.censorship.mask_timing_normalization_enabled = true;
|
||||
config.censorship.mask_timing_normalization_floor_ms = 500;
|
||||
config.censorship.mask_timing_normalization_ceiling_ms = 2000;
|
||||
|
||||
for _ in 0..100 {
|
||||
let budget = mask_outcome_target_budget(&config);
|
||||
let ms = budget.as_millis() as u64;
|
||||
assert!(
|
||||
ms >= 500 && ms <= 2000,
|
||||
"budget {} ms outside [500, 2000]",
|
||||
ms,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Integration: floor == 0 path stays uniform (NOT log-normal) ─────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn masking_lognormal_floor_zero_path_stays_uniform() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.censorship.mask_timing_normalization_enabled = true;
|
||||
config.censorship.mask_timing_normalization_floor_ms = 0;
|
||||
config.censorship.mask_timing_normalization_ceiling_ms = 1000;
|
||||
|
||||
for _ in 0..100 {
|
||||
let budget = mask_outcome_target_budget(&config);
|
||||
let ms = budget.as_millis() as u64;
|
||||
// floor=0 path uses uniform [0, ceiling], not log-normal
|
||||
assert!(ms <= 1000, "budget {} ms exceeds ceiling 1000", ms);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Integration: floor > ceiling misconfiguration is safe ───────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn masking_lognormal_misconfigured_floor_gt_ceiling_safe() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.censorship.mask_timing_normalization_enabled = true;
|
||||
config.censorship.mask_timing_normalization_floor_ms = 2000;
|
||||
config.censorship.mask_timing_normalization_ceiling_ms = 500;
|
||||
|
||||
let budget = mask_outcome_target_budget(&config);
|
||||
let ms = budget.as_millis() as u64;
|
||||
// floor > ceiling: should not exceed the minimum of the two
|
||||
assert!(
|
||||
ms <= 2000,
|
||||
"misconfigured budget {} ms should be bounded",
|
||||
ms,
|
||||
);
|
||||
}
|
||||
|
||||
// ── Stress: rapid repeated calls do not panic or starve ─────────────────
|
||||
|
||||
#[test]
|
||||
fn masking_lognormal_stress_rapid_calls_no_panic() {
|
||||
let mut rng = seeded_rng(42);
|
||||
for _ in 0..100_000 {
|
||||
let _ = sample_lognormal_percentile_bounded(100, 5000, &mut rng);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
use super::*;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
#[test]
|
||||
fn middle_relay_baseline_public_api_idle_roundtrip_contract() {
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 7001));
|
||||
assert_eq!(
|
||||
oldest_relay_idle_candidate_for_testing(shared.as_ref()),
|
||||
Some(7001)
|
||||
);
|
||||
|
||||
clear_relay_idle_candidate_for_testing(shared.as_ref(), 7001);
|
||||
assert_ne!(
|
||||
oldest_relay_idle_candidate_for_testing(shared.as_ref()),
|
||||
Some(7001)
|
||||
);
|
||||
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 7001));
|
||||
assert_eq!(
|
||||
oldest_relay_idle_candidate_for_testing(shared.as_ref()),
|
||||
Some(7001)
|
||||
);
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn middle_relay_baseline_public_api_desync_window_contract() {
|
||||
let shared = ProxySharedState::new();
|
||||
clear_desync_dedup_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let key = 0xDEAD_BEEF_0000_0001u64;
|
||||
let t0 = Instant::now();
|
||||
|
||||
assert!(should_emit_full_desync_for_testing(
|
||||
shared.as_ref(),
|
||||
key,
|
||||
false,
|
||||
t0
|
||||
));
|
||||
assert!(!should_emit_full_desync_for_testing(
|
||||
shared.as_ref(),
|
||||
key,
|
||||
false,
|
||||
t0 + Duration::from_secs(1)
|
||||
));
|
||||
|
||||
let t1 = t0 + DESYNC_DEDUP_WINDOW + Duration::from_millis(10);
|
||||
assert!(should_emit_full_desync_for_testing(
|
||||
shared.as_ref(),
|
||||
key,
|
||||
false,
|
||||
t1
|
||||
));
|
||||
|
||||
clear_desync_dedup_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
@@ -5,22 +5,25 @@ use std::thread;
|
||||
|
||||
#[test]
|
||||
fn desync_all_full_bypass_does_not_initialize_or_grow_dedup_cache() {
|
||||
let _guard = desync_dedup_test_lock()
|
||||
.lock()
|
||||
.expect("desync dedup test lock must be available");
|
||||
clear_desync_dedup_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_desync_dedup_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let initial_len = DESYNC_DEDUP.get().map(|dedup| dedup.len()).unwrap_or(0);
|
||||
let initial_len = desync_dedup_len_for_testing(shared.as_ref());
|
||||
let now = Instant::now();
|
||||
|
||||
for i in 0..20_000u64 {
|
||||
assert!(
|
||||
should_emit_full_desync(0xD35E_D000_0000_0000u64 ^ i, true, now),
|
||||
should_emit_full_desync_for_testing(
|
||||
shared.as_ref(),
|
||||
0xD35E_D000_0000_0000u64 ^ i,
|
||||
true,
|
||||
now
|
||||
),
|
||||
"desync_all_full path must always emit"
|
||||
);
|
||||
}
|
||||
|
||||
let after_len = DESYNC_DEDUP.get().map(|dedup| dedup.len()).unwrap_or(0);
|
||||
let after_len = desync_dedup_len_for_testing(shared.as_ref());
|
||||
assert_eq!(
|
||||
after_len, initial_len,
|
||||
"desync_all_full bypass must not allocate or accumulate dedup entries"
|
||||
@@ -29,39 +32,39 @@ fn desync_all_full_bypass_does_not_initialize_or_grow_dedup_cache() {
|
||||
|
||||
#[test]
|
||||
fn desync_all_full_bypass_keeps_existing_dedup_entries_unchanged() {
|
||||
let _guard = desync_dedup_test_lock()
|
||||
.lock()
|
||||
.expect("desync dedup test lock must be available");
|
||||
clear_desync_dedup_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_desync_dedup_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
|
||||
let seed_time = Instant::now() - Duration::from_secs(7);
|
||||
dedup.insert(0xAAAABBBBCCCCDDDD, seed_time);
|
||||
dedup.insert(0x1111222233334444, seed_time);
|
||||
desync_dedup_insert_for_testing(shared.as_ref(), 0xAAAABBBBCCCCDDDD, seed_time);
|
||||
desync_dedup_insert_for_testing(shared.as_ref(), 0x1111222233334444, seed_time);
|
||||
|
||||
let now = Instant::now();
|
||||
for i in 0..2048u64 {
|
||||
assert!(
|
||||
should_emit_full_desync(0xF011_F000_0000_0000u64 ^ i, true, now),
|
||||
should_emit_full_desync_for_testing(
|
||||
shared.as_ref(),
|
||||
0xF011_F000_0000_0000u64 ^ i,
|
||||
true,
|
||||
now
|
||||
),
|
||||
"desync_all_full must bypass suppression and dedup refresh"
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
dedup.len(),
|
||||
desync_dedup_len_for_testing(shared.as_ref()),
|
||||
2,
|
||||
"bypass path must not mutate dedup cardinality"
|
||||
);
|
||||
assert_eq!(
|
||||
*dedup
|
||||
.get(&0xAAAABBBBCCCCDDDD)
|
||||
desync_dedup_get_for_testing(shared.as_ref(), 0xAAAABBBBCCCCDDDD)
|
||||
.expect("seed key must remain"),
|
||||
seed_time,
|
||||
"bypass path must not refresh existing dedup timestamps"
|
||||
);
|
||||
assert_eq!(
|
||||
*dedup
|
||||
.get(&0x1111222233334444)
|
||||
desync_dedup_get_for_testing(shared.as_ref(), 0x1111222233334444)
|
||||
.expect("seed key must remain"),
|
||||
seed_time,
|
||||
"bypass path must not touch unrelated dedup entries"
|
||||
@@ -70,14 +73,13 @@ fn desync_all_full_bypass_keeps_existing_dedup_entries_unchanged() {
|
||||
|
||||
#[test]
|
||||
fn edge_all_full_burst_does_not_poison_later_false_path_tracking() {
|
||||
let _guard = desync_dedup_test_lock()
|
||||
.lock()
|
||||
.expect("desync dedup test lock must be available");
|
||||
clear_desync_dedup_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_desync_dedup_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let now = Instant::now();
|
||||
for i in 0..8192u64 {
|
||||
assert!(should_emit_full_desync(
|
||||
assert!(should_emit_full_desync_for_testing(
|
||||
shared.as_ref(),
|
||||
0xABCD_0000_0000_0000 ^ i,
|
||||
true,
|
||||
now
|
||||
@@ -86,26 +88,20 @@ fn edge_all_full_burst_does_not_poison_later_false_path_tracking() {
|
||||
|
||||
let tracked_key = 0xDEAD_BEEF_0000_0001u64;
|
||||
assert!(
|
||||
should_emit_full_desync(tracked_key, false, now),
|
||||
should_emit_full_desync_for_testing(shared.as_ref(), tracked_key, false, now),
|
||||
"first false-path event after all_full burst must still be tracked and emitted"
|
||||
);
|
||||
|
||||
let dedup = DESYNC_DEDUP
|
||||
.get()
|
||||
.expect("false path should initialize dedup");
|
||||
assert!(dedup.get(&tracked_key).is_some());
|
||||
assert!(desync_dedup_get_for_testing(shared.as_ref(), tracked_key).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adversarial_mixed_sequence_true_steps_never_change_cache_len() {
|
||||
let _guard = desync_dedup_test_lock()
|
||||
.lock()
|
||||
.expect("desync dedup test lock must be available");
|
||||
clear_desync_dedup_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_desync_dedup_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
|
||||
for i in 0..256u64 {
|
||||
dedup.insert(0x1000_0000_0000_0000 ^ i, Instant::now());
|
||||
desync_dedup_insert_for_testing(shared.as_ref(), 0x1000_0000_0000_0000 ^ i, Instant::now());
|
||||
}
|
||||
|
||||
let mut seed = 0xC0DE_CAFE_BAAD_F00Du64;
|
||||
@@ -116,9 +112,14 @@ fn adversarial_mixed_sequence_true_steps_never_change_cache_len() {
|
||||
|
||||
let flag_all_full = (seed & 0x1) == 1;
|
||||
let key = 0x7000_0000_0000_0000u64 ^ i ^ seed;
|
||||
let before = dedup.len();
|
||||
let _ = should_emit_full_desync(key, flag_all_full, Instant::now());
|
||||
let after = dedup.len();
|
||||
let before = desync_dedup_len_for_testing(shared.as_ref());
|
||||
let _ = should_emit_full_desync_for_testing(
|
||||
shared.as_ref(),
|
||||
key,
|
||||
flag_all_full,
|
||||
Instant::now(),
|
||||
);
|
||||
let after = desync_dedup_len_for_testing(shared.as_ref());
|
||||
|
||||
if flag_all_full {
|
||||
assert_eq!(after, before, "all_full step must not mutate dedup length");
|
||||
@@ -128,50 +129,51 @@ fn adversarial_mixed_sequence_true_steps_never_change_cache_len() {
|
||||
|
||||
#[test]
|
||||
fn light_fuzz_all_full_mode_always_emits_and_stays_bounded() {
|
||||
let _guard = desync_dedup_test_lock()
|
||||
.lock()
|
||||
.expect("desync dedup test lock must be available");
|
||||
clear_desync_dedup_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_desync_dedup_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let mut seed = 0x1234_5678_9ABC_DEF0u64;
|
||||
let before = DESYNC_DEDUP.get().map(|d| d.len()).unwrap_or(0);
|
||||
let before = desync_dedup_len_for_testing(shared.as_ref());
|
||||
|
||||
for _ in 0..20_000 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
let key = seed ^ 0x55AA_55AA_55AA_55AAu64;
|
||||
assert!(should_emit_full_desync(key, true, Instant::now()));
|
||||
assert!(should_emit_full_desync_for_testing(
|
||||
shared.as_ref(),
|
||||
key,
|
||||
true,
|
||||
Instant::now()
|
||||
));
|
||||
}
|
||||
|
||||
let after = DESYNC_DEDUP.get().map(|d| d.len()).unwrap_or(0);
|
||||
let after = desync_dedup_len_for_testing(shared.as_ref());
|
||||
assert_eq!(after, before);
|
||||
assert!(after <= DESYNC_DEDUP_MAX_ENTRIES);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stress_parallel_all_full_storm_does_not_grow_or_mutate_cache() {
|
||||
let _guard = desync_dedup_test_lock()
|
||||
.lock()
|
||||
.expect("desync dedup test lock must be available");
|
||||
clear_desync_dedup_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_desync_dedup_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
|
||||
let seed_time = Instant::now() - Duration::from_secs(2);
|
||||
for i in 0..1024u64 {
|
||||
dedup.insert(0x8888_0000_0000_0000 ^ i, seed_time);
|
||||
desync_dedup_insert_for_testing(shared.as_ref(), 0x8888_0000_0000_0000 ^ i, seed_time);
|
||||
}
|
||||
let before_len = dedup.len();
|
||||
let before_len = desync_dedup_len_for_testing(shared.as_ref());
|
||||
|
||||
let emits = Arc::new(AtomicUsize::new(0));
|
||||
let mut workers = Vec::new();
|
||||
for worker in 0..16u64 {
|
||||
let emits = Arc::clone(&emits);
|
||||
let shared = shared.clone();
|
||||
workers.push(thread::spawn(move || {
|
||||
let now = Instant::now();
|
||||
for i in 0..4096u64 {
|
||||
let key = 0xFACE_0000_0000_0000u64 ^ (worker << 20) ^ i;
|
||||
if should_emit_full_desync(key, true, now) {
|
||||
if should_emit_full_desync_for_testing(shared.as_ref(), key, true, now) {
|
||||
emits.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
@@ -184,7 +186,7 @@ fn stress_parallel_all_full_storm_does_not_grow_or_mutate_cache() {
|
||||
|
||||
assert_eq!(emits.load(Ordering::Relaxed), 16 * 4096);
|
||||
assert_eq!(
|
||||
dedup.len(),
|
||||
desync_dedup_len_for_testing(shared.as_ref()),
|
||||
before_len,
|
||||
"parallel all_full storm must not mutate cache len"
|
||||
);
|
||||
|
||||
@@ -360,73 +360,103 @@ async fn stress_many_idle_sessions_fail_closed_without_hang() {
|
||||
|
||||
#[test]
|
||||
fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() {
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
let stats = Stats::new();
|
||||
|
||||
assert!(mark_relay_idle_candidate(10));
|
||||
assert!(mark_relay_idle_candidate(11));
|
||||
assert_eq!(oldest_relay_idle_candidate(), Some(10));
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 10));
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 11));
|
||||
assert_eq!(
|
||||
oldest_relay_idle_candidate_for_testing(shared.as_ref()),
|
||||
Some(10)
|
||||
);
|
||||
|
||||
note_relay_pressure_event();
|
||||
note_relay_pressure_event_for_testing(shared.as_ref());
|
||||
|
||||
let mut seen_for_newer = 0u64;
|
||||
assert!(
|
||||
!maybe_evict_idle_candidate_on_pressure(11, &mut seen_for_newer, &stats),
|
||||
!maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
11,
|
||||
&mut seen_for_newer,
|
||||
&stats
|
||||
),
|
||||
"newer idle candidate must not be evicted while older candidate exists"
|
||||
);
|
||||
assert_eq!(oldest_relay_idle_candidate(), Some(10));
|
||||
assert_eq!(
|
||||
oldest_relay_idle_candidate_for_testing(shared.as_ref()),
|
||||
Some(10)
|
||||
);
|
||||
|
||||
let mut seen_for_oldest = 0u64;
|
||||
assert!(
|
||||
maybe_evict_idle_candidate_on_pressure(10, &mut seen_for_oldest, &stats),
|
||||
maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
10,
|
||||
&mut seen_for_oldest,
|
||||
&stats
|
||||
),
|
||||
"oldest idle candidate must be evicted first under pressure"
|
||||
);
|
||||
assert_eq!(oldest_relay_idle_candidate(), Some(11));
|
||||
assert_eq!(
|
||||
oldest_relay_idle_candidate_for_testing(shared.as_ref()),
|
||||
Some(11)
|
||||
);
|
||||
assert_eq!(stats.get_relay_pressure_evict_total(), 1);
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pressure_does_not_evict_without_new_pressure_signal() {
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
let stats = Stats::new();
|
||||
|
||||
assert!(mark_relay_idle_candidate(21));
|
||||
let mut seen = relay_pressure_event_seq();
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 21));
|
||||
let mut seen = relay_pressure_event_seq_for_testing(shared.as_ref());
|
||||
|
||||
assert!(
|
||||
!maybe_evict_idle_candidate_on_pressure(21, &mut seen, &stats),
|
||||
!maybe_evict_idle_candidate_on_pressure_for_testing(shared.as_ref(), 21, &mut seen, &stats),
|
||||
"without new pressure signal, candidate must stay"
|
||||
);
|
||||
assert_eq!(stats.get_relay_pressure_evict_total(), 0);
|
||||
assert_eq!(oldest_relay_idle_candidate(), Some(21));
|
||||
assert_eq!(
|
||||
oldest_relay_idle_candidate_for_testing(shared.as_ref()),
|
||||
Some(21)
|
||||
);
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stress_pressure_eviction_preserves_fifo_across_many_candidates() {
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
let stats = Stats::new();
|
||||
|
||||
let mut seen_per_conn = std::collections::HashMap::new();
|
||||
for conn_id in 1000u64..1064u64 {
|
||||
assert!(mark_relay_idle_candidate(conn_id));
|
||||
assert!(mark_relay_idle_candidate_for_testing(
|
||||
shared.as_ref(),
|
||||
conn_id
|
||||
));
|
||||
seen_per_conn.insert(conn_id, 0u64);
|
||||
}
|
||||
|
||||
for expected in 1000u64..1064u64 {
|
||||
note_relay_pressure_event();
|
||||
note_relay_pressure_event_for_testing(shared.as_ref());
|
||||
|
||||
let mut seen = *seen_per_conn
|
||||
.get(&expected)
|
||||
.expect("per-conn pressure cursor must exist");
|
||||
assert!(
|
||||
maybe_evict_idle_candidate_on_pressure(expected, &mut seen, &stats),
|
||||
maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
expected,
|
||||
&mut seen,
|
||||
&stats
|
||||
),
|
||||
"expected conn_id {expected} must be evicted next by deterministic FIFO ordering"
|
||||
);
|
||||
seen_per_conn.insert(expected, seen);
|
||||
@@ -436,33 +466,51 @@ fn stress_pressure_eviction_preserves_fifo_across_many_candidates() {
|
||||
} else {
|
||||
Some(expected + 1)
|
||||
};
|
||||
assert_eq!(oldest_relay_idle_candidate(), next);
|
||||
assert_eq!(
|
||||
oldest_relay_idle_candidate_for_testing(shared.as_ref()),
|
||||
next
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_relay_pressure_evict_total(), 64);
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() {
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
let stats = Stats::new();
|
||||
|
||||
assert!(mark_relay_idle_candidate(301));
|
||||
assert!(mark_relay_idle_candidate(302));
|
||||
assert!(mark_relay_idle_candidate(303));
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 301));
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 302));
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 303));
|
||||
|
||||
let mut seen_301 = 0u64;
|
||||
let mut seen_302 = 0u64;
|
||||
let mut seen_303 = 0u64;
|
||||
|
||||
// Single pressure event should authorize at most one eviction globally.
|
||||
note_relay_pressure_event();
|
||||
note_relay_pressure_event_for_testing(shared.as_ref());
|
||||
|
||||
let evicted_301 = maybe_evict_idle_candidate_on_pressure(301, &mut seen_301, &stats);
|
||||
let evicted_302 = maybe_evict_idle_candidate_on_pressure(302, &mut seen_302, &stats);
|
||||
let evicted_303 = maybe_evict_idle_candidate_on_pressure(303, &mut seen_303, &stats);
|
||||
let evicted_301 = maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
301,
|
||||
&mut seen_301,
|
||||
&stats,
|
||||
);
|
||||
let evicted_302 = maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
302,
|
||||
&mut seen_302,
|
||||
&stats,
|
||||
);
|
||||
let evicted_303 = maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
303,
|
||||
&mut seen_303,
|
||||
&stats,
|
||||
);
|
||||
|
||||
let evicted_total = [evicted_301, evicted_302, evicted_303]
|
||||
.iter()
|
||||
@@ -474,30 +522,40 @@ fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() {
|
||||
"single pressure event must not cascade-evict multiple idle candidates"
|
||||
);
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() {
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
let stats = Stats::new();
|
||||
|
||||
assert!(mark_relay_idle_candidate(401));
|
||||
assert!(mark_relay_idle_candidate(402));
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 401));
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 402));
|
||||
|
||||
let mut seen_oldest = 0u64;
|
||||
let mut seen_next = 0u64;
|
||||
|
||||
note_relay_pressure_event();
|
||||
note_relay_pressure_event_for_testing(shared.as_ref());
|
||||
|
||||
assert!(
|
||||
maybe_evict_idle_candidate_on_pressure(401, &mut seen_oldest, &stats),
|
||||
maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
401,
|
||||
&mut seen_oldest,
|
||||
&stats
|
||||
),
|
||||
"oldest candidate must consume pressure budget first"
|
||||
);
|
||||
|
||||
assert!(
|
||||
!maybe_evict_idle_candidate_on_pressure(402, &mut seen_next, &stats),
|
||||
!maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
402,
|
||||
&mut seen_next,
|
||||
&stats
|
||||
),
|
||||
"next candidate must not consume the same pressure budget"
|
||||
);
|
||||
|
||||
@@ -507,47 +565,67 @@ fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() {
|
||||
"single pressure budget must produce exactly one eviction"
|
||||
);
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() {
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
let stats = Stats::new();
|
||||
|
||||
// Pressure happened before any idle candidate existed.
|
||||
note_relay_pressure_event();
|
||||
assert!(mark_relay_idle_candidate(501));
|
||||
note_relay_pressure_event_for_testing(shared.as_ref());
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 501));
|
||||
|
||||
let mut seen = 0u64;
|
||||
assert!(
|
||||
!maybe_evict_idle_candidate_on_pressure(501, &mut seen, &stats),
|
||||
!maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
501,
|
||||
&mut seen,
|
||||
&stats
|
||||
),
|
||||
"stale pressure (before soft-idle mark) must not evict newly marked candidate"
|
||||
);
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() {
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
let stats = Stats::new();
|
||||
|
||||
note_relay_pressure_event();
|
||||
assert!(mark_relay_idle_candidate(511));
|
||||
assert!(mark_relay_idle_candidate(512));
|
||||
assert!(mark_relay_idle_candidate(513));
|
||||
note_relay_pressure_event_for_testing(shared.as_ref());
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 511));
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 512));
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 513));
|
||||
|
||||
let mut seen_511 = 0u64;
|
||||
let mut seen_512 = 0u64;
|
||||
let mut seen_513 = 0u64;
|
||||
|
||||
let evicted = [
|
||||
maybe_evict_idle_candidate_on_pressure(511, &mut seen_511, &stats),
|
||||
maybe_evict_idle_candidate_on_pressure(512, &mut seen_512, &stats),
|
||||
maybe_evict_idle_candidate_on_pressure(513, &mut seen_513, &stats),
|
||||
maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
511,
|
||||
&mut seen_511,
|
||||
&stats,
|
||||
),
|
||||
maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
512,
|
||||
&mut seen_512,
|
||||
&stats,
|
||||
),
|
||||
maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
513,
|
||||
&mut seen_513,
|
||||
&stats,
|
||||
),
|
||||
]
|
||||
.iter()
|
||||
.filter(|value| **value)
|
||||
@@ -558,111 +636,118 @@ fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() {
|
||||
"stale pressure event must not evict any candidate from a newly marked batch"
|
||||
);
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated() {
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
let stats = Stats::new();
|
||||
|
||||
note_relay_pressure_event();
|
||||
note_relay_pressure_event_for_testing(shared.as_ref());
|
||||
|
||||
// Session A observed pressure while there were no candidates.
|
||||
let mut seen_a = 0u64;
|
||||
assert!(
|
||||
!maybe_evict_idle_candidate_on_pressure(999_001, &mut seen_a, &stats),
|
||||
!maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
999_001,
|
||||
&mut seen_a,
|
||||
&stats
|
||||
),
|
||||
"no candidate existed, so no eviction is possible"
|
||||
);
|
||||
|
||||
// Candidate appears later; Session B must not be able to consume stale pressure.
|
||||
assert!(mark_relay_idle_candidate(521));
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 521));
|
||||
let mut seen_b = 0u64;
|
||||
assert!(
|
||||
!maybe_evict_idle_candidate_on_pressure(521, &mut seen_b, &stats),
|
||||
!maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
521,
|
||||
&mut seen_b,
|
||||
&stats
|
||||
),
|
||||
"once pressure is observed with empty candidate set, it must not be replayed later"
|
||||
);
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blackhat_stale_pressure_must_not_survive_candidate_churn() {
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
let stats = Stats::new();
|
||||
|
||||
note_relay_pressure_event();
|
||||
assert!(mark_relay_idle_candidate(531));
|
||||
clear_relay_idle_candidate(531);
|
||||
assert!(mark_relay_idle_candidate(532));
|
||||
note_relay_pressure_event_for_testing(shared.as_ref());
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 531));
|
||||
clear_relay_idle_candidate_for_testing(shared.as_ref(), 531);
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 532));
|
||||
|
||||
let mut seen = 0u64;
|
||||
assert!(
|
||||
!maybe_evict_idle_candidate_on_pressure(532, &mut seen, &stats),
|
||||
!maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
532,
|
||||
&mut seen,
|
||||
&stats
|
||||
),
|
||||
"stale pressure must not survive clear+remark churn cycles"
|
||||
);
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting() {
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
{
|
||||
let mut guard = relay_idle_candidate_registry()
|
||||
.lock()
|
||||
.expect("registry lock must be available");
|
||||
guard.pressure_event_seq = u64::MAX;
|
||||
guard.pressure_consumed_seq = u64::MAX - 1;
|
||||
set_relay_pressure_state_for_testing(shared.as_ref(), u64::MAX, u64::MAX - 1);
|
||||
}
|
||||
|
||||
// A new pressure event should still be representable; saturating at MAX creates a permanent lockout.
|
||||
note_relay_pressure_event();
|
||||
let after = relay_pressure_event_seq();
|
||||
note_relay_pressure_event_for_testing(shared.as_ref());
|
||||
let after = relay_pressure_event_seq_for_testing(shared.as_ref());
|
||||
assert_ne!(
|
||||
after,
|
||||
u64::MAX,
|
||||
"pressure sequence saturation must not permanently freeze event progression"
|
||||
);
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() {
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
{
|
||||
let mut guard = relay_idle_candidate_registry()
|
||||
.lock()
|
||||
.expect("registry lock must be available");
|
||||
guard.pressure_event_seq = u64::MAX;
|
||||
guard.pressure_consumed_seq = u64::MAX;
|
||||
set_relay_pressure_state_for_testing(shared.as_ref(), u64::MAX, u64::MAX);
|
||||
}
|
||||
|
||||
note_relay_pressure_event();
|
||||
let first = relay_pressure_event_seq();
|
||||
note_relay_pressure_event();
|
||||
let second = relay_pressure_event_seq();
|
||||
note_relay_pressure_event_for_testing(shared.as_ref());
|
||||
let first = relay_pressure_event_seq_for_testing(shared.as_ref());
|
||||
note_relay_pressure_event_for_testing(shared.as_ref());
|
||||
let second = relay_pressure_event_seq_for_testing(shared.as_ref());
|
||||
|
||||
assert!(
|
||||
second > first,
|
||||
"distinct pressure events must remain distinguishable even at sequence boundary"
|
||||
);
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims()
|
||||
{
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let sessions = 16usize;
|
||||
@@ -671,20 +756,28 @@ async fn integration_race_single_pressure_event_allows_at_most_one_eviction_unde
|
||||
let mut seen_per_session = vec![0u64; sessions];
|
||||
|
||||
for conn_id in &conn_ids {
|
||||
assert!(mark_relay_idle_candidate(*conn_id));
|
||||
assert!(mark_relay_idle_candidate_for_testing(
|
||||
shared.as_ref(),
|
||||
*conn_id
|
||||
));
|
||||
}
|
||||
|
||||
for round in 0..rounds {
|
||||
note_relay_pressure_event();
|
||||
note_relay_pressure_event_for_testing(shared.as_ref());
|
||||
|
||||
let mut joins = Vec::with_capacity(sessions);
|
||||
for (idx, conn_id) in conn_ids.iter().enumerate() {
|
||||
let mut seen = seen_per_session[idx];
|
||||
let conn_id = *conn_id;
|
||||
let stats = stats.clone();
|
||||
let shared = shared.clone();
|
||||
joins.push(tokio::spawn(async move {
|
||||
let evicted =
|
||||
maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref());
|
||||
let evicted = maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
conn_id,
|
||||
&mut seen,
|
||||
stats.as_ref(),
|
||||
);
|
||||
(idx, conn_id, seen, evicted)
|
||||
}));
|
||||
}
|
||||
@@ -706,7 +799,7 @@ async fn integration_race_single_pressure_event_allows_at_most_one_eviction_unde
|
||||
);
|
||||
if let Some(conn) = evicted_conn {
|
||||
assert!(
|
||||
mark_relay_idle_candidate(conn),
|
||||
mark_relay_idle_candidate_for_testing(shared.as_ref(), conn),
|
||||
"round {round}: evicted conn must be re-markable as idle candidate"
|
||||
);
|
||||
}
|
||||
@@ -721,13 +814,13 @@ async fn integration_race_single_pressure_event_allows_at_most_one_eviction_unde
|
||||
"parallel race must still observe at least one successful eviction"
|
||||
);
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalidation_and_budget() {
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let sessions = 12usize;
|
||||
@@ -736,7 +829,10 @@ async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalida
|
||||
let mut seen_per_session = vec![0u64; sessions];
|
||||
|
||||
for conn_id in &conn_ids {
|
||||
assert!(mark_relay_idle_candidate(*conn_id));
|
||||
assert!(mark_relay_idle_candidate_for_testing(
|
||||
shared.as_ref(),
|
||||
*conn_id
|
||||
));
|
||||
}
|
||||
|
||||
let mut expected_total_evictions = 0u64;
|
||||
@@ -745,20 +841,25 @@ async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalida
|
||||
let empty_phase = round % 5 == 0;
|
||||
if empty_phase {
|
||||
for conn_id in &conn_ids {
|
||||
clear_relay_idle_candidate(*conn_id);
|
||||
clear_relay_idle_candidate_for_testing(shared.as_ref(), *conn_id);
|
||||
}
|
||||
}
|
||||
|
||||
note_relay_pressure_event();
|
||||
note_relay_pressure_event_for_testing(shared.as_ref());
|
||||
|
||||
let mut joins = Vec::with_capacity(sessions);
|
||||
for (idx, conn_id) in conn_ids.iter().enumerate() {
|
||||
let mut seen = seen_per_session[idx];
|
||||
let conn_id = *conn_id;
|
||||
let stats = stats.clone();
|
||||
let shared = shared.clone();
|
||||
joins.push(tokio::spawn(async move {
|
||||
let evicted =
|
||||
maybe_evict_idle_candidate_on_pressure(conn_id, &mut seen, stats.as_ref());
|
||||
let evicted = maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
shared.as_ref(),
|
||||
conn_id,
|
||||
&mut seen,
|
||||
stats.as_ref(),
|
||||
);
|
||||
(idx, conn_id, seen, evicted)
|
||||
}));
|
||||
}
|
||||
@@ -780,7 +881,10 @@ async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalida
|
||||
"round {round}: empty candidate phase must not allow stale-pressure eviction"
|
||||
);
|
||||
for conn_id in &conn_ids {
|
||||
assert!(mark_relay_idle_candidate(*conn_id));
|
||||
assert!(mark_relay_idle_candidate_for_testing(
|
||||
shared.as_ref(),
|
||||
*conn_id
|
||||
));
|
||||
}
|
||||
} else {
|
||||
assert!(
|
||||
@@ -789,7 +893,10 @@ async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalida
|
||||
);
|
||||
if let Some(conn_id) = evicted_conn {
|
||||
expected_total_evictions = expected_total_evictions.saturating_add(1);
|
||||
assert!(mark_relay_idle_candidate(conn_id));
|
||||
assert!(mark_relay_idle_candidate_for_testing(
|
||||
shared.as_ref(),
|
||||
conn_id
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -800,5 +907,5 @@ async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalida
|
||||
"global pressure eviction counter must match observed per-round successful consumes"
|
||||
);
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
@@ -3,12 +3,13 @@ use std::panic::{AssertUnwindSafe, catch_unwind};
|
||||
|
||||
#[test]
|
||||
fn blackhat_registry_poison_recovers_with_fail_closed_reset_and_pressure_accounting() {
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let _ = catch_unwind(AssertUnwindSafe(|| {
|
||||
let registry = relay_idle_candidate_registry();
|
||||
let mut guard = registry
|
||||
let mut guard = shared
|
||||
.middle_relay
|
||||
.relay_idle_registry
|
||||
.lock()
|
||||
.expect("registry lock must be acquired before poison");
|
||||
guard.by_conn_id.insert(
|
||||
@@ -23,40 +24,50 @@ fn blackhat_registry_poison_recovers_with_fail_closed_reset_and_pressure_account
|
||||
}));
|
||||
|
||||
// Helper lock must recover from poison, reset stale state, and continue.
|
||||
assert!(mark_relay_idle_candidate(42));
|
||||
assert_eq!(oldest_relay_idle_candidate(), Some(42));
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 42));
|
||||
assert_eq!(
|
||||
oldest_relay_idle_candidate_for_testing(shared.as_ref()),
|
||||
Some(42)
|
||||
);
|
||||
|
||||
let before = relay_pressure_event_seq();
|
||||
note_relay_pressure_event();
|
||||
let after = relay_pressure_event_seq();
|
||||
let before = relay_pressure_event_seq_for_testing(shared.as_ref());
|
||||
note_relay_pressure_event_for_testing(shared.as_ref());
|
||||
let after = relay_pressure_event_seq_for_testing(shared.as_ref());
|
||||
assert!(
|
||||
after > before,
|
||||
"pressure accounting must still advance after poison"
|
||||
);
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clear_state_helper_must_reset_poisoned_registry_for_deterministic_fifo_tests() {
|
||||
let _guard = relay_idle_pressure_test_scope();
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let _ = catch_unwind(AssertUnwindSafe(|| {
|
||||
let registry = relay_idle_candidate_registry();
|
||||
let _guard = registry
|
||||
let _guard = shared
|
||||
.middle_relay
|
||||
.relay_idle_registry
|
||||
.lock()
|
||||
.expect("registry lock must be acquired before poison");
|
||||
panic!("intentional poison while lock held");
|
||||
}));
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
assert_eq!(oldest_relay_idle_candidate(), None);
|
||||
assert_eq!(relay_pressure_event_seq(), 0);
|
||||
assert_eq!(
|
||||
oldest_relay_idle_candidate_for_testing(shared.as_ref()),
|
||||
None
|
||||
);
|
||||
assert_eq!(relay_pressure_event_seq_for_testing(shared.as_ref()), 0);
|
||||
|
||||
assert!(mark_relay_idle_candidate(7));
|
||||
assert_eq!(oldest_relay_idle_candidate(), Some(7));
|
||||
assert!(mark_relay_idle_candidate_for_testing(shared.as_ref(), 7));
|
||||
assert_eq!(
|
||||
oldest_relay_idle_candidate_for_testing(shared.as_ref()),
|
||||
Some(7)
|
||||
);
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::{Duration as TokioDuration, timeout};
|
||||
|
||||
@@ -15,32 +15,30 @@ fn make_pooled_payload(data: &[u8]) -> PooledBuffer {
|
||||
#[test]
|
||||
#[ignore = "Tracking for M-04: Verify should_emit_full_desync returns true on first occurrence and false on duplicate within window"]
|
||||
fn should_emit_full_desync_filters_duplicates() {
|
||||
let _guard = desync_dedup_test_lock()
|
||||
.lock()
|
||||
.expect("desync dedup test lock must be available");
|
||||
clear_desync_dedup_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_desync_dedup_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let key = 0x4D04_0000_0000_0001_u64;
|
||||
let base = Instant::now();
|
||||
|
||||
assert!(
|
||||
should_emit_full_desync(key, false, base),
|
||||
should_emit_full_desync_for_testing(shared.as_ref(), key, false, base),
|
||||
"first occurrence must emit full forensic record"
|
||||
);
|
||||
assert!(
|
||||
!should_emit_full_desync(key, false, base),
|
||||
!should_emit_full_desync_for_testing(shared.as_ref(), key, false, base),
|
||||
"duplicate at same timestamp must be suppressed"
|
||||
);
|
||||
|
||||
let within_window = base + DESYNC_DEDUP_WINDOW - TokioDuration::from_millis(1);
|
||||
assert!(
|
||||
!should_emit_full_desync(key, false, within_window),
|
||||
!should_emit_full_desync_for_testing(shared.as_ref(), key, false, within_window),
|
||||
"duplicate strictly inside dedup window must stay suppressed"
|
||||
);
|
||||
|
||||
let on_window_edge = base + DESYNC_DEDUP_WINDOW;
|
||||
assert!(
|
||||
should_emit_full_desync(key, false, on_window_edge),
|
||||
should_emit_full_desync_for_testing(shared.as_ref(), key, false, on_window_edge),
|
||||
"duplicate at window boundary must re-emit and refresh"
|
||||
);
|
||||
}
|
||||
@@ -48,39 +46,34 @@ fn should_emit_full_desync_filters_duplicates() {
|
||||
#[test]
|
||||
#[ignore = "Tracking for M-04: Verify desync dedup eviction behaves correctly under map-full condition"]
|
||||
fn desync_dedup_eviction_under_map_full_condition() {
|
||||
let _guard = desync_dedup_test_lock()
|
||||
.lock()
|
||||
.expect("desync dedup test lock must be available");
|
||||
clear_desync_dedup_for_testing();
|
||||
let shared = ProxySharedState::new();
|
||||
clear_desync_dedup_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let base = Instant::now();
|
||||
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
|
||||
assert!(
|
||||
should_emit_full_desync(key, false, base),
|
||||
should_emit_full_desync_for_testing(shared.as_ref(), key, false, base),
|
||||
"unique key should be inserted while warming dedup cache"
|
||||
);
|
||||
}
|
||||
|
||||
let dedup = DESYNC_DEDUP
|
||||
.get()
|
||||
.expect("dedup map must exist after warm-up insertions");
|
||||
assert_eq!(
|
||||
dedup.len(),
|
||||
desync_dedup_len_for_testing(shared.as_ref()),
|
||||
DESYNC_DEDUP_MAX_ENTRIES,
|
||||
"cache warm-up must reach exact hard cap"
|
||||
);
|
||||
|
||||
let before_keys: HashSet<u64> = dedup.iter().map(|entry| *entry.key()).collect();
|
||||
let before_keys = desync_dedup_keys_for_testing(shared.as_ref());
|
||||
let newcomer_key = 0x4D04_FFFF_FFFF_0001_u64;
|
||||
|
||||
assert!(
|
||||
should_emit_full_desync(newcomer_key, false, base),
|
||||
should_emit_full_desync_for_testing(shared.as_ref(), newcomer_key, false, base),
|
||||
"first newcomer at map-full must emit under bounded full-cache gate"
|
||||
);
|
||||
|
||||
let after_keys: HashSet<u64> = dedup.iter().map(|entry| *entry.key()).collect();
|
||||
let after_keys = desync_dedup_keys_for_testing(shared.as_ref());
|
||||
assert_eq!(
|
||||
dedup.len(),
|
||||
desync_dedup_len_for_testing(shared.as_ref()),
|
||||
DESYNC_DEDUP_MAX_ENTRIES,
|
||||
"map-full insertion must preserve hard capacity bound"
|
||||
);
|
||||
@@ -101,7 +94,7 @@ fn desync_dedup_eviction_under_map_full_condition() {
|
||||
);
|
||||
|
||||
assert!(
|
||||
!should_emit_full_desync(newcomer_key, false, base),
|
||||
!should_emit_full_desync_for_testing(shared.as_ref(), newcomer_key, false, base),
|
||||
"immediate duplicate newcomer must remain suppressed"
|
||||
);
|
||||
}
|
||||
@@ -119,6 +112,7 @@ async fn c2me_channel_full_path_yields_then_sends() {
|
||||
.expect("priming queue with one frame must succeed");
|
||||
|
||||
let tx2 = tx.clone();
|
||||
let stats = Stats::default();
|
||||
let producer = tokio::spawn(async move {
|
||||
enqueue_c2me_command(
|
||||
&tx2,
|
||||
@@ -127,6 +121,7 @@ async fn c2me_channel_full_path_yields_then_sends() {
|
||||
flags: 2,
|
||||
},
|
||||
None,
|
||||
&stats,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
@@ -0,0 +1,674 @@
|
||||
use crate::proxy::client::handle_client_stream_with_shared;
|
||||
use crate::proxy::handshake::{
|
||||
auth_probe_fail_streak_for_testing_in_shared, auth_probe_is_throttled_for_testing_in_shared,
|
||||
auth_probe_record_failure_for_testing, clear_auth_probe_state_for_testing_in_shared,
|
||||
clear_unknown_sni_warn_state_for_testing_in_shared, clear_warned_secrets_for_testing_in_shared,
|
||||
should_emit_unknown_sni_warn_for_testing_in_shared, warned_secrets_for_testing_in_shared,
|
||||
};
|
||||
use crate::proxy::middle_relay::{
|
||||
clear_desync_dedup_for_testing_in_shared, clear_relay_idle_candidate_for_testing,
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared, mark_relay_idle_candidate_for_testing,
|
||||
maybe_evict_idle_candidate_on_pressure_for_testing, note_relay_pressure_event_for_testing,
|
||||
oldest_relay_idle_candidate_for_testing, relay_idle_mark_seq_for_testing,
|
||||
relay_pressure_event_seq_for_testing, should_emit_full_desync_for_testing,
|
||||
};
|
||||
use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController};
|
||||
use crate::proxy::shared_state::ProxySharedState;
|
||||
use crate::{
|
||||
config::{ProxyConfig, UpstreamConfig, UpstreamType},
|
||||
crypto::SecureRandom,
|
||||
ip_tracker::UserIpTracker,
|
||||
stats::{ReplayChecker, Stats, beobachten::BeobachtenStore},
|
||||
stream::BufferPool,
|
||||
transport::UpstreamManager,
|
||||
};
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::io::{AsyncWriteExt, duplex};
|
||||
use tokio::sync::Barrier;
|
||||
|
||||
struct ClientHarness {
|
||||
config: Arc<ProxyConfig>,
|
||||
stats: Arc<Stats>,
|
||||
upstream_manager: Arc<UpstreamManager>,
|
||||
replay_checker: Arc<ReplayChecker>,
|
||||
buffer_pool: Arc<BufferPool>,
|
||||
rng: Arc<SecureRandom>,
|
||||
route_runtime: Arc<RouteRuntimeController>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
beobachten: Arc<BeobachtenStore>,
|
||||
}
|
||||
|
||||
fn new_client_harness() -> ClientHarness {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.censorship.mask = false;
|
||||
cfg.general.modes.classic = true;
|
||||
cfg.general.modes.secure = true;
|
||||
let config = Arc::new(cfg);
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let upstream_manager = Arc::new(UpstreamManager::new(
|
||||
vec![UpstreamConfig {
|
||||
upstream_type: UpstreamType::Direct {
|
||||
interface: None,
|
||||
bind_addresses: None,
|
||||
},
|
||||
weight: 1,
|
||||
enabled: true,
|
||||
scopes: String::new(),
|
||||
selected_scope: String::new(),
|
||||
}],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
10,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
));
|
||||
|
||||
ClientHarness {
|
||||
config,
|
||||
stats,
|
||||
upstream_manager,
|
||||
replay_checker: Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
|
||||
buffer_pool: Arc::new(BufferPool::new()),
|
||||
rng: Arc::new(SecureRandom::new()),
|
||||
route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
|
||||
ip_tracker: Arc::new(UserIpTracker::new()),
|
||||
beobachten: Arc::new(BeobachtenStore::new()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn drive_invalid_mtproto_handshake(
|
||||
shared: Arc<ProxySharedState>,
|
||||
peer: std::net::SocketAddr,
|
||||
) {
|
||||
let harness = new_client_harness();
|
||||
let (server_side, mut client_side) = duplex(4096);
|
||||
let invalid = [0u8; 64];
|
||||
|
||||
let task = tokio::spawn(handle_client_stream_with_shared(
|
||||
server_side,
|
||||
peer,
|
||||
harness.config,
|
||||
harness.stats,
|
||||
harness.upstream_manager,
|
||||
harness.replay_checker,
|
||||
harness.buffer_pool,
|
||||
harness.rng,
|
||||
None,
|
||||
harness.route_runtime,
|
||||
None,
|
||||
harness.ip_tracker,
|
||||
harness.beobachten,
|
||||
shared,
|
||||
false,
|
||||
));
|
||||
|
||||
client_side
|
||||
.write_all(&invalid)
|
||||
.await
|
||||
.expect("failed to write invalid handshake");
|
||||
client_side
|
||||
.shutdown()
|
||||
.await
|
||||
.expect("failed to shutdown client");
|
||||
let _ = tokio::time::timeout(Duration::from_secs(3), task)
|
||||
.await
|
||||
.expect("client task timed out")
|
||||
.expect("client task join failed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_two_instances_do_not_share_auth_probe_state() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(a.as_ref());
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 10));
|
||||
auth_probe_record_failure_for_testing(a.as_ref(), ip, Instant::now());
|
||||
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(a.as_ref(), ip),
|
||||
Some(1)
|
||||
);
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(b.as_ref(), ip),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_two_instances_do_not_share_desync_dedup() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_desync_dedup_for_testing_in_shared(a.as_ref());
|
||||
|
||||
let now = Instant::now();
|
||||
let key = 0xA5A5_u64;
|
||||
assert!(should_emit_full_desync_for_testing(
|
||||
a.as_ref(),
|
||||
key,
|
||||
false,
|
||||
now
|
||||
));
|
||||
assert!(should_emit_full_desync_for_testing(
|
||||
b.as_ref(),
|
||||
key,
|
||||
false,
|
||||
now
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_two_instances_do_not_share_idle_registry() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(a.as_ref());
|
||||
|
||||
assert!(mark_relay_idle_candidate_for_testing(a.as_ref(), 111));
|
||||
assert_eq!(
|
||||
oldest_relay_idle_candidate_for_testing(a.as_ref()),
|
||||
Some(111)
|
||||
);
|
||||
assert_eq!(oldest_relay_idle_candidate_for_testing(b.as_ref()), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_reset_in_one_instance_does_not_affect_another() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(a.as_ref());
|
||||
|
||||
let ip_a = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 1));
|
||||
let ip_b = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 2));
|
||||
let now = Instant::now();
|
||||
|
||||
auth_probe_record_failure_for_testing(a.as_ref(), ip_a, now);
|
||||
auth_probe_record_failure_for_testing(b.as_ref(), ip_b, now);
|
||||
clear_auth_probe_state_for_testing_in_shared(a.as_ref());
|
||||
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(a.as_ref(), ip_a),
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(b.as_ref(), ip_b),
|
||||
Some(1)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_parallel_auth_probe_updates_stay_per_instance() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(a.as_ref());
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 77));
|
||||
let now = Instant::now();
|
||||
|
||||
for _ in 0..5 {
|
||||
auth_probe_record_failure_for_testing(a.as_ref(), ip, now);
|
||||
}
|
||||
for _ in 0..3 {
|
||||
auth_probe_record_failure_for_testing(b.as_ref(), ip, now + Duration::from_millis(1));
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(a.as_ref(), ip),
|
||||
Some(5)
|
||||
);
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(b.as_ref(), ip),
|
||||
Some(3)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn proxy_shared_state_client_pipeline_records_probe_failures_in_instance_state() {
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 200));
|
||||
let peer = std::net::SocketAddr::new(peer_ip, 54001);
|
||||
|
||||
drive_invalid_mtproto_handshake(shared.clone(), peer).await;
|
||||
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(shared.as_ref(), peer_ip),
|
||||
Some(1),
|
||||
"invalid handshake in client pipeline must update injected shared auth-probe state"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn proxy_shared_state_client_pipeline_keeps_auth_probe_isolated_between_instances() {
|
||||
let shared_a = ProxySharedState::new();
|
||||
let shared_b = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared_a.as_ref());
|
||||
clear_auth_probe_state_for_testing_in_shared(shared_b.as_ref());
|
||||
|
||||
let peer_a_ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 210));
|
||||
let peer_b_ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 211));
|
||||
|
||||
drive_invalid_mtproto_handshake(
|
||||
shared_a.clone(),
|
||||
std::net::SocketAddr::new(peer_a_ip, 54110),
|
||||
)
|
||||
.await;
|
||||
drive_invalid_mtproto_handshake(
|
||||
shared_b.clone(),
|
||||
std::net::SocketAddr::new(peer_b_ip, 54111),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(shared_a.as_ref(), peer_a_ip),
|
||||
Some(1)
|
||||
);
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(shared_b.as_ref(), peer_b_ip),
|
||||
Some(1)
|
||||
);
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(shared_a.as_ref(), peer_b_ip),
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(shared_b.as_ref(), peer_a_ip),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn proxy_shared_state_client_pipeline_high_contention_same_ip_stays_lossless_per_instance() {
|
||||
let shared_a = ProxySharedState::new();
|
||||
let shared_b = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared_a.as_ref());
|
||||
clear_auth_probe_state_for_testing_in_shared(shared_b.as_ref());
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 250));
|
||||
let workers = 48u16;
|
||||
let barrier = Arc::new(Barrier::new((workers as usize) * 2));
|
||||
let mut tasks = Vec::new();
|
||||
|
||||
for i in 0..workers {
|
||||
let shared_a = shared_a.clone();
|
||||
let barrier_a = barrier.clone();
|
||||
let peer_a = std::net::SocketAddr::new(ip, 56000 + i);
|
||||
tasks.push(tokio::spawn(async move {
|
||||
barrier_a.wait().await;
|
||||
drive_invalid_mtproto_handshake(shared_a, peer_a).await;
|
||||
}));
|
||||
|
||||
let shared_b = shared_b.clone();
|
||||
let barrier_b = barrier.clone();
|
||||
let peer_b = std::net::SocketAddr::new(ip, 56100 + i);
|
||||
tasks.push(tokio::spawn(async move {
|
||||
barrier_b.wait().await;
|
||||
drive_invalid_mtproto_handshake(shared_b, peer_b).await;
|
||||
}));
|
||||
}
|
||||
|
||||
for task in tasks {
|
||||
task.await.expect("pipeline task join failed");
|
||||
}
|
||||
|
||||
let streak_a = auth_probe_fail_streak_for_testing_in_shared(shared_a.as_ref(), ip)
|
||||
.expect("instance A must track probe failures");
|
||||
let streak_b = auth_probe_fail_streak_for_testing_in_shared(shared_b.as_ref(), ip)
|
||||
.expect("instance B must track probe failures");
|
||||
|
||||
assert!(streak_a > 0);
|
||||
assert!(streak_b > 0);
|
||||
|
||||
clear_auth_probe_state_for_testing_in_shared(shared_a.as_ref());
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(shared_a.as_ref(), ip),
|
||||
None,
|
||||
"clearing one instance must reset only that instance"
|
||||
);
|
||||
assert!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(shared_b.as_ref(), ip).is_some(),
|
||||
"clearing one instance must not clear the other instance"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_auth_saturation_does_not_bleed_across_instances() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(a.as_ref());
|
||||
clear_auth_probe_state_for_testing_in_shared(b.as_ref());
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 77));
|
||||
let future_now = Instant::now() + Duration::from_secs(1);
|
||||
for _ in 0..8 {
|
||||
auth_probe_record_failure_for_testing(a.as_ref(), ip, future_now);
|
||||
}
|
||||
|
||||
assert!(auth_probe_is_throttled_for_testing_in_shared(
|
||||
a.as_ref(),
|
||||
ip
|
||||
));
|
||||
assert!(!auth_probe_is_throttled_for_testing_in_shared(
|
||||
b.as_ref(),
|
||||
ip
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_poison_clear_in_one_instance_does_not_affect_other_instance() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(a.as_ref());
|
||||
clear_auth_probe_state_for_testing_in_shared(b.as_ref());
|
||||
|
||||
let ip_a = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 31));
|
||||
let ip_b = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 32));
|
||||
let now = Instant::now();
|
||||
|
||||
auth_probe_record_failure_for_testing(a.as_ref(), ip_a, now);
|
||||
auth_probe_record_failure_for_testing(b.as_ref(), ip_b, now);
|
||||
|
||||
let a_for_poison = a.clone();
|
||||
let _ = std::thread::spawn(move || {
|
||||
let _hold = a_for_poison
|
||||
.handshake
|
||||
.auth_probe_saturation
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
panic!("intentional poison for per-instance isolation regression coverage");
|
||||
})
|
||||
.join();
|
||||
|
||||
clear_auth_probe_state_for_testing_in_shared(a.as_ref());
|
||||
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(a.as_ref(), ip_a),
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(b.as_ref(), ip_b),
|
||||
Some(1),
|
||||
"poison recovery and clear in one instance must not touch other instance state"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_unknown_sni_cooldown_does_not_bleed_across_instances() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_unknown_sni_warn_state_for_testing_in_shared(a.as_ref());
|
||||
clear_unknown_sni_warn_state_for_testing_in_shared(b.as_ref());
|
||||
|
||||
let now = Instant::now();
|
||||
assert!(should_emit_unknown_sni_warn_for_testing_in_shared(
|
||||
a.as_ref(),
|
||||
now
|
||||
));
|
||||
assert!(should_emit_unknown_sni_warn_for_testing_in_shared(
|
||||
b.as_ref(),
|
||||
now
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_warned_secret_cache_does_not_bleed_across_instances() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_warned_secrets_for_testing_in_shared(a.as_ref());
|
||||
clear_warned_secrets_for_testing_in_shared(b.as_ref());
|
||||
|
||||
let key = ("isolation-user".to_string(), "invalid_hex".to_string());
|
||||
{
|
||||
let warned = warned_secrets_for_testing_in_shared(a.as_ref());
|
||||
let mut guard = warned
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
guard.insert(key.clone());
|
||||
}
|
||||
|
||||
let contains_in_a = {
|
||||
let warned = warned_secrets_for_testing_in_shared(a.as_ref());
|
||||
let guard = warned
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
guard.contains(&key)
|
||||
};
|
||||
let contains_in_b = {
|
||||
let warned = warned_secrets_for_testing_in_shared(b.as_ref());
|
||||
let guard = warned
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
guard.contains(&key)
|
||||
};
|
||||
|
||||
assert!(contains_in_a);
|
||||
assert!(!contains_in_b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_idle_mark_seq_is_per_instance() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(a.as_ref());
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(b.as_ref());
|
||||
|
||||
assert_eq!(relay_idle_mark_seq_for_testing(a.as_ref()), 0);
|
||||
assert_eq!(relay_idle_mark_seq_for_testing(b.as_ref()), 0);
|
||||
|
||||
assert!(mark_relay_idle_candidate_for_testing(a.as_ref(), 9001));
|
||||
assert_eq!(relay_idle_mark_seq_for_testing(a.as_ref()), 1);
|
||||
assert_eq!(relay_idle_mark_seq_for_testing(b.as_ref()), 0);
|
||||
|
||||
assert!(mark_relay_idle_candidate_for_testing(b.as_ref(), 9002));
|
||||
assert_eq!(relay_idle_mark_seq_for_testing(a.as_ref()), 1);
|
||||
assert_eq!(relay_idle_mark_seq_for_testing(b.as_ref()), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_unknown_sni_clear_in_one_instance_does_not_reset_other() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_unknown_sni_warn_state_for_testing_in_shared(a.as_ref());
|
||||
clear_unknown_sni_warn_state_for_testing_in_shared(b.as_ref());
|
||||
|
||||
let now = Instant::now();
|
||||
assert!(should_emit_unknown_sni_warn_for_testing_in_shared(
|
||||
a.as_ref(),
|
||||
now
|
||||
));
|
||||
assert!(should_emit_unknown_sni_warn_for_testing_in_shared(
|
||||
b.as_ref(),
|
||||
now
|
||||
));
|
||||
|
||||
clear_unknown_sni_warn_state_for_testing_in_shared(a.as_ref());
|
||||
assert!(should_emit_unknown_sni_warn_for_testing_in_shared(
|
||||
a.as_ref(),
|
||||
now + Duration::from_millis(1)
|
||||
));
|
||||
assert!(!should_emit_unknown_sni_warn_for_testing_in_shared(
|
||||
b.as_ref(),
|
||||
now + Duration::from_millis(1)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_warned_secret_clear_in_one_instance_does_not_clear_other() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_warned_secrets_for_testing_in_shared(a.as_ref());
|
||||
clear_warned_secrets_for_testing_in_shared(b.as_ref());
|
||||
|
||||
let key = (
|
||||
"clear-isolation-user".to_string(),
|
||||
"invalid_length".to_string(),
|
||||
);
|
||||
{
|
||||
let warned_a = warned_secrets_for_testing_in_shared(a.as_ref());
|
||||
let mut guard_a = warned_a
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
guard_a.insert(key.clone());
|
||||
|
||||
let warned_b = warned_secrets_for_testing_in_shared(b.as_ref());
|
||||
let mut guard_b = warned_b
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
guard_b.insert(key.clone());
|
||||
}
|
||||
|
||||
clear_warned_secrets_for_testing_in_shared(a.as_ref());
|
||||
|
||||
let has_a = {
|
||||
let warned = warned_secrets_for_testing_in_shared(a.as_ref());
|
||||
let guard = warned
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
guard.contains(&key)
|
||||
};
|
||||
let has_b = {
|
||||
let warned = warned_secrets_for_testing_in_shared(b.as_ref());
|
||||
let guard = warned
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
guard.contains(&key)
|
||||
};
|
||||
|
||||
assert!(!has_a);
|
||||
assert!(has_b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_desync_duplicate_suppression_is_instance_scoped() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_desync_dedup_for_testing_in_shared(a.as_ref());
|
||||
clear_desync_dedup_for_testing_in_shared(b.as_ref());
|
||||
|
||||
let now = Instant::now();
|
||||
let key = 0xBEEF_0000_0000_0001u64;
|
||||
assert!(should_emit_full_desync_for_testing(
|
||||
a.as_ref(),
|
||||
key,
|
||||
false,
|
||||
now
|
||||
));
|
||||
assert!(!should_emit_full_desync_for_testing(
|
||||
a.as_ref(),
|
||||
key,
|
||||
false,
|
||||
now + Duration::from_millis(1)
|
||||
));
|
||||
assert!(should_emit_full_desync_for_testing(
|
||||
b.as_ref(),
|
||||
key,
|
||||
false,
|
||||
now
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_desync_clear_in_one_instance_does_not_clear_other() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_desync_dedup_for_testing_in_shared(a.as_ref());
|
||||
clear_desync_dedup_for_testing_in_shared(b.as_ref());
|
||||
|
||||
let now = Instant::now();
|
||||
let key = 0xCAFE_0000_0000_0001u64;
|
||||
assert!(should_emit_full_desync_for_testing(
|
||||
a.as_ref(),
|
||||
key,
|
||||
false,
|
||||
now
|
||||
));
|
||||
assert!(should_emit_full_desync_for_testing(
|
||||
b.as_ref(),
|
||||
key,
|
||||
false,
|
||||
now
|
||||
));
|
||||
|
||||
clear_desync_dedup_for_testing_in_shared(a.as_ref());
|
||||
|
||||
assert!(should_emit_full_desync_for_testing(
|
||||
a.as_ref(),
|
||||
key,
|
||||
false,
|
||||
now + Duration::from_millis(2)
|
||||
));
|
||||
assert!(!should_emit_full_desync_for_testing(
|
||||
b.as_ref(),
|
||||
key,
|
||||
false,
|
||||
now + Duration::from_millis(2)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_idle_candidate_clear_in_one_instance_does_not_affect_other() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(a.as_ref());
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(b.as_ref());
|
||||
|
||||
assert!(mark_relay_idle_candidate_for_testing(a.as_ref(), 1001));
|
||||
assert!(mark_relay_idle_candidate_for_testing(b.as_ref(), 2002));
|
||||
clear_relay_idle_candidate_for_testing(a.as_ref(), 1001);
|
||||
|
||||
assert_eq!(oldest_relay_idle_candidate_for_testing(a.as_ref()), None);
|
||||
assert_eq!(
|
||||
oldest_relay_idle_candidate_for_testing(b.as_ref()),
|
||||
Some(2002)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_pressure_seq_increments_are_instance_scoped() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(a.as_ref());
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(b.as_ref());
|
||||
|
||||
assert_eq!(relay_pressure_event_seq_for_testing(a.as_ref()), 0);
|
||||
assert_eq!(relay_pressure_event_seq_for_testing(b.as_ref()), 0);
|
||||
|
||||
note_relay_pressure_event_for_testing(a.as_ref());
|
||||
note_relay_pressure_event_for_testing(a.as_ref());
|
||||
|
||||
assert_eq!(relay_pressure_event_seq_for_testing(a.as_ref()), 2);
|
||||
assert_eq!(relay_pressure_event_seq_for_testing(b.as_ref()), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_shared_state_pressure_consumption_does_not_cross_instances() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(a.as_ref());
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(b.as_ref());
|
||||
|
||||
assert!(mark_relay_idle_candidate_for_testing(a.as_ref(), 7001));
|
||||
assert!(mark_relay_idle_candidate_for_testing(b.as_ref(), 7001));
|
||||
note_relay_pressure_event_for_testing(a.as_ref());
|
||||
|
||||
let stats = Stats::new();
|
||||
let mut seen_a = 0u64;
|
||||
let mut seen_b = 0u64;
|
||||
|
||||
assert!(maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
a.as_ref(),
|
||||
7001,
|
||||
&mut seen_a,
|
||||
&stats
|
||||
));
|
||||
assert!(!maybe_evict_idle_candidate_on_pressure_for_testing(
|
||||
b.as_ref(),
|
||||
7001,
|
||||
&mut seen_b,
|
||||
&stats
|
||||
));
|
||||
}
|
||||
@@ -0,0 +1,265 @@
|
||||
use crate::proxy::handshake::{
|
||||
auth_probe_fail_streak_for_testing_in_shared, auth_probe_record_failure_for_testing,
|
||||
clear_auth_probe_state_for_testing_in_shared,
|
||||
clear_unknown_sni_warn_state_for_testing_in_shared,
|
||||
should_emit_unknown_sni_warn_for_testing_in_shared,
|
||||
};
|
||||
use crate::proxy::middle_relay::{
|
||||
clear_desync_dedup_for_testing_in_shared,
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared, mark_relay_idle_candidate_for_testing,
|
||||
oldest_relay_idle_candidate_for_testing, should_emit_full_desync_for_testing,
|
||||
};
|
||||
use crate::proxy::shared_state::ProxySharedState;
|
||||
use rand::RngExt;
|
||||
use rand::SeedableRng;
|
||||
use rand::rngs::StdRng;
|
||||
use std::net::{IpAddr, Ipv4Addr};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::Barrier;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn proxy_shared_state_50_concurrent_instances_no_counter_bleed() {
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..50_u8 {
|
||||
handles.push(tokio::spawn(async move {
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 200));
|
||||
auth_probe_record_failure_for_testing(shared.as_ref(), ip, Instant::now());
|
||||
auth_probe_fail_streak_for_testing_in_shared(shared.as_ref(), ip)
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
let streak = handle.await.expect("task join failed");
|
||||
assert_eq!(streak, Some(1));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn proxy_shared_state_desync_rotation_concurrent_20_instances() {
|
||||
let now = Instant::now();
|
||||
let key = 0xD35E_D35E_u64;
|
||||
let mut handles = Vec::new();
|
||||
for _ in 0..20_u64 {
|
||||
handles.push(tokio::spawn(async move {
|
||||
let shared = ProxySharedState::new();
|
||||
clear_desync_dedup_for_testing_in_shared(shared.as_ref());
|
||||
should_emit_full_desync_for_testing(shared.as_ref(), key, false, now)
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
let emitted = handle.await.expect("task join failed");
|
||||
assert!(emitted);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn proxy_shared_state_idle_registry_concurrent_10_instances() {
|
||||
let mut handles = Vec::new();
|
||||
let conn_id = 42_u64;
|
||||
for _ in 1..=10_u64 {
|
||||
handles.push(tokio::spawn(async move {
|
||||
let shared = ProxySharedState::new();
|
||||
clear_relay_idle_pressure_state_for_testing_in_shared(shared.as_ref());
|
||||
let marked = mark_relay_idle_candidate_for_testing(shared.as_ref(), conn_id);
|
||||
let oldest = oldest_relay_idle_candidate_for_testing(shared.as_ref());
|
||||
(marked, oldest)
|
||||
}));
|
||||
}
|
||||
|
||||
for (i, handle) in handles.into_iter().enumerate() {
|
||||
let (marked, oldest) = handle.await.expect("task join failed");
|
||||
assert!(marked, "instance {} failed to mark", i);
|
||||
assert_eq!(oldest, Some(conn_id));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn proxy_shared_state_dual_instance_same_ip_high_contention_no_counter_bleed() {
|
||||
let a = ProxySharedState::new();
|
||||
let b = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(a.as_ref());
|
||||
clear_auth_probe_state_for_testing_in_shared(b.as_ref());
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(203, 0, 113, 200));
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for _ in 0..64 {
|
||||
let a = a.clone();
|
||||
let b = b.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
auth_probe_record_failure_for_testing(a.as_ref(), ip, Instant::now());
|
||||
auth_probe_record_failure_for_testing(b.as_ref(), ip, Instant::now());
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.await.expect("task join failed");
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(a.as_ref(), ip),
|
||||
Some(64)
|
||||
);
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(b.as_ref(), ip),
|
||||
Some(64)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn proxy_shared_state_unknown_sni_parallel_instances_no_cross_cooldown() {
|
||||
let mut handles = Vec::new();
|
||||
let now = Instant::now();
|
||||
|
||||
for _ in 0..32 {
|
||||
handles.push(tokio::spawn(async move {
|
||||
let shared = ProxySharedState::new();
|
||||
clear_unknown_sni_warn_state_for_testing_in_shared(shared.as_ref());
|
||||
let first = should_emit_unknown_sni_warn_for_testing_in_shared(shared.as_ref(), now);
|
||||
let second = should_emit_unknown_sni_warn_for_testing_in_shared(
|
||||
shared.as_ref(),
|
||||
now + std::time::Duration::from_millis(1),
|
||||
);
|
||||
(first, second)
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
let (first, second) = handle.await.expect("task join failed");
|
||||
assert!(first);
|
||||
assert!(!second);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn proxy_shared_state_auth_probe_high_contention_increments_are_lossless() {
|
||||
let shared = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared.as_ref());
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 33));
|
||||
let workers = 128usize;
|
||||
let rounds = 20usize;
|
||||
|
||||
for _ in 0..rounds {
|
||||
let start = Arc::new(Barrier::new(workers));
|
||||
let mut handles = Vec::with_capacity(workers);
|
||||
|
||||
for _ in 0..workers {
|
||||
let shared = shared.clone();
|
||||
let start = start.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
start.wait().await;
|
||||
auth_probe_record_failure_for_testing(shared.as_ref(), ip, Instant::now());
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.await.expect("task join failed");
|
||||
}
|
||||
}
|
||||
|
||||
let expected = (workers * rounds) as u32;
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(shared.as_ref(), ip),
|
||||
Some(expected),
|
||||
"auth probe fail streak must account for every concurrent update"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn proxy_shared_state_seed_matrix_concurrency_isolation_no_counter_bleed() {
|
||||
let seeds: [u64; 8] = [
|
||||
0x0000_0000_0000_0001,
|
||||
0x1111_1111_1111_1111,
|
||||
0xA5A5_A5A5_A5A5_A5A5,
|
||||
0xDEAD_BEEF_CAFE_BABE,
|
||||
0x0123_4567_89AB_CDEF,
|
||||
0xFEDC_BA98_7654_3210,
|
||||
0x0F0F_F0F0_55AA_AA55,
|
||||
0x1357_9BDF_2468_ACE0,
|
||||
];
|
||||
|
||||
for seed in seeds {
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
let shared_a = ProxySharedState::new();
|
||||
let shared_b = ProxySharedState::new();
|
||||
clear_auth_probe_state_for_testing_in_shared(shared_a.as_ref());
|
||||
clear_auth_probe_state_for_testing_in_shared(shared_b.as_ref());
|
||||
|
||||
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, rng.random_range(1_u8..=250_u8)));
|
||||
let workers = rng.random_range(16_usize..=48_usize);
|
||||
let rounds = rng.random_range(4_usize..=10_usize);
|
||||
|
||||
let mut expected_a: u32 = 0;
|
||||
let mut expected_b: u32 = 0;
|
||||
|
||||
for _ in 0..rounds {
|
||||
let start = Arc::new(Barrier::new(workers * 2));
|
||||
let mut handles = Vec::with_capacity(workers * 2);
|
||||
|
||||
for _ in 0..workers {
|
||||
let a_ops = rng.random_range(1_u32..=3_u32);
|
||||
let b_ops = rng.random_range(1_u32..=3_u32);
|
||||
expected_a = expected_a.saturating_add(a_ops);
|
||||
expected_b = expected_b.saturating_add(b_ops);
|
||||
|
||||
let shared_a = shared_a.clone();
|
||||
let start_a = start.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
start_a.wait().await;
|
||||
for _ in 0..a_ops {
|
||||
auth_probe_record_failure_for_testing(
|
||||
shared_a.as_ref(),
|
||||
ip,
|
||||
Instant::now(),
|
||||
);
|
||||
}
|
||||
}));
|
||||
|
||||
let shared_b = shared_b.clone();
|
||||
let start_b = start.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
start_b.wait().await;
|
||||
for _ in 0..b_ops {
|
||||
auth_probe_record_failure_for_testing(
|
||||
shared_b.as_ref(),
|
||||
ip,
|
||||
Instant::now(),
|
||||
);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.await.expect("task join failed");
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(shared_a.as_ref(), ip),
|
||||
Some(expected_a),
|
||||
"seed {seed:#x}: instance A streak mismatch"
|
||||
);
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(shared_b.as_ref(), ip),
|
||||
Some(expected_b),
|
||||
"seed {seed:#x}: instance B streak mismatch"
|
||||
);
|
||||
|
||||
clear_auth_probe_state_for_testing_in_shared(shared_a.as_ref());
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(shared_a.as_ref(), ip),
|
||||
None,
|
||||
"seed {seed:#x}: clearing A must reset only A"
|
||||
);
|
||||
assert_eq!(
|
||||
auth_probe_fail_streak_for_testing_in_shared(shared_b.as_ref(), ip),
|
||||
Some(expected_b),
|
||||
"seed {seed:#x}: clearing A must not mutate B"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,284 @@
|
||||
use super::*;
|
||||
use crate::error::ProxyError;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf, duplex};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
struct BrokenPipeWriter;
|
||||
|
||||
impl AsyncWrite for BrokenPipeWriter {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
"forced broken pipe",
|
||||
)))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn relay_baseline_activity_timeout_fires_after_inactivity() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "relay-baseline-idle-timeout";
|
||||
|
||||
let (_client_peer, relay_client) = duplex(1024);
|
||||
let (_server_peer, relay_server) = duplex(1024);
|
||||
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_task = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
1024,
|
||||
1024,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
None,
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
tokio::time::advance(ACTIVITY_TIMEOUT.saturating_sub(Duration::from_secs(1))).await;
|
||||
tokio::task::yield_now().await;
|
||||
assert!(
|
||||
!relay_task.is_finished(),
|
||||
"relay must stay alive before inactivity timeout"
|
||||
);
|
||||
|
||||
tokio::time::advance(WATCHDOG_INTERVAL + Duration::from_secs(2)).await;
|
||||
|
||||
let done = timeout(Duration::from_secs(1), relay_task)
|
||||
.await
|
||||
.expect("relay must complete after inactivity timeout")
|
||||
.expect("relay task must not panic");
|
||||
|
||||
assert!(
|
||||
done.is_ok(),
|
||||
"relay must return Ok(()) after inactivity timeout"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_baseline_zero_bytes_returns_ok_and_counters_zero() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "relay-baseline-zero-bytes";
|
||||
|
||||
let (client_peer, relay_client) = duplex(1024);
|
||||
let (server_peer, relay_server) = duplex(1024);
|
||||
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_task = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
1024,
|
||||
1024,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
None,
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let done = timeout(Duration::from_secs(2), relay_task)
|
||||
.await
|
||||
.expect("relay must stop after both peers close")
|
||||
.expect("relay task must not panic");
|
||||
|
||||
assert!(done.is_ok(), "relay must return Ok(()) on immediate EOF");
|
||||
assert_eq!(stats.get_user_total_octets(user), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_baseline_bidirectional_bytes_counted_symmetrically() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "relay-baseline-bidir-counters";
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(16 * 1024);
|
||||
let (relay_server, mut server_peer) = duplex(16 * 1024);
|
||||
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_task = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
4096,
|
||||
4096,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
None,
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
let c2s = vec![0xAA; 4096];
|
||||
let s2c = vec![0xBB; 2048];
|
||||
|
||||
client_peer.write_all(&c2s).await.unwrap();
|
||||
server_peer.write_all(&s2c).await.unwrap();
|
||||
|
||||
let mut seen_c2s = vec![0u8; c2s.len()];
|
||||
let mut seen_s2c = vec![0u8; s2c.len()];
|
||||
server_peer.read_exact(&mut seen_c2s).await.unwrap();
|
||||
client_peer.read_exact(&mut seen_s2c).await.unwrap();
|
||||
|
||||
assert_eq!(seen_c2s, c2s);
|
||||
assert_eq!(seen_s2c, s2c);
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let done = timeout(Duration::from_secs(2), relay_task)
|
||||
.await
|
||||
.expect("relay must complete after both peers close")
|
||||
.expect("relay task must not panic");
|
||||
assert!(done.is_ok());
|
||||
|
||||
assert_eq!(
|
||||
stats.get_user_total_octets(user),
|
||||
(c2s.len() + s2c.len()) as u64
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_baseline_both_sides_close_simultaneously_no_panic() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let (client_peer, relay_client) = duplex(1024);
|
||||
let (relay_server, server_peer) = duplex(1024);
|
||||
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_task = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
1024,
|
||||
1024,
|
||||
"relay-baseline-sim-close",
|
||||
Arc::clone(&stats),
|
||||
None,
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let done = timeout(Duration::from_secs(2), relay_task)
|
||||
.await
|
||||
.expect("relay must complete")
|
||||
.expect("relay task must not panic");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_baseline_broken_pipe_midtransfer_returns_error() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "relay-baseline-broken-pipe";
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(1024);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
|
||||
let relay_task = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
tokio::io::empty(),
|
||||
BrokenPipeWriter,
|
||||
1024,
|
||||
1024,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
None,
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
client_peer.write_all(b"trigger").await.unwrap();
|
||||
|
||||
let done = timeout(Duration::from_secs(2), relay_task)
|
||||
.await
|
||||
.expect("relay must return after broken pipe")
|
||||
.expect("relay task must not panic");
|
||||
|
||||
match done {
|
||||
Err(ProxyError::Io(err)) => {
|
||||
assert!(
|
||||
matches!(
|
||||
err.kind(),
|
||||
io::ErrorKind::BrokenPipe | io::ErrorKind::ConnectionReset
|
||||
),
|
||||
"expected BrokenPipe/ConnectionReset, got {:?}",
|
||||
err.kind()
|
||||
);
|
||||
}
|
||||
other => panic!("expected ProxyError::Io, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn relay_baseline_many_small_writes_exact_counter() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "relay-baseline-many-small";
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(4096);
|
||||
let (relay_server, mut server_peer) = duplex(4096);
|
||||
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_task = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
1024,
|
||||
1024,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
None,
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
for i in 0..10_000u32 {
|
||||
let b = [(i & 0xFF) as u8];
|
||||
client_peer.write_all(&b).await.unwrap();
|
||||
let mut seen = [0u8; 1];
|
||||
server_peer.read_exact(&mut seen).await.unwrap();
|
||||
assert_eq!(seen, b);
|
||||
}
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let done = timeout(Duration::from_secs(3), relay_task)
|
||||
.await
|
||||
.expect("relay must complete for many small writes")
|
||||
.expect("relay task must not panic");
|
||||
assert!(done.is_ok());
|
||||
assert_eq!(stats.get_user_total_octets(user), 10_000);
|
||||
}
|
||||
@@ -0,0 +1,205 @@
|
||||
use crate::config::ProxyConfig;
|
||||
use rand::SeedableRng;
|
||||
use rand::rngs::StdRng;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::AsyncWrite;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::task::{RawWaker, RawWakerVTable, Waker};
|
||||
|
||||
unsafe fn wake_counter_clone(data: *const ()) -> RawWaker {
|
||||
let arc = Arc::<AtomicUsize>::from_raw(data.cast::<AtomicUsize>());
|
||||
let cloned = Arc::clone(&arc);
|
||||
let _ = Arc::into_raw(arc);
|
||||
RawWaker::new(
|
||||
Arc::into_raw(cloned).cast::<()>(),
|
||||
&WAKE_COUNTER_WAKER_VTABLE,
|
||||
)
|
||||
}
|
||||
|
||||
unsafe fn wake_counter_wake(data: *const ()) {
|
||||
let arc = Arc::<AtomicUsize>::from_raw(data.cast::<AtomicUsize>());
|
||||
arc.fetch_add(1, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
unsafe fn wake_counter_wake_by_ref(data: *const ()) {
|
||||
let arc = Arc::<AtomicUsize>::from_raw(data.cast::<AtomicUsize>());
|
||||
arc.fetch_add(1, Ordering::SeqCst);
|
||||
let _ = Arc::into_raw(arc);
|
||||
}
|
||||
|
||||
unsafe fn wake_counter_drop(data: *const ()) {
|
||||
let _ = Arc::<AtomicUsize>::from_raw(data.cast::<AtomicUsize>());
|
||||
}
|
||||
|
||||
static WAKE_COUNTER_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
|
||||
wake_counter_clone,
|
||||
wake_counter_wake,
|
||||
wake_counter_wake_by_ref,
|
||||
wake_counter_drop,
|
||||
);
|
||||
|
||||
fn wake_counter_waker(counter: Arc<AtomicUsize>) -> Waker {
|
||||
let raw = RawWaker::new(
|
||||
Arc::into_raw(counter).cast::<()>(),
|
||||
&WAKE_COUNTER_WAKER_VTABLE,
|
||||
);
|
||||
// SAFETY: `raw` points to a valid `Arc<AtomicUsize>` and uses a vtable
|
||||
// that preserves Arc reference-counting semantics.
|
||||
unsafe { Waker::from_raw(raw) }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pending_count_writer_write_pending_does_not_spurious_wake() {
|
||||
let counter = Arc::new(AtomicUsize::new(0));
|
||||
let waker = wake_counter_waker(Arc::clone(&counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let mut writer = PendingCountWriter::new(RecordingWriter::new(), 1, 0);
|
||||
let poll = Pin::new(&mut writer).poll_write(&mut cx, b"x");
|
||||
|
||||
assert!(matches!(poll, Poll::Pending));
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pending_count_writer_flush_pending_does_not_spurious_wake() {
|
||||
let counter = Arc::new(AtomicUsize::new(0));
|
||||
let waker = wake_counter_waker(Arc::clone(&counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let mut writer = PendingCountWriter::new(RecordingWriter::new(), 0, 1);
|
||||
let poll = Pin::new(&mut writer).poll_flush(&mut cx);
|
||||
|
||||
assert!(matches!(poll, Poll::Pending));
|
||||
assert_eq!(counter.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
}
|
||||
|
||||
// In-memory AsyncWrite that records both per-write and per-flush granularity.
|
||||
pub struct RecordingWriter {
|
||||
pub writes: Vec<Vec<u8>>,
|
||||
pub flushed: Vec<Vec<u8>>,
|
||||
current_record: Vec<u8>,
|
||||
}
|
||||
|
||||
impl RecordingWriter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
writes: Vec::new(),
|
||||
flushed: Vec::new(),
|
||||
current_record: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn total_bytes(&self) -> usize {
|
||||
self.writes.iter().map(|w| w.len()).sum()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RecordingWriter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for RecordingWriter {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let me = self.as_mut().get_mut();
|
||||
me.writes.push(buf.to_vec());
|
||||
me.current_record.extend_from_slice(buf);
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let me = self.as_mut().get_mut();
|
||||
let record = std::mem::take(&mut me.current_record);
|
||||
if !record.is_empty() {
|
||||
me.flushed.push(record);
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
// Returns Poll::Pending for the first N write/flush calls, then delegates.
|
||||
pub struct PendingCountWriter<W> {
|
||||
pub inner: W,
|
||||
pub write_pending_remaining: usize,
|
||||
pub flush_pending_remaining: usize,
|
||||
}
|
||||
|
||||
impl<W> PendingCountWriter<W> {
|
||||
pub fn new(inner: W, write_pending: usize, flush_pending: usize) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
write_pending_remaining: write_pending,
|
||||
flush_pending_remaining: flush_pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin> AsyncWrite for PendingCountWriter<W> {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let me = self.as_mut().get_mut();
|
||||
if me.write_pending_remaining > 0 {
|
||||
me.write_pending_remaining -= 1;
|
||||
return Poll::Pending;
|
||||
}
|
||||
Pin::new(&mut me.inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
let me = self.as_mut().get_mut();
|
||||
if me.flush_pending_remaining > 0 {
|
||||
me.flush_pending_remaining -= 1;
|
||||
return Poll::Pending;
|
||||
}
|
||||
Pin::new(&mut me.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn seeded_rng(seed: u64) -> StdRng {
|
||||
StdRng::seed_from_u64(seed)
|
||||
}
|
||||
|
||||
pub fn tls_only_config() -> Arc<ProxyConfig> {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.general.modes.tls = true;
|
||||
Arc::new(cfg)
|
||||
}
|
||||
|
||||
pub fn handshake_test_config(secret_hex: &str) -> ProxyConfig {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.access.users.clear();
|
||||
cfg.access
|
||||
.users
|
||||
.insert("test-user".to_string(), secret_hex.to_string());
|
||||
cfg.access.ignore_time_skew = true;
|
||||
cfg.censorship.mask = true;
|
||||
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||
cfg.censorship.mask_port = 0;
|
||||
cfg
|
||||
}
|
||||
+2
-2
@@ -159,8 +159,8 @@ MemoryDenyWriteExecute=true
|
||||
LockPersonality=true
|
||||
|
||||
# Allow binding to privileged ports and writing to specific paths
|
||||
AmbientCapabilities=CAP_NET_BIND_SERVICE
|
||||
CapabilityBoundingSet=CAP_NET_BIND_SERVICE
|
||||
AmbientCapabilities=CAP_NET_BIND_SERVICE CAP_NET_ADMIN
|
||||
CapabilityBoundingSet=CAP_NET_BIND_SERVICE CAP_NET_ADMIN
|
||||
ReadWritePaths=/etc/telemt /var/run /var/lib/telemt
|
||||
|
||||
[Install]
|
||||
|
||||
+205
-1
@@ -91,6 +91,17 @@ pub struct Stats {
|
||||
current_connections_direct: AtomicU64,
|
||||
current_connections_me: AtomicU64,
|
||||
handshake_timeouts: AtomicU64,
|
||||
accept_permit_timeout_total: AtomicU64,
|
||||
conntrack_control_enabled_gauge: AtomicBool,
|
||||
conntrack_control_available_gauge: AtomicBool,
|
||||
conntrack_pressure_active_gauge: AtomicBool,
|
||||
conntrack_event_queue_depth_gauge: AtomicU64,
|
||||
conntrack_rule_apply_ok_gauge: AtomicBool,
|
||||
conntrack_delete_attempt_total: AtomicU64,
|
||||
conntrack_delete_success_total: AtomicU64,
|
||||
conntrack_delete_not_found_total: AtomicU64,
|
||||
conntrack_delete_error_total: AtomicU64,
|
||||
conntrack_close_event_drop_total: AtomicU64,
|
||||
upstream_connect_attempt_total: AtomicU64,
|
||||
upstream_connect_success_total: AtomicU64,
|
||||
upstream_connect_fail_total: AtomicU64,
|
||||
@@ -200,6 +211,14 @@ pub struct Stats {
|
||||
me_d2c_flush_duration_us_bucket_1001_5000: AtomicU64,
|
||||
me_d2c_flush_duration_us_bucket_5001_20000: AtomicU64,
|
||||
me_d2c_flush_duration_us_bucket_gt_20000: AtomicU64,
|
||||
// Buffer pool gauges
|
||||
buffer_pool_pooled_gauge: AtomicU64,
|
||||
buffer_pool_allocated_gauge: AtomicU64,
|
||||
buffer_pool_in_use_gauge: AtomicU64,
|
||||
// C2ME enqueue observability
|
||||
me_c2me_send_full_total: AtomicU64,
|
||||
me_c2me_send_high_water_total: AtomicU64,
|
||||
me_c2me_send_timeout_total: AtomicU64,
|
||||
me_d2c_batch_timeout_armed_total: AtomicU64,
|
||||
me_d2c_batch_timeout_fired_total: AtomicU64,
|
||||
me_writer_pick_sorted_rr_success_try_total: AtomicU64,
|
||||
@@ -520,6 +539,74 @@ impl Stats {
|
||||
self.handshake_timeouts.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment_accept_permit_timeout_total(&self) {
|
||||
if self.telemetry_core_enabled() {
|
||||
self.accept_permit_timeout_total
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_conntrack_control_enabled(&self, enabled: bool) {
|
||||
self.conntrack_control_enabled_gauge
|
||||
.store(enabled, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn set_conntrack_control_available(&self, available: bool) {
|
||||
self.conntrack_control_available_gauge
|
||||
.store(available, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn set_conntrack_pressure_active(&self, active: bool) {
|
||||
self.conntrack_pressure_active_gauge
|
||||
.store(active, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn set_conntrack_event_queue_depth(&self, depth: u64) {
|
||||
self.conntrack_event_queue_depth_gauge
|
||||
.store(depth, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn set_conntrack_rule_apply_ok(&self, ok: bool) {
|
||||
self.conntrack_rule_apply_ok_gauge
|
||||
.store(ok, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn increment_conntrack_delete_attempt_total(&self) {
|
||||
if self.telemetry_core_enabled() {
|
||||
self.conntrack_delete_attempt_total
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment_conntrack_delete_success_total(&self) {
|
||||
if self.telemetry_core_enabled() {
|
||||
self.conntrack_delete_success_total
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment_conntrack_delete_not_found_total(&self) {
|
||||
if self.telemetry_core_enabled() {
|
||||
self.conntrack_delete_not_found_total
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment_conntrack_delete_error_total(&self) {
|
||||
if self.telemetry_core_enabled() {
|
||||
self.conntrack_delete_error_total
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment_conntrack_close_event_drop_total(&self) {
|
||||
if self.telemetry_core_enabled() {
|
||||
self.conntrack_close_event_drop_total
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment_upstream_connect_attempt_total(&self) {
|
||||
if self.telemetry_core_enabled() {
|
||||
self.upstream_connect_attempt_total
|
||||
@@ -1414,6 +1501,37 @@ impl Stats {
|
||||
.store(value, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_buffer_pool_gauges(&self, pooled: usize, allocated: usize, in_use: usize) {
|
||||
if self.telemetry_me_allows_normal() {
|
||||
self.buffer_pool_pooled_gauge
|
||||
.store(pooled as u64, Ordering::Relaxed);
|
||||
self.buffer_pool_allocated_gauge
|
||||
.store(allocated as u64, Ordering::Relaxed);
|
||||
self.buffer_pool_in_use_gauge
|
||||
.store(in_use as u64, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment_me_c2me_send_full_total(&self) {
|
||||
if self.telemetry_me_allows_normal() {
|
||||
self.me_c2me_send_full_total.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment_me_c2me_send_high_water_total(&self) {
|
||||
if self.telemetry_me_allows_normal() {
|
||||
self.me_c2me_send_high_water_total
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn increment_me_c2me_send_timeout_total(&self) {
|
||||
if self.telemetry_me_allows_normal() {
|
||||
self.me_c2me_send_timeout_total
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
pub fn increment_me_floor_cap_block_total(&self) {
|
||||
if self.telemetry_me_allows_normal() {
|
||||
self.me_floor_cap_block_total
|
||||
@@ -1438,6 +1556,9 @@ impl Stats {
|
||||
pub fn get_connects_bad(&self) -> u64 {
|
||||
self.connects_bad.load(Ordering::Relaxed)
|
||||
}
|
||||
pub fn get_accept_permit_timeout_total(&self) -> u64 {
|
||||
self.accept_permit_timeout_total.load(Ordering::Relaxed)
|
||||
}
|
||||
pub fn get_current_connections_direct(&self) -> u64 {
|
||||
self.current_connections_direct.load(Ordering::Relaxed)
|
||||
}
|
||||
@@ -1448,6 +1569,40 @@ impl Stats {
|
||||
self.get_current_connections_direct()
|
||||
.saturating_add(self.get_current_connections_me())
|
||||
}
|
||||
pub fn get_conntrack_control_enabled(&self) -> bool {
|
||||
self.conntrack_control_enabled_gauge.load(Ordering::Relaxed)
|
||||
}
|
||||
pub fn get_conntrack_control_available(&self) -> bool {
|
||||
self.conntrack_control_available_gauge
|
||||
.load(Ordering::Relaxed)
|
||||
}
|
||||
pub fn get_conntrack_pressure_active(&self) -> bool {
|
||||
self.conntrack_pressure_active_gauge.load(Ordering::Relaxed)
|
||||
}
|
||||
pub fn get_conntrack_event_queue_depth(&self) -> u64 {
|
||||
self.conntrack_event_queue_depth_gauge
|
||||
.load(Ordering::Relaxed)
|
||||
}
|
||||
pub fn get_conntrack_rule_apply_ok(&self) -> bool {
|
||||
self.conntrack_rule_apply_ok_gauge.load(Ordering::Relaxed)
|
||||
}
|
||||
pub fn get_conntrack_delete_attempt_total(&self) -> u64 {
|
||||
self.conntrack_delete_attempt_total.load(Ordering::Relaxed)
|
||||
}
|
||||
pub fn get_conntrack_delete_success_total(&self) -> u64 {
|
||||
self.conntrack_delete_success_total.load(Ordering::Relaxed)
|
||||
}
|
||||
pub fn get_conntrack_delete_not_found_total(&self) -> u64 {
|
||||
self.conntrack_delete_not_found_total
|
||||
.load(Ordering::Relaxed)
|
||||
}
|
||||
pub fn get_conntrack_delete_error_total(&self) -> u64 {
|
||||
self.conntrack_delete_error_total.load(Ordering::Relaxed)
|
||||
}
|
||||
pub fn get_conntrack_close_event_drop_total(&self) -> u64 {
|
||||
self.conntrack_close_event_drop_total
|
||||
.load(Ordering::Relaxed)
|
||||
}
|
||||
pub fn get_me_keepalive_sent(&self) -> u64 {
|
||||
self.me_keepalive_sent.load(Ordering::Relaxed)
|
||||
}
|
||||
@@ -1780,6 +1935,30 @@ impl Stats {
|
||||
self.me_d2c_flush_duration_us_bucket_gt_20000
|
||||
.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn get_buffer_pool_pooled_gauge(&self) -> u64 {
|
||||
self.buffer_pool_pooled_gauge.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn get_buffer_pool_allocated_gauge(&self) -> u64 {
|
||||
self.buffer_pool_allocated_gauge.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn get_buffer_pool_in_use_gauge(&self) -> u64 {
|
||||
self.buffer_pool_in_use_gauge.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn get_me_c2me_send_full_total(&self) -> u64 {
|
||||
self.me_c2me_send_full_total.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn get_me_c2me_send_high_water_total(&self) -> u64 {
|
||||
self.me_c2me_send_high_water_total.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn get_me_c2me_send_timeout_total(&self) -> u64 {
|
||||
self.me_c2me_send_timeout_total.load(Ordering::Relaxed)
|
||||
}
|
||||
pub fn get_me_d2c_batch_timeout_armed_total(&self) -> u64 {
|
||||
self.me_d2c_batch_timeout_armed_total
|
||||
.load(Ordering::Relaxed)
|
||||
@@ -2171,6 +2350,8 @@ impl ReplayShard {
|
||||
|
||||
fn cleanup(&mut self, now: Instant, window: Duration) {
|
||||
if window.is_zero() {
|
||||
self.cache.clear();
|
||||
self.queue.clear();
|
||||
return;
|
||||
}
|
||||
let cutoff = now.checked_sub(window).unwrap_or(now);
|
||||
@@ -2192,13 +2373,22 @@ impl ReplayShard {
|
||||
}
|
||||
|
||||
fn check(&mut self, key: &[u8], now: Instant, window: Duration) -> bool {
|
||||
if window.is_zero() {
|
||||
return false;
|
||||
}
|
||||
self.cleanup(now, window);
|
||||
// key is &[u8], resolves Q=[u8] via Box<[u8]>: Borrow<[u8]>
|
||||
self.cache.get(key).is_some()
|
||||
}
|
||||
|
||||
fn add(&mut self, key: &[u8], now: Instant, window: Duration) {
|
||||
if window.is_zero() {
|
||||
return;
|
||||
}
|
||||
self.cleanup(now, window);
|
||||
if self.cache.peek(key).is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
let seq = self.next_seq();
|
||||
let boxed_key: Box<[u8]> = key.into();
|
||||
@@ -2341,7 +2531,7 @@ impl ReplayChecker {
|
||||
let interval = if self.window.as_secs() > 60 {
|
||||
Duration::from_secs(30)
|
||||
} else {
|
||||
Duration::from_secs(self.window.as_secs().max(1) / 2)
|
||||
Duration::from_secs((self.window.as_secs().max(1) / 2).max(1))
|
||||
};
|
||||
|
||||
loop {
|
||||
@@ -2553,6 +2743,20 @@ mod tests {
|
||||
assert!(!checker.check_handshake(b"expire"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_checker_zero_window_does_not_retain_entries() {
|
||||
let checker = ReplayChecker::new(100, Duration::ZERO);
|
||||
|
||||
for _ in 0..1_000 {
|
||||
assert!(!checker.check_handshake(b"no-retain"));
|
||||
checker.add_handshake(b"no-retain");
|
||||
}
|
||||
|
||||
let stats = checker.stats();
|
||||
assert_eq!(stats.total_entries, 0);
|
||||
assert_eq!(stats.total_queue_len, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replay_checker_stats() {
|
||||
let checker = ReplayChecker::new(100, Duration::from_secs(60));
|
||||
|
||||
@@ -35,6 +35,10 @@ pub struct BufferPool {
|
||||
misses: AtomicUsize,
|
||||
/// Number of successful reuses
|
||||
hits: AtomicUsize,
|
||||
/// Number of non-standard buffers replaced with a fresh default-sized buffer
|
||||
replaced_nonstandard: AtomicUsize,
|
||||
/// Number of buffers dropped because the pool queue was full
|
||||
dropped_pool_full: AtomicUsize,
|
||||
}
|
||||
|
||||
impl BufferPool {
|
||||
@@ -52,6 +56,8 @@ impl BufferPool {
|
||||
allocated: AtomicUsize::new(0),
|
||||
misses: AtomicUsize::new(0),
|
||||
hits: AtomicUsize::new(0),
|
||||
replaced_nonstandard: AtomicUsize::new(0),
|
||||
dropped_pool_full: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,17 +97,36 @@ impl BufferPool {
|
||||
|
||||
/// Return a buffer to the pool
|
||||
fn return_buffer(&self, mut buffer: BytesMut) {
|
||||
// Clear the buffer but keep capacity
|
||||
buffer.clear();
|
||||
const MAX_RETAINED_BUFFER_FACTOR: usize = 2;
|
||||
|
||||
// Only return if we haven't exceeded max and buffer is right size
|
||||
if buffer.capacity() >= self.buffer_size {
|
||||
// Try to push to pool, if full just drop
|
||||
let _ = self.buffers.push(buffer);
|
||||
// Clear the buffer but keep capacity.
|
||||
buffer.clear();
|
||||
let max_retained_capacity = self
|
||||
.buffer_size
|
||||
.saturating_mul(MAX_RETAINED_BUFFER_FACTOR)
|
||||
.max(self.buffer_size);
|
||||
|
||||
// Keep only near-default capacities in the pool. Oversized buffers keep
|
||||
// RSS elevated for hours under churn; replace them with default-sized
|
||||
// buffers before re-pooling.
|
||||
if buffer.capacity() < self.buffer_size || buffer.capacity() > max_retained_capacity {
|
||||
self.replaced_nonstandard.fetch_add(1, Ordering::Relaxed);
|
||||
buffer = BytesMut::with_capacity(self.buffer_size);
|
||||
}
|
||||
// If buffer was dropped (pool full), decrement allocated
|
||||
// Actually we don't decrement here because the buffer might have been
|
||||
// grown beyond our size - we just let it go
|
||||
|
||||
// Try to return into the queue; if full, drop and update accounting.
|
||||
if self.buffers.push(buffer).is_err() {
|
||||
self.dropped_pool_full.fetch_add(1, Ordering::Relaxed);
|
||||
self.decrement_allocated();
|
||||
}
|
||||
}
|
||||
|
||||
fn decrement_allocated(&self) {
|
||||
let _ = self
|
||||
.allocated
|
||||
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
|
||||
Some(current.saturating_sub(1))
|
||||
});
|
||||
}
|
||||
|
||||
/// Get pool statistics
|
||||
@@ -113,6 +138,8 @@ impl BufferPool {
|
||||
buffer_size: self.buffer_size,
|
||||
hits: self.hits.load(Ordering::Relaxed),
|
||||
misses: self.misses.load(Ordering::Relaxed),
|
||||
replaced_nonstandard: self.replaced_nonstandard.load(Ordering::Relaxed),
|
||||
dropped_pool_full: self.dropped_pool_full.load(Ordering::Relaxed),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,6 +148,41 @@ impl BufferPool {
|
||||
self.buffer_size
|
||||
}
|
||||
|
||||
/// Maximum number of buffers the pool will retain.
|
||||
pub fn max_buffers(&self) -> usize {
|
||||
self.max_buffers
|
||||
}
|
||||
|
||||
/// Current number of pooled buffers.
|
||||
pub fn pooled(&self) -> usize {
|
||||
self.buffers.len()
|
||||
}
|
||||
|
||||
/// Total buffers allocated (pooled + checked out).
|
||||
pub fn allocated(&self) -> usize {
|
||||
self.allocated.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Best-effort number of buffers currently checked out.
|
||||
pub fn in_use(&self) -> usize {
|
||||
self.allocated().saturating_sub(self.pooled())
|
||||
}
|
||||
|
||||
/// Trim pooled buffers down to a target count.
|
||||
pub fn trim_to(&self, target_pooled: usize) {
|
||||
let target = target_pooled.min(self.max_buffers);
|
||||
loop {
|
||||
if self.buffers.len() <= target {
|
||||
break;
|
||||
}
|
||||
if self.buffers.pop().is_some() {
|
||||
self.decrement_allocated();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Preallocate buffers to fill the pool
|
||||
pub fn preallocate(&self, count: usize) {
|
||||
let to_alloc = count.min(self.max_buffers);
|
||||
@@ -160,6 +222,10 @@ pub struct PoolStats {
|
||||
pub hits: usize,
|
||||
/// Number of cache misses (new allocation)
|
||||
pub misses: usize,
|
||||
/// Number of non-standard buffers replaced during return
|
||||
pub replaced_nonstandard: usize,
|
||||
/// Number of buffers dropped because the pool queue was full
|
||||
pub dropped_pool_full: usize,
|
||||
}
|
||||
|
||||
impl PoolStats {
|
||||
@@ -185,6 +251,7 @@ pub struct PooledBuffer {
|
||||
impl PooledBuffer {
|
||||
/// Take the inner buffer, preventing return to pool
|
||||
pub fn take(mut self) -> BytesMut {
|
||||
self.pool.decrement_allocated();
|
||||
self.buffer.take().unwrap()
|
||||
}
|
||||
|
||||
@@ -364,6 +431,25 @@ mod tests {
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.pooled, 0);
|
||||
assert_eq!(stats.allocated, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pool_replaces_oversized_buffers() {
|
||||
let pool = Arc::new(BufferPool::with_config(1024, 10));
|
||||
|
||||
{
|
||||
let mut buf = pool.get();
|
||||
buf.reserve(8192);
|
||||
assert!(buf.capacity() > 2048);
|
||||
}
|
||||
|
||||
let stats = pool.stats();
|
||||
assert_eq!(stats.replaced_nonstandard, 1);
|
||||
assert_eq!(stats.pooled, 1);
|
||||
|
||||
let buf = pool.get();
|
||||
assert!(buf.capacity() <= 2048);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#![allow(clippy::too_many_arguments)]
|
||||
|
||||
use crc32fast::Hasher;
|
||||
use crate::crypto::{SecureRandom, sha256_hmac};
|
||||
use crate::protocol::constants::{
|
||||
MAX_TLS_CIPHERTEXT_SIZE, TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER,
|
||||
@@ -98,6 +99,31 @@ fn build_compact_cert_info_payload(cert_info: &ParsedCertificateInfo) -> Option<
|
||||
Some(payload)
|
||||
}
|
||||
|
||||
fn hash_compact_cert_info_payload(cert_payload: Vec<u8>) -> Option<Vec<u8>> {
|
||||
if cert_payload.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut hashed = Vec::with_capacity(cert_payload.len());
|
||||
let mut seed_hasher = Hasher::new();
|
||||
seed_hasher.update(&cert_payload);
|
||||
let mut state = seed_hasher.finalize();
|
||||
|
||||
while hashed.len() < cert_payload.len() {
|
||||
let mut hasher = Hasher::new();
|
||||
hasher.update(&state.to_le_bytes());
|
||||
hasher.update(&cert_payload);
|
||||
state = hasher.finalize();
|
||||
|
||||
let block = state.to_le_bytes();
|
||||
let remaining = cert_payload.len() - hashed.len();
|
||||
let copy_len = remaining.min(block.len());
|
||||
hashed.extend_from_slice(&block[..copy_len]);
|
||||
}
|
||||
|
||||
Some(hashed)
|
||||
}
|
||||
|
||||
/// Build a ServerHello + CCS + ApplicationData sequence using cached TLS metadata.
|
||||
pub fn build_emulated_server_hello(
|
||||
secret: &[u8],
|
||||
@@ -190,7 +216,8 @@ pub fn build_emulated_server_hello(
|
||||
let compact_payload = cached
|
||||
.cert_info
|
||||
.as_ref()
|
||||
.and_then(build_compact_cert_info_payload);
|
||||
.and_then(build_compact_cert_info_payload)
|
||||
.and_then(hash_compact_cert_info_payload);
|
||||
let selected_payload: Option<&[u8]> = if use_full_cert_payload {
|
||||
cached
|
||||
.cert_payload
|
||||
@@ -221,7 +248,6 @@ pub fn build_emulated_server_hello(
|
||||
marker.extend_from_slice(proto);
|
||||
marker
|
||||
});
|
||||
let mut payload_offset = 0usize;
|
||||
for (idx, size) in sizes.into_iter().enumerate() {
|
||||
let mut rec = Vec::with_capacity(5 + size);
|
||||
rec.push(TLS_RECORD_APPLICATION);
|
||||
@@ -231,11 +257,10 @@ pub fn build_emulated_server_hello(
|
||||
if let Some(payload) = selected_payload {
|
||||
if size > 17 {
|
||||
let body_len = size - 17;
|
||||
let remaining = payload.len().saturating_sub(payload_offset);
|
||||
let remaining = payload.len();
|
||||
let copy_len = remaining.min(body_len);
|
||||
if copy_len > 0 {
|
||||
rec.extend_from_slice(&payload[payload_offset..payload_offset + copy_len]);
|
||||
payload_offset += copy_len;
|
||||
rec.extend_from_slice(&payload[..copy_len]);
|
||||
}
|
||||
if body_len > copy_len {
|
||||
rec.extend_from_slice(&rng.bytes(body_len - copy_len));
|
||||
@@ -317,7 +342,9 @@ mod tests {
|
||||
CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsCertPayload, TlsProfileSource,
|
||||
};
|
||||
|
||||
use super::build_emulated_server_hello;
|
||||
use super::{
|
||||
build_compact_cert_info_payload, build_emulated_server_hello, hash_compact_cert_info_payload,
|
||||
};
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::protocol::constants::{
|
||||
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE,
|
||||
@@ -432,7 +459,21 @@ mod tests {
|
||||
);
|
||||
|
||||
let payload = first_app_data_payload(&response);
|
||||
assert!(payload.starts_with(b"CN=example.com"));
|
||||
let expected_hashed_payload = build_compact_cert_info_payload(
|
||||
cached
|
||||
.cert_info
|
||||
.as_ref()
|
||||
.expect("test fixture must provide certificate info"),
|
||||
)
|
||||
.and_then(hash_compact_cert_info_payload)
|
||||
.expect("compact certificate info payload must be present for this test");
|
||||
let copied_prefix_len = expected_hashed_payload
|
||||
.len()
|
||||
.min(payload.len().saturating_sub(17));
|
||||
assert_eq!(
|
||||
&payload[..copied_prefix_len],
|
||||
&expected_hashed_payload[..copied_prefix_len]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -55,6 +55,20 @@ struct RoutingTable {
|
||||
map: DashMap<u64, mpsc::Sender<MeResponse>>,
|
||||
}
|
||||
|
||||
struct WriterTable {
|
||||
map: DashMap<u64, mpsc::Sender<WriterCommand>>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct HotConnBinding {
|
||||
writer_id: u64,
|
||||
meta: ConnMeta,
|
||||
}
|
||||
|
||||
struct HotBindingTable {
|
||||
map: DashMap<u64, HotConnBinding>,
|
||||
}
|
||||
|
||||
struct BindingState {
|
||||
inner: Mutex<BindingInner>,
|
||||
}
|
||||
@@ -83,6 +97,8 @@ impl BindingInner {
|
||||
|
||||
pub struct ConnRegistry {
|
||||
routing: RoutingTable,
|
||||
writers: WriterTable,
|
||||
hot_binding: HotBindingTable,
|
||||
binding: BindingState,
|
||||
next_id: AtomicU64,
|
||||
route_channel_capacity: usize,
|
||||
@@ -105,6 +121,12 @@ impl ConnRegistry {
|
||||
routing: RoutingTable {
|
||||
map: DashMap::new(),
|
||||
},
|
||||
writers: WriterTable {
|
||||
map: DashMap::new(),
|
||||
},
|
||||
hot_binding: HotBindingTable {
|
||||
map: DashMap::new(),
|
||||
},
|
||||
binding: BindingState {
|
||||
inner: Mutex::new(BindingInner::new()),
|
||||
},
|
||||
@@ -149,16 +171,18 @@ impl ConnRegistry {
|
||||
|
||||
pub async fn register_writer(&self, writer_id: u64, tx: mpsc::Sender<WriterCommand>) {
|
||||
let mut binding = self.binding.inner.lock().await;
|
||||
binding.writers.insert(writer_id, tx);
|
||||
binding.writers.insert(writer_id, tx.clone());
|
||||
binding
|
||||
.conns_for_writer
|
||||
.entry(writer_id)
|
||||
.or_insert_with(HashSet::new);
|
||||
self.writers.map.insert(writer_id, tx);
|
||||
}
|
||||
|
||||
/// Unregister connection, returning associated writer_id if any.
|
||||
pub async fn unregister(&self, id: u64) -> Option<u64> {
|
||||
self.routing.map.remove(&id);
|
||||
self.hot_binding.map.remove(&id);
|
||||
let mut binding = self.binding.inner.lock().await;
|
||||
binding.meta.remove(&id);
|
||||
if let Some(writer_id) = binding.writer_for_conn.remove(&id) {
|
||||
@@ -325,13 +349,20 @@ impl ConnRegistry {
|
||||
}
|
||||
|
||||
binding.meta.insert(conn_id, meta.clone());
|
||||
binding.last_meta_for_writer.insert(writer_id, meta);
|
||||
binding.last_meta_for_writer.insert(writer_id, meta.clone());
|
||||
binding.writer_idle_since_epoch_secs.remove(&writer_id);
|
||||
binding
|
||||
.conns_for_writer
|
||||
.entry(writer_id)
|
||||
.or_insert_with(HashSet::new)
|
||||
.insert(conn_id);
|
||||
self.hot_binding.map.insert(
|
||||
conn_id,
|
||||
HotConnBinding {
|
||||
writer_id,
|
||||
meta,
|
||||
},
|
||||
);
|
||||
true
|
||||
}
|
||||
|
||||
@@ -392,39 +423,12 @@ impl ConnRegistry {
|
||||
}
|
||||
|
||||
pub async fn get_writer(&self, conn_id: u64) -> Option<ConnWriter> {
|
||||
let mut binding = self.binding.inner.lock().await;
|
||||
// ROUTING IS THE SOURCE OF TRUTH:
|
||||
// stale bindings are ignored and lazily cleaned when routing no longer
|
||||
// contains the connection.
|
||||
if !self.routing.map.contains_key(&conn_id) {
|
||||
binding.meta.remove(&conn_id);
|
||||
if let Some(stale_writer_id) = binding.writer_for_conn.remove(&conn_id)
|
||||
&& let Some(conns) = binding.conns_for_writer.get_mut(&stale_writer_id)
|
||||
{
|
||||
conns.remove(&conn_id);
|
||||
if conns.is_empty() {
|
||||
binding
|
||||
.writer_idle_since_epoch_secs
|
||||
.insert(stale_writer_id, Self::now_epoch_secs());
|
||||
}
|
||||
}
|
||||
return None;
|
||||
}
|
||||
|
||||
let writer_id = binding.writer_for_conn.get(&conn_id).copied()?;
|
||||
let Some(writer) = binding.writers.get(&writer_id).cloned() else {
|
||||
binding.writer_for_conn.remove(&conn_id);
|
||||
binding.meta.remove(&conn_id);
|
||||
if let Some(conns) = binding.conns_for_writer.get_mut(&writer_id) {
|
||||
conns.remove(&conn_id);
|
||||
if conns.is_empty() {
|
||||
binding
|
||||
.writer_idle_since_epoch_secs
|
||||
.insert(writer_id, Self::now_epoch_secs());
|
||||
}
|
||||
}
|
||||
return None;
|
||||
};
|
||||
let writer_id = self.hot_binding.map.get(&conn_id).map(|entry| entry.writer_id)?;
|
||||
let writer = self.writers.map.get(&writer_id).map(|entry| entry.value().clone())?;
|
||||
Some(ConnWriter {
|
||||
writer_id,
|
||||
tx: writer,
|
||||
@@ -439,6 +443,7 @@ impl ConnRegistry {
|
||||
pub async fn writer_lost(&self, writer_id: u64) -> Vec<BoundConn> {
|
||||
let mut binding = self.binding.inner.lock().await;
|
||||
binding.writers.remove(&writer_id);
|
||||
self.writers.map.remove(&writer_id);
|
||||
binding.last_meta_for_writer.remove(&writer_id);
|
||||
binding.writer_idle_since_epoch_secs.remove(&writer_id);
|
||||
let conns = binding
|
||||
@@ -454,6 +459,15 @@ impl ConnRegistry {
|
||||
continue;
|
||||
}
|
||||
binding.writer_for_conn.remove(&conn_id);
|
||||
let remove_hot = self
|
||||
.hot_binding
|
||||
.map
|
||||
.get(&conn_id)
|
||||
.map(|hot| hot.writer_id == writer_id)
|
||||
.unwrap_or(false);
|
||||
if remove_hot {
|
||||
self.hot_binding.map.remove(&conn_id);
|
||||
}
|
||||
if let Some(m) = binding.meta.get(&conn_id) {
|
||||
out.push(BoundConn {
|
||||
conn_id,
|
||||
@@ -466,8 +480,10 @@ impl ConnRegistry {
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn get_meta(&self, conn_id: u64) -> Option<ConnMeta> {
|
||||
let binding = self.binding.inner.lock().await;
|
||||
binding.meta.get(&conn_id).cloned()
|
||||
self.hot_binding
|
||||
.map
|
||||
.get(&conn_id)
|
||||
.map(|entry| entry.meta.clone())
|
||||
}
|
||||
|
||||
pub async fn is_writer_empty(&self, writer_id: u64) -> bool {
|
||||
@@ -491,6 +507,7 @@ impl ConnRegistry {
|
||||
}
|
||||
|
||||
binding.writers.remove(&writer_id);
|
||||
self.writers.map.remove(&writer_id);
|
||||
binding.last_meta_for_writer.remove(&writer_id);
|
||||
binding.writer_idle_since_epoch_secs.remove(&writer_id);
|
||||
binding.conns_for_writer.remove(&writer_id);
|
||||
|
||||
Reference in New Issue
Block a user