mirror of
https://github.com/telemt/telemt.git
synced 2026-05-23 20:21:44 +03:00
Compare commits
19 Commits
3.3.27
...
4f55d08c51
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f55d08c51 | ||
|
|
93caab1aec | ||
|
|
0c6bb3a641 | ||
|
|
b2e15327fe | ||
|
|
2e8be87ccf | ||
|
|
d78360982c | ||
|
|
822bcbf7a5 | ||
|
|
b25ec97a43 | ||
|
|
8821e38013 | ||
|
|
a1caebbe6f | ||
|
|
e0d821c6b6 | ||
|
|
205fc88718 | ||
|
|
e4a50f9286 | ||
|
|
213ce4555a | ||
|
|
5a16e68487 | ||
|
|
6ffbc51fb0 | ||
|
|
dcab19a64f | ||
|
|
f10ca192fa | ||
|
|
2bd9036908 |
15
.cargo/deny.toml
Normal file
15
.cargo/deny.toml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
[bans]
|
||||||
|
multiple-versions = "deny"
|
||||||
|
wildcards = "allow"
|
||||||
|
highlight = "all"
|
||||||
|
|
||||||
|
# Explicitly flag the weak cryptography so the agent is forced to justify its existence
|
||||||
|
[[bans.skip]]
|
||||||
|
name = "md-5"
|
||||||
|
version = "*"
|
||||||
|
reason = "MUST VERIFY: Only allowed for legacy checksums, never for security."
|
||||||
|
|
||||||
|
[[bans.skip]]
|
||||||
|
name = "sha1"
|
||||||
|
version = "*"
|
||||||
|
reason = "MUST VERIFY: Only allowed for backwards compatibility."
|
||||||
16
AGENTS.md
16
AGENTS.md
@@ -5,6 +5,22 @@ Your responses are precise, minimal, and architecturally sound. You are working
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
### Context: The Telemt Project
|
||||||
|
|
||||||
|
You are working on **Telemt**, a high-performance, production-grade Telegram MTProxy implementation written in Rust. It is explicitly designed to operate in highly hostile network environments and evade advanced network censorship.
|
||||||
|
|
||||||
|
**Adversarial Threat Model:**
|
||||||
|
The proxy operates under constant surveillance by DPI (Deep Packet Inspection) systems and active scanners (state firewalls, mobile operator fraud controls). These entities actively probe IPs, analyze protocol handshakes, and look for known proxy signatures to block or throttle traffic.
|
||||||
|
|
||||||
|
**Core Architectural Pillars:**
|
||||||
|
1. **TLS-Fronting (TLS-F) & TCP-Splitting (TCP-S):** To the outside world, Telemt looks like a standard TLS server. If a client presents a valid MTProxy key, the connection is handled internally. If a censor's scanner, web browser, or unauthorized crawler connects, Telemt seamlessly splices the TCP connection (L4) to a real, legitimate HTTPS fallback server (e.g., Nginx) without modifying the `ClientHello` or terminating the TLS handshake.
|
||||||
|
2. **Middle-End (ME) Orchestration:** A highly concurrent, generation-based pool managing upstream connections to Telegram Datacenters (DCs). It utilizes an **Adaptive Floor** (dynamically scaling writer connections based on traffic), **Hardswaps** (zero-downtime pool reconfiguration), and **STUN/NAT** reflection mechanisms.
|
||||||
|
3. **Strict KDF Routing:** Cryptographic Key Derivation Functions (KDF) in this protocol strictly rely on the exact pairing of Source IP/Port and Destination IP/Port. Deviations or missing port logic will silently break the MTProto handshake.
|
||||||
|
4. **Data Plane vs. Control Plane Isolation:** The Data Plane (readers, writers, payload relay, TCP splicing) must remain strictly non-blocking, zero-allocation in hot paths, and highly resilient to network backpressure. The Control Plane (API, metrics, pool generation swaps, config reloads) orchestrates the state asynchronously without stalling the Data Plane.
|
||||||
|
|
||||||
|
Any modification you make must preserve Telemt's invisibility to censors, its strict memory-safety invariants, and its hot-path throughput.
|
||||||
|
|
||||||
|
|
||||||
### 0. Priority Resolution — Scope Control
|
### 0. Priority Resolution — Scope Control
|
||||||
|
|
||||||
This section resolves conflicts between code quality enforcement and scope limitation.
|
This section resolves conflicts between code quality enforcement and scope limitation.
|
||||||
|
|||||||
10
Cargo.lock
generated
10
Cargo.lock
generated
@@ -2025,6 +2025,12 @@ version = "1.2.1"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
|
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "static_assertions"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "subtle"
|
name = "subtle"
|
||||||
version = "2.6.1"
|
version = "2.6.1"
|
||||||
@@ -2087,7 +2093,7 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "telemt"
|
name = "telemt"
|
||||||
version = "3.3.15"
|
version = "3.3.19"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aes",
|
"aes",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
@@ -2127,6 +2133,8 @@ dependencies = [
|
|||||||
"sha1",
|
"sha1",
|
||||||
"sha2",
|
"sha2",
|
||||||
"socket2 0.5.10",
|
"socket2 0.5.10",
|
||||||
|
"static_assertions",
|
||||||
|
"subtle",
|
||||||
"thiserror 2.0.18",
|
"thiserror 2.0.18",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-rustls",
|
"tokio-rustls",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "telemt"
|
name = "telemt"
|
||||||
version = "3.3.19"
|
version = "3.3.20"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
@@ -22,6 +22,8 @@ hmac = "0.12"
|
|||||||
crc32fast = "1.4"
|
crc32fast = "1.4"
|
||||||
crc32c = "0.6"
|
crc32c = "0.6"
|
||||||
zeroize = { version = "1.8", features = ["derive"] }
|
zeroize = { version = "1.8", features = ["derive"] }
|
||||||
|
subtle = "2.6"
|
||||||
|
static_assertions = "1.1"
|
||||||
|
|
||||||
# Network
|
# Network
|
||||||
socket2 = { version = "0.5", features = ["all"] }
|
socket2 = { version = "0.5", features = ["all"] }
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ use crate::config::{
|
|||||||
};
|
};
|
||||||
use super::load::{LoadedConfig, ProxyConfig};
|
use super::load::{LoadedConfig, ProxyConfig};
|
||||||
|
|
||||||
const HOT_RELOAD_STABLE_SNAPSHOTS: u8 = 2;
|
|
||||||
const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50);
|
const HOT_RELOAD_DEBOUNCE: Duration = Duration::from_millis(50);
|
||||||
|
|
||||||
// ── Hot fields ────────────────────────────────────────────────────────────────
|
// ── Hot fields ────────────────────────────────────────────────────────────────
|
||||||
@@ -329,41 +328,19 @@ impl WatchManifest {
|
|||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
struct ReloadState {
|
struct ReloadState {
|
||||||
applied_snapshot_hash: Option<u64>,
|
applied_snapshot_hash: Option<u64>,
|
||||||
candidate_snapshot_hash: Option<u64>,
|
|
||||||
candidate_hits: u8,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ReloadState {
|
impl ReloadState {
|
||||||
fn new(applied_snapshot_hash: Option<u64>) -> Self {
|
fn new(applied_snapshot_hash: Option<u64>) -> Self {
|
||||||
Self {
|
Self { applied_snapshot_hash }
|
||||||
applied_snapshot_hash,
|
|
||||||
candidate_snapshot_hash: None,
|
|
||||||
candidate_hits: 0,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_applied(&self, hash: u64) -> bool {
|
fn is_applied(&self, hash: u64) -> bool {
|
||||||
self.applied_snapshot_hash == Some(hash)
|
self.applied_snapshot_hash == Some(hash)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn observe_candidate(&mut self, hash: u64) -> u8 {
|
|
||||||
if self.candidate_snapshot_hash == Some(hash) {
|
|
||||||
self.candidate_hits = self.candidate_hits.saturating_add(1);
|
|
||||||
} else {
|
|
||||||
self.candidate_snapshot_hash = Some(hash);
|
|
||||||
self.candidate_hits = 1;
|
|
||||||
}
|
|
||||||
self.candidate_hits
|
|
||||||
}
|
|
||||||
|
|
||||||
fn reset_candidate(&mut self) {
|
|
||||||
self.candidate_snapshot_hash = None;
|
|
||||||
self.candidate_hits = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn mark_applied(&mut self, hash: u64) {
|
fn mark_applied(&mut self, hash: u64) {
|
||||||
self.applied_snapshot_hash = Some(hash);
|
self.applied_snapshot_hash = Some(hash);
|
||||||
self.reset_candidate();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1138,7 +1115,6 @@ fn reload_config(
|
|||||||
let loaded = match ProxyConfig::load_with_metadata(config_path) {
|
let loaded = match ProxyConfig::load_with_metadata(config_path) {
|
||||||
Ok(loaded) => loaded,
|
Ok(loaded) => loaded,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
reload_state.reset_candidate();
|
|
||||||
error!("config reload: failed to parse {:?}: {}", config_path, e);
|
error!("config reload: failed to parse {:?}: {}", config_path, e);
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
@@ -1151,7 +1127,6 @@ fn reload_config(
|
|||||||
let next_manifest = WatchManifest::from_source_files(&source_files);
|
let next_manifest = WatchManifest::from_source_files(&source_files);
|
||||||
|
|
||||||
if let Err(e) = new_cfg.validate() {
|
if let Err(e) = new_cfg.validate() {
|
||||||
reload_state.reset_candidate();
|
|
||||||
error!("config reload: validation failed: {}; keeping old config", e);
|
error!("config reload: validation failed: {}; keeping old config", e);
|
||||||
return Some(next_manifest);
|
return Some(next_manifest);
|
||||||
}
|
}
|
||||||
@@ -1160,17 +1135,6 @@ fn reload_config(
|
|||||||
return Some(next_manifest);
|
return Some(next_manifest);
|
||||||
}
|
}
|
||||||
|
|
||||||
let candidate_hits = reload_state.observe_candidate(rendered_hash);
|
|
||||||
if candidate_hits < HOT_RELOAD_STABLE_SNAPSHOTS {
|
|
||||||
info!(
|
|
||||||
snapshot_hash = rendered_hash,
|
|
||||||
candidate_hits,
|
|
||||||
required_hits = HOT_RELOAD_STABLE_SNAPSHOTS,
|
|
||||||
"config reload: candidate snapshot observed but not stable yet"
|
|
||||||
);
|
|
||||||
return Some(next_manifest);
|
|
||||||
}
|
|
||||||
|
|
||||||
let old_cfg = config_tx.borrow().clone();
|
let old_cfg = config_tx.borrow().clone();
|
||||||
let applied_cfg = overlay_hot_fields(&old_cfg, &new_cfg);
|
let applied_cfg = overlay_hot_fields(&old_cfg, &new_cfg);
|
||||||
let old_hot = HotFields::from_config(&old_cfg);
|
let old_hot = HotFields::from_config(&old_cfg);
|
||||||
@@ -1190,7 +1154,6 @@ fn reload_config(
|
|||||||
if old_hot.dns_overrides != applied_hot.dns_overrides
|
if old_hot.dns_overrides != applied_hot.dns_overrides
|
||||||
&& let Err(e) = crate::network::dns_overrides::install_entries(&applied_hot.dns_overrides)
|
&& let Err(e) = crate::network::dns_overrides::install_entries(&applied_hot.dns_overrides)
|
||||||
{
|
{
|
||||||
reload_state.reset_candidate();
|
|
||||||
error!(
|
error!(
|
||||||
"config reload: invalid network.dns_overrides: {}; keeping old config",
|
"config reload: invalid network.dns_overrides: {}; keeping old config",
|
||||||
e
|
e
|
||||||
@@ -1334,14 +1297,28 @@ pub fn spawn_config_watcher(
|
|||||||
tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await;
|
tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await;
|
||||||
while notify_rx.try_recv().is_ok() {}
|
while notify_rx.try_recv().is_ok() {}
|
||||||
|
|
||||||
if let Some(next_manifest) = reload_config(
|
let mut next_manifest = reload_config(
|
||||||
&config_path,
|
&config_path,
|
||||||
&config_tx,
|
&config_tx,
|
||||||
&log_tx,
|
&log_tx,
|
||||||
detected_ip_v4,
|
detected_ip_v4,
|
||||||
detected_ip_v6,
|
detected_ip_v6,
|
||||||
&mut reload_state,
|
&mut reload_state,
|
||||||
) {
|
);
|
||||||
|
if next_manifest.is_none() {
|
||||||
|
tokio::time::sleep(HOT_RELOAD_DEBOUNCE).await;
|
||||||
|
while notify_rx.try_recv().is_ok() {}
|
||||||
|
next_manifest = reload_config(
|
||||||
|
&config_path,
|
||||||
|
&config_tx,
|
||||||
|
&log_tx,
|
||||||
|
detected_ip_v4,
|
||||||
|
detected_ip_v6,
|
||||||
|
&mut reload_state,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(next_manifest) = next_manifest {
|
||||||
apply_watch_manifest(
|
apply_watch_manifest(
|
||||||
inotify_watcher.as_mut(),
|
inotify_watcher.as_mut(),
|
||||||
poll_watcher.as_mut(),
|
poll_watcher.as_mut(),
|
||||||
@@ -1466,7 +1443,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn reload_requires_stable_snapshot_before_hot_apply() {
|
fn reload_applies_hot_change_on_first_observed_snapshot() {
|
||||||
let initial_tag = "11111111111111111111111111111111";
|
let initial_tag = "11111111111111111111111111111111";
|
||||||
let final_tag = "22222222222222222222222222222222";
|
let final_tag = "22222222222222222222222222222222";
|
||||||
let path = temp_config_path("telemt_hot_reload_stable");
|
let path = temp_config_path("telemt_hot_reload_stable");
|
||||||
@@ -1478,20 +1455,7 @@ mod tests {
|
|||||||
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
|
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
|
||||||
let mut reload_state = ReloadState::new(Some(initial_hash));
|
let mut reload_state = ReloadState::new(Some(initial_hash));
|
||||||
|
|
||||||
write_reload_config(&path, None, None);
|
|
||||||
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
config_tx.borrow().general.ad_tag.as_deref(),
|
|
||||||
Some(initial_tag)
|
|
||||||
);
|
|
||||||
|
|
||||||
write_reload_config(&path, Some(final_tag), None);
|
write_reload_config(&path, Some(final_tag), None);
|
||||||
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
|
||||||
assert_eq!(
|
|
||||||
config_tx.borrow().general.ad_tag.as_deref(),
|
|
||||||
Some(initial_tag)
|
|
||||||
);
|
|
||||||
|
|
||||||
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||||
assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag));
|
assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag));
|
||||||
|
|
||||||
@@ -1513,7 +1477,6 @@ mod tests {
|
|||||||
|
|
||||||
write_reload_config(&path, Some(final_tag), Some(initial_cfg.server.port + 1));
|
write_reload_config(&path, Some(final_tag), Some(initial_cfg.server.port + 1));
|
||||||
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||||
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
|
||||||
|
|
||||||
let applied = config_tx.borrow().clone();
|
let applied = config_tx.borrow().clone();
|
||||||
assert_eq!(applied.general.ad_tag.as_deref(), Some(final_tag));
|
assert_eq!(applied.general.ad_tag.as_deref(), Some(final_tag));
|
||||||
@@ -1521,4 +1484,31 @@ mod tests {
|
|||||||
|
|
||||||
let _ = std::fs::remove_file(path);
|
let _ = std::fs::remove_file(path);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reload_recovers_after_parse_error_on_next_attempt() {
|
||||||
|
let initial_tag = "cccccccccccccccccccccccccccccccc";
|
||||||
|
let final_tag = "dddddddddddddddddddddddddddddddd";
|
||||||
|
let path = temp_config_path("telemt_hot_reload_parse_recovery");
|
||||||
|
|
||||||
|
write_reload_config(&path, Some(initial_tag), None);
|
||||||
|
let initial_cfg = Arc::new(ProxyConfig::load(&path).unwrap());
|
||||||
|
let initial_hash = ProxyConfig::load_with_metadata(&path).unwrap().rendered_hash;
|
||||||
|
let (config_tx, _config_rx) = watch::channel(initial_cfg.clone());
|
||||||
|
let (log_tx, _log_rx) = watch::channel(initial_cfg.general.log_level.clone());
|
||||||
|
let mut reload_state = ReloadState::new(Some(initial_hash));
|
||||||
|
|
||||||
|
std::fs::write(&path, "[access.users\nuser = \"broken\"\n").unwrap();
|
||||||
|
assert!(reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).is_none());
|
||||||
|
assert_eq!(
|
||||||
|
config_tx.borrow().general.ad_tag.as_deref(),
|
||||||
|
Some(initial_tag)
|
||||||
|
);
|
||||||
|
|
||||||
|
write_reload_config(&path, Some(final_tag), None);
|
||||||
|
reload_config(&path, &config_tx, &log_tx, None, None, &mut reload_state).unwrap();
|
||||||
|
assert_eq!(config_tx.borrow().general.ad_tag.as_deref(), Some(final_tag));
|
||||||
|
|
||||||
|
let _ = std::fs::remove_file(path);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1156,6 +1156,13 @@ pub struct ServerConfig {
|
|||||||
#[serde(default = "default_proxy_protocol_header_timeout_ms")]
|
#[serde(default = "default_proxy_protocol_header_timeout_ms")]
|
||||||
pub proxy_protocol_header_timeout_ms: u64,
|
pub proxy_protocol_header_timeout_ms: u64,
|
||||||
|
|
||||||
|
/// Trusted source CIDRs allowed to send incoming PROXY protocol headers.
|
||||||
|
///
|
||||||
|
/// When non-empty, connections from addresses outside this allowlist are
|
||||||
|
/// rejected before `src_addr` is applied.
|
||||||
|
#[serde(default)]
|
||||||
|
pub proxy_protocol_trusted_cidrs: Vec<IpNetwork>,
|
||||||
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub metrics_port: Option<u16>,
|
pub metrics_port: Option<u16>,
|
||||||
|
|
||||||
@@ -1185,6 +1192,7 @@ impl Default for ServerConfig {
|
|||||||
listen_tcp: None,
|
listen_tcp: None,
|
||||||
proxy_protocol: false,
|
proxy_protocol: false,
|
||||||
proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(),
|
proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(),
|
||||||
|
proxy_protocol_trusted_cidrs: Vec::new(),
|
||||||
metrics_port: None,
|
metrics_port: None,
|
||||||
metrics_whitelist: default_metrics_whitelist(),
|
metrics_whitelist: default_metrics_whitelist(),
|
||||||
api: ApiConfig::default(),
|
api: ApiConfig::default(),
|
||||||
|
|||||||
@@ -10,6 +10,16 @@ use crate::transport::middle_proxy::{
|
|||||||
ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache,
|
ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub(crate) fn resolve_runtime_config_path(config_path_cli: &str, startup_cwd: &std::path::Path) -> PathBuf {
|
||||||
|
let raw = PathBuf::from(config_path_cli);
|
||||||
|
let absolute = if raw.is_absolute() {
|
||||||
|
raw
|
||||||
|
} else {
|
||||||
|
startup_cwd.join(raw)
|
||||||
|
};
|
||||||
|
absolute.canonicalize().unwrap_or(absolute)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn parse_cli() -> (String, Option<PathBuf>, bool, Option<String>) {
|
pub(crate) fn parse_cli() -> (String, Option<PathBuf>, bool, Option<String>) {
|
||||||
let mut config_path = "config.toml".to_string();
|
let mut config_path = "config.toml".to_string();
|
||||||
let mut data_path: Option<PathBuf> = None;
|
let mut data_path: Option<PathBuf> = None;
|
||||||
@@ -96,6 +106,44 @@ pub(crate) fn parse_cli() -> (String, Option<PathBuf>, bool, Option<String>) {
|
|||||||
(config_path, data_path, silent, log_level)
|
(config_path, data_path, silent, log_level)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::resolve_runtime_config_path;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_runtime_config_path_anchors_relative_to_startup_cwd() {
|
||||||
|
let nonce = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_nanos();
|
||||||
|
let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_{nonce}"));
|
||||||
|
std::fs::create_dir_all(&startup_cwd).unwrap();
|
||||||
|
let target = startup_cwd.join("config.toml");
|
||||||
|
std::fs::write(&target, " ").unwrap();
|
||||||
|
|
||||||
|
let resolved = resolve_runtime_config_path("config.toml", &startup_cwd);
|
||||||
|
assert_eq!(resolved, target.canonicalize().unwrap());
|
||||||
|
|
||||||
|
let _ = std::fs::remove_file(&target);
|
||||||
|
let _ = std::fs::remove_dir(&startup_cwd);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resolve_runtime_config_path_keeps_absolute_for_missing_file() {
|
||||||
|
let nonce = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap()
|
||||||
|
.as_nanos();
|
||||||
|
let startup_cwd = std::env::temp_dir().join(format!("telemt_cfg_path_missing_{nonce}"));
|
||||||
|
std::fs::create_dir_all(&startup_cwd).unwrap();
|
||||||
|
|
||||||
|
let resolved = resolve_runtime_config_path("missing.toml", &startup_cwd);
|
||||||
|
assert_eq!(resolved, startup_cwd.join("missing.toml"));
|
||||||
|
|
||||||
|
let _ = std::fs::remove_dir(&startup_cwd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) {
|
pub(crate) fn print_proxy_links(host: &str, port: u16, config: &ProxyConfig) {
|
||||||
info!(target: "telemt::links", "--- Proxy Links ({}) ---", host);
|
info!(target: "telemt::links", "--- Proxy Links ({}) ---", host);
|
||||||
for user_name in config.general.links.show.resolve_users(&config.access.users) {
|
for user_name in config.general.links.show.resolve_users(&config.access.users) {
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ use crate::startup::{
|
|||||||
use crate::stream::BufferPool;
|
use crate::stream::BufferPool;
|
||||||
use crate::transport::middle_proxy::MePool;
|
use crate::transport::middle_proxy::MePool;
|
||||||
use crate::transport::UpstreamManager;
|
use crate::transport::UpstreamManager;
|
||||||
use helpers::parse_cli;
|
use helpers::{parse_cli, resolve_runtime_config_path};
|
||||||
|
|
||||||
/// Runs the full telemt runtime startup pipeline and blocks until shutdown.
|
/// Runs the full telemt runtime startup pipeline and blocks until shutdown.
|
||||||
pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||||
@@ -58,18 +58,26 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
|||||||
startup_tracker
|
startup_tracker
|
||||||
.start_component(COMPONENT_CONFIG_LOAD, Some("load and validate config".to_string()))
|
.start_component(COMPONENT_CONFIG_LOAD, Some("load and validate config".to_string()))
|
||||||
.await;
|
.await;
|
||||||
let (config_path, data_path, cli_silent, cli_log_level) = parse_cli();
|
let (config_path_cli, data_path, cli_silent, cli_log_level) = parse_cli();
|
||||||
|
let startup_cwd = match std::env::current_dir() {
|
||||||
|
Ok(cwd) => cwd,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("[telemt] Can't read current_dir: {}", e);
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let config_path = resolve_runtime_config_path(&config_path_cli, &startup_cwd);
|
||||||
|
|
||||||
let mut config = match ProxyConfig::load(&config_path) {
|
let mut config = match ProxyConfig::load(&config_path) {
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
if std::path::Path::new(&config_path).exists() {
|
if config_path.exists() {
|
||||||
eprintln!("[telemt] Error: {}", e);
|
eprintln!("[telemt] Error: {}", e);
|
||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
} else {
|
} else {
|
||||||
let default = ProxyConfig::default();
|
let default = ProxyConfig::default();
|
||||||
std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
|
std::fs::write(&config_path, toml::to_string_pretty(&default).unwrap()).unwrap();
|
||||||
eprintln!("[telemt] Created default config at {}", config_path);
|
eprintln!("[telemt] Created default config at {}", config_path.display());
|
||||||
default
|
default
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -258,7 +266,7 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
|||||||
let route_runtime_api = route_runtime.clone();
|
let route_runtime_api = route_runtime.clone();
|
||||||
let config_rx_api = api_config_rx.clone();
|
let config_rx_api = api_config_rx.clone();
|
||||||
let admission_rx_api = admission_rx.clone();
|
let admission_rx_api = admission_rx.clone();
|
||||||
let config_path_api = std::path::PathBuf::from(&config_path);
|
let config_path_api = config_path.clone();
|
||||||
let startup_tracker_api = startup_tracker.clone();
|
let startup_tracker_api = startup_tracker.clone();
|
||||||
let detected_ips_rx_api = detected_ips_rx.clone();
|
let detected_ips_rx_api = detected_ips_rx.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use std::net::IpAddr;
|
use std::net::IpAddr;
|
||||||
use std::path::PathBuf;
|
use std::path::Path;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use tokio::sync::{mpsc, watch};
|
use tokio::sync::{mpsc, watch};
|
||||||
@@ -32,7 +32,7 @@ pub(crate) struct RuntimeWatches {
|
|||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) async fn spawn_runtime_tasks(
|
pub(crate) async fn spawn_runtime_tasks(
|
||||||
config: &Arc<ProxyConfig>,
|
config: &Arc<ProxyConfig>,
|
||||||
config_path: &str,
|
config_path: &Path,
|
||||||
probe: &NetworkProbe,
|
probe: &NetworkProbe,
|
||||||
prefer_ipv6: bool,
|
prefer_ipv6: bool,
|
||||||
decision_ipv4_dc: bool,
|
decision_ipv4_dc: bool,
|
||||||
@@ -83,7 +83,7 @@ pub(crate) async fn spawn_runtime_tasks(
|
|||||||
watch::Receiver<Arc<ProxyConfig>>,
|
watch::Receiver<Arc<ProxyConfig>>,
|
||||||
watch::Receiver<LogLevel>,
|
watch::Receiver<LogLevel>,
|
||||||
) = spawn_config_watcher(
|
) = spawn_config_watcher(
|
||||||
PathBuf::from(config_path),
|
config_path.to_path_buf(),
|
||||||
config.clone(),
|
config.clone(),
|
||||||
detected_ip_v4,
|
detected_ip_v4,
|
||||||
detected_ip_v6,
|
detected_ip_v6,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ use super::constants::*;
|
|||||||
use std::time::{SystemTime, UNIX_EPOCH};
|
use std::time::{SystemTime, UNIX_EPOCH};
|
||||||
use num_bigint::BigUint;
|
use num_bigint::BigUint;
|
||||||
use num_traits::One;
|
use num_traits::One;
|
||||||
|
use subtle::ConstantTimeEq;
|
||||||
|
|
||||||
// ============= Public Constants =============
|
// ============= Public Constants =============
|
||||||
|
|
||||||
@@ -28,6 +29,8 @@ pub const TLS_DIGEST_HALF_LEN: usize = 16;
|
|||||||
/// Time skew limits for anti-replay (in seconds)
|
/// Time skew limits for anti-replay (in seconds)
|
||||||
pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before
|
pub const TIME_SKEW_MIN: i64 = -20 * 60; // 20 minutes before
|
||||||
pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after
|
pub const TIME_SKEW_MAX: i64 = 10 * 60; // 10 minutes after
|
||||||
|
/// Maximum accepted boot-time timestamp (seconds) before skew checks are enforced.
|
||||||
|
pub const BOOT_TIME_MAX_SECS: u32 = 7 * 24 * 60 * 60;
|
||||||
|
|
||||||
// ============= Private Constants =============
|
// ============= Private Constants =============
|
||||||
|
|
||||||
@@ -125,7 +128,7 @@ impl TlsExtensionBuilder {
|
|||||||
// protocol name length (1 byte)
|
// protocol name length (1 byte)
|
||||||
// protocol name bytes
|
// protocol name bytes
|
||||||
let proto_len = proto.len() as u8;
|
let proto_len = proto.len() as u8;
|
||||||
let list_len: u16 = 1 + proto_len as u16;
|
let list_len: u16 = 1 + u16::from(proto_len);
|
||||||
let ext_len: u16 = 2 + list_len;
|
let ext_len: u16 = 2 + list_len;
|
||||||
|
|
||||||
self.extensions.extend_from_slice(&ext_len.to_be_bytes());
|
self.extensions.extend_from_slice(&ext_len.to_be_bytes());
|
||||||
@@ -273,13 +276,86 @@ impl ServerHelloBuilder {
|
|||||||
|
|
||||||
// ============= Public Functions =============
|
// ============= Public Functions =============
|
||||||
|
|
||||||
/// Validate TLS ClientHello against user secrets
|
/// Validate TLS ClientHello against user secrets.
|
||||||
///
|
///
|
||||||
/// Returns validation result if a matching user is found.
|
/// Returns validation result if a matching user is found.
|
||||||
|
/// The result **must** be used — ignoring it silently bypasses authentication.
|
||||||
|
#[must_use]
|
||||||
pub fn validate_tls_handshake(
|
pub fn validate_tls_handshake(
|
||||||
handshake: &[u8],
|
handshake: &[u8],
|
||||||
secrets: &[(String, Vec<u8>)],
|
secrets: &[(String, Vec<u8>)],
|
||||||
ignore_time_skew: bool,
|
ignore_time_skew: bool,
|
||||||
|
) -> Option<TlsValidation> {
|
||||||
|
validate_tls_handshake_with_replay_window(
|
||||||
|
handshake,
|
||||||
|
secrets,
|
||||||
|
ignore_time_skew,
|
||||||
|
u64::from(BOOT_TIME_MAX_SECS),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate TLS ClientHello and cap the boot-time bypass by replay-cache TTL.
|
||||||
|
///
|
||||||
|
/// A boot-time timestamp is only accepted when it falls below both
|
||||||
|
/// `BOOT_TIME_MAX_SECS` and the configured replay window, preventing timestamp
|
||||||
|
/// reuse outside replay cache coverage.
|
||||||
|
#[must_use]
|
||||||
|
pub fn validate_tls_handshake_with_replay_window(
|
||||||
|
handshake: &[u8],
|
||||||
|
secrets: &[(String, Vec<u8>)],
|
||||||
|
ignore_time_skew: bool,
|
||||||
|
replay_window_secs: u64,
|
||||||
|
) -> Option<TlsValidation> {
|
||||||
|
// Only pay the clock syscall when we will actually compare against it.
|
||||||
|
// If `ignore_time_skew` is set, a broken or unavailable system clock
|
||||||
|
// must not block legitimate clients — that would be a DoS via clock failure.
|
||||||
|
let now = if !ignore_time_skew {
|
||||||
|
system_time_to_unix_secs(SystemTime::now())?
|
||||||
|
} else {
|
||||||
|
0_i64
|
||||||
|
};
|
||||||
|
|
||||||
|
let replay_window_u32 = u32::try_from(replay_window_secs).unwrap_or(u32::MAX);
|
||||||
|
let boot_time_cap_secs = BOOT_TIME_MAX_SECS.min(replay_window_u32);
|
||||||
|
|
||||||
|
validate_tls_handshake_at_time_with_boot_cap(
|
||||||
|
handshake,
|
||||||
|
secrets,
|
||||||
|
ignore_time_skew,
|
||||||
|
now,
|
||||||
|
boot_time_cap_secs,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn system_time_to_unix_secs(now: SystemTime) -> Option<i64> {
|
||||||
|
// `try_from` rejects values that overflow i64 (> ~292 billion years CE),
|
||||||
|
// whereas `as i64` would silently wrap to a negative timestamp and corrupt
|
||||||
|
// every subsequent time-skew comparison.
|
||||||
|
let d = now.duration_since(UNIX_EPOCH).ok()?;
|
||||||
|
i64::try_from(d.as_secs()).ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_tls_handshake_at_time(
|
||||||
|
handshake: &[u8],
|
||||||
|
secrets: &[(String, Vec<u8>)],
|
||||||
|
ignore_time_skew: bool,
|
||||||
|
now: i64,
|
||||||
|
) -> Option<TlsValidation> {
|
||||||
|
validate_tls_handshake_at_time_with_boot_cap(
|
||||||
|
handshake,
|
||||||
|
secrets,
|
||||||
|
ignore_time_skew,
|
||||||
|
now,
|
||||||
|
BOOT_TIME_MAX_SECS,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_tls_handshake_at_time_with_boot_cap(
|
||||||
|
handshake: &[u8],
|
||||||
|
secrets: &[(String, Vec<u8>)],
|
||||||
|
ignore_time_skew: bool,
|
||||||
|
now: i64,
|
||||||
|
boot_time_cap_secs: u32,
|
||||||
) -> Option<TlsValidation> {
|
) -> Option<TlsValidation> {
|
||||||
if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 {
|
if handshake.len() < TLS_DIGEST_POS + TLS_DIGEST_LEN + 1 {
|
||||||
return None;
|
return None;
|
||||||
@@ -305,50 +381,56 @@ pub fn validate_tls_handshake(
|
|||||||
let mut msg = handshake.to_vec();
|
let mut msg = handshake.to_vec();
|
||||||
msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
|
msg[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN].fill(0);
|
||||||
|
|
||||||
// Get current time
|
let mut first_match: Option<(&String, u32)> = None;
|
||||||
let now = SystemTime::now()
|
|
||||||
.duration_since(UNIX_EPOCH)
|
|
||||||
.unwrap()
|
|
||||||
.as_secs() as i64;
|
|
||||||
|
|
||||||
for (user, secret) in secrets {
|
for (user, secret) in secrets {
|
||||||
let computed = sha256_hmac(secret, &msg);
|
let computed = sha256_hmac(secret, &msg);
|
||||||
|
|
||||||
// XOR digests
|
// Constant-time equality check on the 28-byte HMAC window.
|
||||||
let xored: Vec<u8> = digest.iter()
|
// A variable-time short-circuit here lets an active censor measure how many
|
||||||
.zip(computed.iter())
|
// bytes matched, enabling secret brute-force via timing side-channels.
|
||||||
.map(|(a, b)| a ^ b)
|
// Direct comparison on the original arrays avoids a heap allocation and
|
||||||
.collect();
|
// removes the `try_into().unwrap()` that the intermediate Vec would require.
|
||||||
|
if !bool::from(digest[..28].ct_eq(&computed[..28])) {
|
||||||
// Check that first 28 bytes are zeros (timestamp in last 4)
|
|
||||||
if !xored[..28].iter().all(|&b| b == 0) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract timestamp
|
// The last 4 bytes encode the timestamp as XOR(digest[28..32], computed[28..32]).
|
||||||
let timestamp = u32::from_le_bytes(xored[28..32].try_into().unwrap());
|
// Inline array construction is infallible: both slices are [u8; 32] by construction.
|
||||||
let time_diff = now - timestamp as i64;
|
let timestamp = u32::from_le_bytes([
|
||||||
|
digest[28] ^ computed[28],
|
||||||
// Check time skew
|
digest[29] ^ computed[29],
|
||||||
|
digest[30] ^ computed[30],
|
||||||
|
digest[31] ^ computed[31],
|
||||||
|
]);
|
||||||
|
|
||||||
|
// time_diff is only meaningful (and `now` is only valid) when we are
|
||||||
|
// actually checking the window. Keep both inside the guard to make
|
||||||
|
// the dead-code path explicit and prevent accidental future use of
|
||||||
|
// a sentinel `now` value outside its intended scope.
|
||||||
if !ignore_time_skew {
|
if !ignore_time_skew {
|
||||||
// Allow very small timestamps (boot time instead of unix time)
|
// Allow very small timestamps (boot time instead of unix time)
|
||||||
// This is a quirk in some clients that use uptime instead of real time
|
// This is a quirk in some clients that use uptime instead of real time
|
||||||
let is_boot_time = timestamp < 60 * 60 * 24 * 1000; // < ~2.7 years in seconds
|
let is_boot_time = timestamp < boot_time_cap_secs;
|
||||||
|
if !is_boot_time {
|
||||||
if !is_boot_time && !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) {
|
let time_diff = now - i64::from(timestamp);
|
||||||
continue;
|
if !(TIME_SKEW_MIN..=TIME_SKEW_MAX).contains(&time_diff) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Some(TlsValidation {
|
if first_match.is_none() {
|
||||||
user: user.clone(),
|
first_match = Some((user, timestamp));
|
||||||
session_id,
|
}
|
||||||
digest,
|
|
||||||
timestamp,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
None
|
first_match.map(|(user, timestamp)| TlsValidation {
|
||||||
|
user: user.clone(),
|
||||||
|
session_id,
|
||||||
|
digest,
|
||||||
|
timestamp,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn curve25519_prime() -> BigUint {
|
fn curve25519_prime() -> BigUint {
|
||||||
@@ -528,7 +610,9 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
|
|||||||
if name_type == 0 && name_len > 0
|
if name_type == 0 && name_len > 0
|
||||||
&& let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len])
|
&& let Ok(host) = std::str::from_utf8(&handshake[sn_pos..sn_pos + name_len])
|
||||||
{
|
{
|
||||||
return Some(host.to_string());
|
if is_valid_sni_hostname(host) {
|
||||||
|
return Some(host.to_string());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
sn_pos += name_len;
|
sn_pos += name_len;
|
||||||
}
|
}
|
||||||
@@ -539,6 +623,35 @@ pub fn extract_sni_from_client_hello(handshake: &[u8]) -> Option<String> {
|
|||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn is_valid_sni_hostname(host: &str) -> bool {
|
||||||
|
if host.is_empty() || host.len() > 253 {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if host.starts_with('.') || host.ends_with('.') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if host.parse::<std::net::IpAddr>().is_ok() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for label in host.split('.') {
|
||||||
|
if label.is_empty() || label.len() > 63 {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if label.starts_with('-') || label.ends_with('-') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if !label
|
||||||
|
.bytes()
|
||||||
|
.all(|b| b.is_ascii_alphanumeric() || b == b'-')
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
/// Extract ALPN protocol list from ClientHello, return in offered order.
|
/// Extract ALPN protocol list from ClientHello, return in offered order.
|
||||||
pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> {
|
pub fn extract_alpn_from_client_hello(handshake: &[u8]) -> Vec<Vec<u8>> {
|
||||||
let mut pos = 5; // after record header
|
let mut pos = 5; // after record header
|
||||||
@@ -667,291 +780,29 @@ fn validate_server_hello_structure(data: &[u8]) -> Result<(), ProxyError> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
// ============= Compile-time Security Invariants =============
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_is_tls_handshake() {
|
|
||||||
assert!(is_tls_handshake(&[0x16, 0x03, 0x01]));
|
|
||||||
assert!(is_tls_handshake(&[0x16, 0x03, 0x01, 0x02, 0x00]));
|
|
||||||
assert!(!is_tls_handshake(&[0x17, 0x03, 0x01])); // Application data
|
|
||||||
assert!(!is_tls_handshake(&[0x16, 0x03, 0x02])); // Wrong version
|
|
||||||
assert!(!is_tls_handshake(&[0x16, 0x03])); // Too short
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_parse_tls_record_header() {
|
|
||||||
let header = [0x16, 0x03, 0x01, 0x02, 0x00];
|
|
||||||
let result = parse_tls_record_header(&header).unwrap();
|
|
||||||
assert_eq!(result.0, TLS_RECORD_HANDSHAKE);
|
|
||||||
assert_eq!(result.1, 512);
|
|
||||||
|
|
||||||
let header = [0x17, 0x03, 0x03, 0x40, 0x00];
|
|
||||||
let result = parse_tls_record_header(&header).unwrap();
|
|
||||||
assert_eq!(result.0, TLS_RECORD_APPLICATION);
|
|
||||||
assert_eq!(result.1, 16384);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_gen_fake_x25519_key() {
|
|
||||||
let rng = SecureRandom::new();
|
|
||||||
let key1 = gen_fake_x25519_key(&rng);
|
|
||||||
let key2 = gen_fake_x25519_key(&rng);
|
|
||||||
|
|
||||||
assert_eq!(key1.len(), 32);
|
|
||||||
assert_eq!(key2.len(), 32);
|
|
||||||
assert_ne!(key1, key2); // Should be random
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
/// Compile-time checks that enforce invariants the rest of the code relies on.
|
||||||
fn test_fake_x25519_key_is_quadratic_residue() {
|
/// Using `static_assertions` ensures these can never silently break across
|
||||||
let rng = SecureRandom::new();
|
/// refactors without a compile error.
|
||||||
let key = gen_fake_x25519_key(&rng);
|
mod compile_time_security_checks {
|
||||||
let p = curve25519_prime();
|
use super::{TLS_DIGEST_LEN, TLS_DIGEST_HALF_LEN};
|
||||||
let k_num = BigUint::from_bytes_le(&key);
|
use static_assertions::const_assert;
|
||||||
let exponent = (&p - BigUint::one()) >> 1;
|
|
||||||
let legendre = k_num.modpow(&exponent, &p);
|
|
||||||
assert_eq!(legendre, BigUint::one());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_tls_extension_builder() {
|
|
||||||
let key = [0x42u8; 32];
|
|
||||||
|
|
||||||
let mut builder = TlsExtensionBuilder::new();
|
|
||||||
builder.add_key_share(&key);
|
|
||||||
builder.add_supported_versions(0x0304);
|
|
||||||
|
|
||||||
let result = builder.build();
|
|
||||||
|
|
||||||
// Check length prefix
|
|
||||||
let len = u16::from_be_bytes([result[0], result[1]]) as usize;
|
|
||||||
assert_eq!(len, result.len() - 2);
|
|
||||||
|
|
||||||
// Check key_share extension is present
|
|
||||||
assert!(result.len() > 40); // At least key share
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_server_hello_builder() {
|
|
||||||
let session_id = vec![0x01, 0x02, 0x03, 0x04];
|
|
||||||
let key = [0x55u8; 32];
|
|
||||||
|
|
||||||
let builder = ServerHelloBuilder::new(session_id.clone())
|
|
||||||
.with_x25519_key(&key)
|
|
||||||
.with_tls13_version();
|
|
||||||
|
|
||||||
let record = builder.build_record();
|
|
||||||
|
|
||||||
// Validate structure
|
|
||||||
validate_server_hello_structure(&record).expect("Invalid ServerHello structure");
|
|
||||||
|
|
||||||
// Check record type
|
|
||||||
assert_eq!(record[0], TLS_RECORD_HANDSHAKE);
|
|
||||||
|
|
||||||
// Check version
|
|
||||||
assert_eq!(&record[1..3], &TLS_VERSION);
|
|
||||||
|
|
||||||
// Check message type (ServerHello = 0x02)
|
|
||||||
assert_eq!(record[5], 0x02);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_build_server_hello_structure() {
|
|
||||||
let secret = b"test secret";
|
|
||||||
let client_digest = [0x42u8; 32];
|
|
||||||
let session_id = vec![0xAA; 32];
|
|
||||||
|
|
||||||
let rng = SecureRandom::new();
|
|
||||||
let response = build_server_hello(secret, &client_digest, &session_id, 2048, &rng, None, 0);
|
|
||||||
|
|
||||||
// Should have at least 3 records
|
|
||||||
assert!(response.len() > 100);
|
|
||||||
|
|
||||||
// First record should be ServerHello
|
|
||||||
assert_eq!(response[0], TLS_RECORD_HANDSHAKE);
|
|
||||||
|
|
||||||
// Validate ServerHello structure
|
|
||||||
validate_server_hello_structure(&response).expect("Invalid ServerHello");
|
|
||||||
|
|
||||||
// Find Change Cipher Spec
|
|
||||||
let server_hello_len = 5 + u16::from_be_bytes([response[3], response[4]]) as usize;
|
|
||||||
let ccs_start = server_hello_len;
|
|
||||||
|
|
||||||
assert!(response.len() > ccs_start + 6);
|
|
||||||
assert_eq!(response[ccs_start], TLS_RECORD_CHANGE_CIPHER);
|
|
||||||
|
|
||||||
// Find Application Data
|
|
||||||
let ccs_len = 5 + u16::from_be_bytes([response[ccs_start + 3], response[ccs_start + 4]]) as usize;
|
|
||||||
let app_start = ccs_start + ccs_len;
|
|
||||||
|
|
||||||
assert!(response.len() > app_start + 5);
|
|
||||||
assert_eq!(response[app_start], TLS_RECORD_APPLICATION);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_build_server_hello_digest() {
|
|
||||||
let secret = b"test secret key here";
|
|
||||||
let client_digest = [0x42u8; 32];
|
|
||||||
let session_id = vec![0xAA; 32];
|
|
||||||
|
|
||||||
let rng = SecureRandom::new();
|
|
||||||
let response1 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0);
|
|
||||||
let response2 = build_server_hello(secret, &client_digest, &session_id, 1024, &rng, None, 0);
|
|
||||||
|
|
||||||
// Digest position should have non-zero data
|
|
||||||
let digest1 = &response1[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];
|
|
||||||
assert!(!digest1.iter().all(|&b| b == 0));
|
|
||||||
|
|
||||||
// Different calls should have different digests (due to random cert)
|
|
||||||
let digest2 = &response2[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN];
|
|
||||||
assert_ne!(digest1, digest2);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_server_hello_extensions_length() {
|
|
||||||
let session_id = vec![0x01; 32];
|
|
||||||
let key = [0x55u8; 32];
|
|
||||||
|
|
||||||
let builder = ServerHelloBuilder::new(session_id)
|
|
||||||
.with_x25519_key(&key)
|
|
||||||
.with_tls13_version();
|
|
||||||
|
|
||||||
let record = builder.build_record();
|
|
||||||
|
|
||||||
// Parse to find extensions
|
|
||||||
let msg_start = 5; // After record header
|
|
||||||
let msg_len = u32::from_be_bytes([0, record[6], record[7], record[8]]) as usize;
|
|
||||||
|
|
||||||
// Skip to session ID
|
|
||||||
let session_id_pos = msg_start + 4 + 2 + 32; // header(4) + version(2) + random(32)
|
|
||||||
let session_id_len = record[session_id_pos] as usize;
|
|
||||||
|
|
||||||
// Skip to extensions
|
|
||||||
let ext_len_pos = session_id_pos + 1 + session_id_len + 2 + 1; // session_id + cipher(2) + compression(1)
|
|
||||||
let ext_len = u16::from_be_bytes([record[ext_len_pos], record[ext_len_pos + 1]]) as usize;
|
|
||||||
|
|
||||||
// Verify extensions length matches actual data
|
|
||||||
let extensions_data = &record[ext_len_pos + 2..msg_start + 4 + msg_len];
|
|
||||||
assert_eq!(ext_len, extensions_data.len(),
|
|
||||||
"Extension length mismatch: declared {}, actual {}", ext_len, extensions_data.len());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_validate_tls_handshake_format() {
|
|
||||||
// Build a minimal ClientHello-like structure
|
|
||||||
let mut handshake = vec![0u8; 100];
|
|
||||||
|
|
||||||
// Put a valid-looking digest at position 11
|
|
||||||
handshake[TLS_DIGEST_POS..TLS_DIGEST_POS + TLS_DIGEST_LEN]
|
|
||||||
.copy_from_slice(&[0x42; 32]);
|
|
||||||
|
|
||||||
// Session ID length
|
|
||||||
handshake[TLS_DIGEST_POS + TLS_DIGEST_LEN] = 32;
|
|
||||||
|
|
||||||
// This won't validate (wrong HMAC) but shouldn't panic
|
|
||||||
let secrets = vec![("test".to_string(), b"secret".to_vec())];
|
|
||||||
let result = validate_tls_handshake(&handshake, &secrets, true);
|
|
||||||
|
|
||||||
// Should return None (no match) but not panic
|
|
||||||
assert!(result.is_none());
|
|
||||||
}
|
|
||||||
|
|
||||||
fn build_client_hello_with_exts(exts: Vec<(u16, Vec<u8>)>, host: &str) -> Vec<u8> {
|
// The digest must be exactly one SHA-256 output.
|
||||||
let mut body = Vec::new();
|
const_assert!(TLS_DIGEST_LEN == 32);
|
||||||
body.extend_from_slice(&TLS_VERSION); // legacy version
|
|
||||||
body.extend_from_slice(&[0u8; 32]); // random
|
|
||||||
body.push(0); // session id len
|
|
||||||
body.extend_from_slice(&2u16.to_be_bytes()); // cipher suites len
|
|
||||||
body.extend_from_slice(&[0x13, 0x01]); // TLS_AES_128_GCM_SHA256
|
|
||||||
body.push(1); // compression len
|
|
||||||
body.push(0); // null compression
|
|
||||||
|
|
||||||
// Build SNI extension
|
// Replay-dedup stores the first half; verify it is literally half.
|
||||||
let host_bytes = host.as_bytes();
|
const_assert!(TLS_DIGEST_HALF_LEN * 2 == TLS_DIGEST_LEN);
|
||||||
let mut sni_ext = Vec::new();
|
|
||||||
sni_ext.extend_from_slice(&(host_bytes.len() as u16 + 3).to_be_bytes());
|
|
||||||
sni_ext.push(0);
|
|
||||||
sni_ext.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes());
|
|
||||||
sni_ext.extend_from_slice(host_bytes);
|
|
||||||
|
|
||||||
let mut ext_blob = Vec::new();
|
// The HMAC check window (28 bytes) plus the embedded timestamp (4 bytes)
|
||||||
for (typ, data) in exts {
|
// must exactly fill the digest. If TLS_DIGEST_LEN ever changes, these
|
||||||
ext_blob.extend_from_slice(&typ.to_be_bytes());
|
// assertions will catch the mismatch before any timing-oracle fix is broke.
|
||||||
ext_blob.extend_from_slice(&(data.len() as u16).to_be_bytes());
|
const_assert!(28 + 4 == TLS_DIGEST_LEN);
|
||||||
ext_blob.extend_from_slice(&data);
|
|
||||||
}
|
|
||||||
// SNI last
|
|
||||||
ext_blob.extend_from_slice(&0x0000u16.to_be_bytes());
|
|
||||||
ext_blob.extend_from_slice(&(sni_ext.len() as u16).to_be_bytes());
|
|
||||||
ext_blob.extend_from_slice(&sni_ext);
|
|
||||||
|
|
||||||
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
|
|
||||||
body.extend_from_slice(&ext_blob);
|
|
||||||
|
|
||||||
let mut handshake = Vec::new();
|
|
||||||
handshake.push(0x01); // ClientHello
|
|
||||||
let len_bytes = (body.len() as u32).to_be_bytes();
|
|
||||||
handshake.extend_from_slice(&len_bytes[1..4]);
|
|
||||||
handshake.extend_from_slice(&body);
|
|
||||||
|
|
||||||
let mut record = Vec::new();
|
|
||||||
record.push(TLS_RECORD_HANDSHAKE);
|
|
||||||
record.extend_from_slice(&[0x03, 0x01]);
|
|
||||||
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
|
|
||||||
record.extend_from_slice(&handshake);
|
|
||||||
record
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_extract_sni_with_grease_extension() {
|
|
||||||
// GREASE type 0x0a0a with zero length before SNI
|
|
||||||
let ch = build_client_hello_with_exts(vec![(0x0a0a, Vec::new())], "example.com");
|
|
||||||
let sni = extract_sni_from_client_hello(&ch);
|
|
||||||
assert_eq!(sni.as_deref(), Some("example.com"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_extract_sni_tolerates_empty_unknown_extension() {
|
|
||||||
let ch = build_client_hello_with_exts(vec![(0x1234, Vec::new())], "test.local");
|
|
||||||
let sni = extract_sni_from_client_hello(&ch);
|
|
||||||
assert_eq!(sni.as_deref(), Some("test.local"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_extract_alpn_single() {
|
|
||||||
let mut alpn_data = Vec::new();
|
|
||||||
// list length = 3 (1 length byte + "h2")
|
|
||||||
alpn_data.extend_from_slice(&3u16.to_be_bytes());
|
|
||||||
alpn_data.push(2);
|
|
||||||
alpn_data.extend_from_slice(b"h2");
|
|
||||||
let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test");
|
|
||||||
let alpn = extract_alpn_from_client_hello(&ch);
|
|
||||||
let alpn_str: Vec<String> = alpn
|
|
||||||
.iter()
|
|
||||||
.map(|p| std::str::from_utf8(p).unwrap().to_string())
|
|
||||||
.collect();
|
|
||||||
assert_eq!(alpn_str, vec!["h2"]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_extract_alpn_multiple() {
|
|
||||||
let mut alpn_data = Vec::new();
|
|
||||||
// list length = 11 (sum of per-proto lengths including length bytes)
|
|
||||||
alpn_data.extend_from_slice(&11u16.to_be_bytes());
|
|
||||||
alpn_data.push(2);
|
|
||||||
alpn_data.extend_from_slice(b"h2");
|
|
||||||
alpn_data.push(4);
|
|
||||||
alpn_data.extend_from_slice(b"spdy");
|
|
||||||
alpn_data.push(2);
|
|
||||||
alpn_data.extend_from_slice(b"h3");
|
|
||||||
let ch = build_client_hello_with_exts(vec![(0x0010, alpn_data)], "alpn.test");
|
|
||||||
let alpn = extract_alpn_from_client_hello(&ch);
|
|
||||||
let alpn_str: Vec<String> = alpn
|
|
||||||
.iter()
|
|
||||||
.map(|p| std::str::from_utf8(p).unwrap().to_string())
|
|
||||||
.collect();
|
|
||||||
assert_eq!(alpn_str, vec!["h2", "spdy", "h3"]);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============= Security-focused regression tests =============
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "tls_security_tests.rs"]
|
||||||
|
mod security_tests;
|
||||||
|
|||||||
1396
src/protocol/tls_security_tests.rs
Normal file
1396
src/protocol/tls_security_tests.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,10 @@ use std::future::Future;
|
|||||||
use std::net::{IpAddr, SocketAddr};
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::OnceLock;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use ipnetwork::IpNetwork;
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
@@ -23,7 +26,7 @@ enum HandshakeOutcome {
|
|||||||
|
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
use crate::crypto::SecureRandom;
|
use crate::crypto::SecureRandom;
|
||||||
use crate::error::{HandshakeResult, ProxyError, Result};
|
use crate::error::{HandshakeResult, ProxyError, Result, StreamError};
|
||||||
use crate::ip_tracker::UserIpTracker;
|
use crate::ip_tracker::UserIpTracker;
|
||||||
use crate::protocol::constants::*;
|
use crate::protocol::constants::*;
|
||||||
use crate::protocol::tls;
|
use crate::protocol::tls;
|
||||||
@@ -63,14 +66,30 @@ fn record_handshake_failure_class(
|
|||||||
peer_ip: IpAddr,
|
peer_ip: IpAddr,
|
||||||
error: &ProxyError,
|
error: &ProxyError,
|
||||||
) {
|
) {
|
||||||
let class = if error.to_string().contains("expected 64 bytes, got 0") {
|
let class = match error {
|
||||||
"expected_64_got_0"
|
ProxyError::Io(err) if err.kind() == std::io::ErrorKind::UnexpectedEof => {
|
||||||
} else {
|
"expected_64_got_0"
|
||||||
"other"
|
}
|
||||||
|
ProxyError::Stream(StreamError::UnexpectedEof) => "expected_64_got_0",
|
||||||
|
_ => "other",
|
||||||
};
|
};
|
||||||
record_beobachten_class(beobachten, config, peer_ip, class);
|
record_beobachten_class(beobachten, config, peer_ip, class);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool {
|
||||||
|
if trusted.is_empty() {
|
||||||
|
static EMPTY_PROXY_TRUST_WARNED: OnceLock<AtomicBool> = OnceLock::new();
|
||||||
|
let warned = EMPTY_PROXY_TRUST_WARNED.get_or_init(|| AtomicBool::new(false));
|
||||||
|
if !warned.swap(true, Ordering::Relaxed) {
|
||||||
|
warn!(
|
||||||
|
"PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers by default"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
trusted.iter().any(|cidr| cidr.contains(peer_ip))
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn handle_client_stream<S>(
|
pub async fn handle_client_stream<S>(
|
||||||
mut stream: S,
|
mut stream: S,
|
||||||
peer: SocketAddr,
|
peer: SocketAddr,
|
||||||
@@ -104,6 +123,17 @@ where
|
|||||||
);
|
);
|
||||||
match timeout(proxy_header_timeout, parse_proxy_protocol(&mut stream, peer)).await {
|
match timeout(proxy_header_timeout, parse_proxy_protocol(&mut stream, peer)).await {
|
||||||
Ok(Ok(info)) => {
|
Ok(Ok(info)) => {
|
||||||
|
if !is_trusted_proxy_source(peer.ip(), &config.server.proxy_protocol_trusted_cidrs)
|
||||||
|
{
|
||||||
|
stats.increment_connects_bad();
|
||||||
|
warn!(
|
||||||
|
peer = %peer,
|
||||||
|
trusted = ?config.server.proxy_protocol_trusted_cidrs,
|
||||||
|
"Rejecting PROXY protocol header from untrusted source"
|
||||||
|
);
|
||||||
|
record_beobachten_class(&beobachten, &config, peer.ip(), "other");
|
||||||
|
return Err(ProxyError::InvalidProxyProtocol);
|
||||||
|
}
|
||||||
debug!(
|
debug!(
|
||||||
peer = %peer,
|
peer = %peer,
|
||||||
client = %info.src_addr,
|
client = %info.src_addr,
|
||||||
@@ -149,8 +179,13 @@ where
|
|||||||
if is_tls {
|
if is_tls {
|
||||||
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
let tls_len = u16::from_be_bytes([first_bytes[3], first_bytes[4]]) as usize;
|
||||||
|
|
||||||
if tls_len < 512 {
|
// RFC 8446 §5.1 mandates that TLSPlaintext records must not exceed 2^14
|
||||||
debug!(peer = %real_peer, tls_len = tls_len, "TLS handshake too short");
|
// bytes (16_384). A client claiming a larger record is non-compliant and
|
||||||
|
// may be an active probe attempting to force large allocations.
|
||||||
|
//
|
||||||
|
// Also enforce a minimum record size to avoid trivial/garbage probes.
|
||||||
|
if !(512..=MAX_TLS_RECORD_SIZE).contains(&tls_len) {
|
||||||
|
debug!(peer = %real_peer, tls_len = tls_len, max_tls_len = MAX_TLS_RECORD_SIZE, "TLS handshake length out of bounds");
|
||||||
stats.increment_connects_bad();
|
stats.increment_connects_bad();
|
||||||
let (reader, writer) = tokio::io::split(stream);
|
let (reader, writer) = tokio::io::split(stream);
|
||||||
handle_bad_client(
|
handle_bad_client(
|
||||||
@@ -204,9 +239,19 @@ where
|
|||||||
&config, &replay_checker, true, Some(tls_user.as_str()),
|
&config, &replay_checker, true, Some(tls_user.as_str()),
|
||||||
).await {
|
).await {
|
||||||
HandshakeResult::Success(result) => result,
|
HandshakeResult::Success(result) => result,
|
||||||
HandshakeResult::BadClient { reader: _, writer: _ } => {
|
HandshakeResult::BadClient { reader, writer } => {
|
||||||
stats.increment_connects_bad();
|
stats.increment_connects_bad();
|
||||||
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
|
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
|
||||||
|
handle_bad_client(
|
||||||
|
reader,
|
||||||
|
writer,
|
||||||
|
&mtproto_handshake,
|
||||||
|
real_peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
return Ok(HandshakeOutcome::Handled);
|
return Ok(HandshakeOutcome::Handled);
|
||||||
}
|
}
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
@@ -445,6 +490,24 @@ impl RunningClientHandler {
|
|||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(Ok(info)) => {
|
Ok(Ok(info)) => {
|
||||||
|
if !is_trusted_proxy_source(
|
||||||
|
self.peer.ip(),
|
||||||
|
&self.config.server.proxy_protocol_trusted_cidrs,
|
||||||
|
) {
|
||||||
|
self.stats.increment_connects_bad();
|
||||||
|
warn!(
|
||||||
|
peer = %self.peer,
|
||||||
|
trusted = ?self.config.server.proxy_protocol_trusted_cidrs,
|
||||||
|
"Rejecting PROXY protocol header from untrusted source"
|
||||||
|
);
|
||||||
|
record_beobachten_class(
|
||||||
|
&self.beobachten,
|
||||||
|
&self.config,
|
||||||
|
self.peer.ip(),
|
||||||
|
"other",
|
||||||
|
);
|
||||||
|
return Err(ProxyError::InvalidProxyProtocol);
|
||||||
|
}
|
||||||
debug!(
|
debug!(
|
||||||
peer = %self.peer,
|
peer = %self.peer,
|
||||||
client = %info.src_addr,
|
client = %info.src_addr,
|
||||||
@@ -513,8 +576,10 @@ impl RunningClientHandler {
|
|||||||
|
|
||||||
debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake");
|
debug!(peer = %peer, tls_len = tls_len, "Reading TLS handshake");
|
||||||
|
|
||||||
if tls_len < 512 {
|
// See RFC 8446 §5.1: TLSPlaintext records must not exceed 16_384 bytes.
|
||||||
debug!(peer = %peer, tls_len = tls_len, "TLS handshake too short");
|
// Treat too-small or too-large lengths as active probes and mask them.
|
||||||
|
if !(512..=MAX_TLS_RECORD_SIZE).contains(&tls_len) {
|
||||||
|
debug!(peer = %peer, tls_len = tls_len, max_tls_len = MAX_TLS_RECORD_SIZE, "TLS handshake length out of bounds");
|
||||||
self.stats.increment_connects_bad();
|
self.stats.increment_connects_bad();
|
||||||
let (reader, writer) = self.stream.into_split();
|
let (reader, writer) = self.stream.into_split();
|
||||||
handle_bad_client(
|
handle_bad_client(
|
||||||
@@ -590,12 +655,19 @@ impl RunningClientHandler {
|
|||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
HandshakeResult::Success(result) => result,
|
HandshakeResult::Success(result) => result,
|
||||||
HandshakeResult::BadClient {
|
HandshakeResult::BadClient { reader, writer } => {
|
||||||
reader: _,
|
|
||||||
writer: _,
|
|
||||||
} => {
|
|
||||||
stats.increment_connects_bad();
|
stats.increment_connects_bad();
|
||||||
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
|
debug!(peer = %peer, "Valid TLS but invalid MTProto handshake");
|
||||||
|
handle_bad_client(
|
||||||
|
reader,
|
||||||
|
writer,
|
||||||
|
&mtproto_handshake,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&self.beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
return Ok(HandshakeOutcome::Handled);
|
return Ok(HandshakeOutcome::Handled);
|
||||||
}
|
}
|
||||||
HandshakeResult::Error(e) => return Err(e),
|
HandshakeResult::Error(e) => return Err(e),
|
||||||
@@ -742,7 +814,7 @@ impl RunningClientHandler {
|
|||||||
client_writer,
|
client_writer,
|
||||||
success,
|
success,
|
||||||
pool.clone(),
|
pool.clone(),
|
||||||
stats,
|
stats.clone(),
|
||||||
config,
|
config,
|
||||||
buffer_pool,
|
buffer_pool,
|
||||||
local_addr,
|
local_addr,
|
||||||
@@ -759,7 +831,7 @@ impl RunningClientHandler {
|
|||||||
client_writer,
|
client_writer,
|
||||||
success,
|
success,
|
||||||
upstream_manager,
|
upstream_manager,
|
||||||
stats,
|
stats.clone(),
|
||||||
config,
|
config,
|
||||||
buffer_pool,
|
buffer_pool,
|
||||||
rng,
|
rng,
|
||||||
@@ -776,7 +848,7 @@ impl RunningClientHandler {
|
|||||||
client_writer,
|
client_writer,
|
||||||
success,
|
success,
|
||||||
upstream_manager,
|
upstream_manager,
|
||||||
stats,
|
stats.clone(),
|
||||||
config,
|
config,
|
||||||
buffer_pool,
|
buffer_pool,
|
||||||
rng,
|
rng,
|
||||||
@@ -787,6 +859,7 @@ impl RunningClientHandler {
|
|||||||
.await
|
.await
|
||||||
};
|
};
|
||||||
|
|
||||||
|
stats.decrement_user_curr_connects(&user);
|
||||||
ip_tracker.remove_ip(&user, peer_addr.ip()).await;
|
ip_tracker.remove_ip(&user, peer_addr.ip()).await;
|
||||||
relay_result
|
relay_result
|
||||||
}
|
}
|
||||||
@@ -806,9 +879,29 @@ impl RunningClientHandler {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let ip_reserved = match ip_tracker.check_and_add(user, peer_addr.ip()).await {
|
if let Some(quota) = config.access.user_data_quota.get(user)
|
||||||
Ok(()) => true,
|
&& stats.get_user_total_octets(user) >= *quota
|
||||||
|
{
|
||||||
|
return Err(ProxyError::DataQuotaExceeded {
|
||||||
|
user: user.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let limit = config
|
||||||
|
.access
|
||||||
|
.user_max_tcp_conns
|
||||||
|
.get(user)
|
||||||
|
.map(|v| *v as u64);
|
||||||
|
if !stats.try_acquire_user_curr_connects(user, limit) {
|
||||||
|
return Err(ProxyError::ConnectionLimitExceeded {
|
||||||
|
user: user.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
|
||||||
|
Ok(()) => {}
|
||||||
Err(reason) => {
|
Err(reason) => {
|
||||||
|
stats.decrement_user_curr_connects(user);
|
||||||
warn!(
|
warn!(
|
||||||
user = %user,
|
user = %user,
|
||||||
ip = %peer_addr.ip(),
|
ip = %peer_addr.ip(),
|
||||||
@@ -819,33 +912,12 @@ impl RunningClientHandler {
|
|||||||
user: user.to_string(),
|
user: user.to_string(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
|
||||||
// IP limit check
|
|
||||||
|
|
||||||
if let Some(limit) = config.access.user_max_tcp_conns.get(user)
|
|
||||||
&& stats.get_user_curr_connects(user) >= *limit as u64
|
|
||||||
{
|
|
||||||
if ip_reserved {
|
|
||||||
ip_tracker.remove_ip(user, peer_addr.ip()).await;
|
|
||||||
stats.increment_ip_reservation_rollback_tcp_limit_total();
|
|
||||||
}
|
|
||||||
return Err(ProxyError::ConnectionLimitExceeded {
|
|
||||||
user: user.to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(quota) = config.access.user_data_quota.get(user)
|
|
||||||
&& stats.get_user_total_octets(user) >= *quota
|
|
||||||
{
|
|
||||||
if ip_reserved {
|
|
||||||
ip_tracker.remove_ip(user, peer_addr.ip()).await;
|
|
||||||
stats.increment_ip_reservation_rollback_quota_limit_total();
|
|
||||||
}
|
|
||||||
return Err(ProxyError::DataQuotaExceeded {
|
|
||||||
user: user.to_string(),
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "client_security_tests.rs"]
|
||||||
|
mod security_tests;
|
||||||
|
|||||||
2071
src/proxy/client_security_tests.rs
Normal file
2071
src/proxy/client_security_tests.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,8 @@ use std::fs::OpenOptions;
|
|||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::sync::{Mutex, OnceLock};
|
||||||
|
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
@@ -22,6 +24,45 @@ use crate::stats::Stats;
|
|||||||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
||||||
use crate::transport::UpstreamManager;
|
use crate::transport::UpstreamManager;
|
||||||
|
|
||||||
|
const UNKNOWN_DC_LOG_DISTINCT_LIMIT: usize = 1024;
|
||||||
|
static LOGGED_UNKNOWN_DCS: OnceLock<Mutex<HashSet<i16>>> = OnceLock::new();
|
||||||
|
|
||||||
|
// In tests, this function shares global mutable state. Callers that also use
|
||||||
|
// cache-reset helpers must hold `unknown_dc_test_lock()` to keep assertions
|
||||||
|
// deterministic under parallel execution.
|
||||||
|
fn should_log_unknown_dc(dc_idx: i16) -> bool {
|
||||||
|
let set = LOGGED_UNKNOWN_DCS.get_or_init(|| Mutex::new(HashSet::new()));
|
||||||
|
match set.lock() {
|
||||||
|
Ok(mut guard) => {
|
||||||
|
if guard.contains(&dc_idx) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if guard.len() >= UNKNOWN_DC_LOG_DISTINCT_LIMIT {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
guard.insert(dc_idx)
|
||||||
|
}
|
||||||
|
// If the lock is poisoned, keep logging rather than silently dropping
|
||||||
|
// operator-visible diagnostics.
|
||||||
|
Err(_) => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn clear_unknown_dc_log_cache_for_testing() {
|
||||||
|
if let Some(set) = LOGGED_UNKNOWN_DCS.get()
|
||||||
|
&& let Ok(mut guard) = set.lock()
|
||||||
|
{
|
||||||
|
guard.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn unknown_dc_test_lock() -> &'static Mutex<()> {
|
||||||
|
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||||
|
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) async fn handle_via_direct<R, W>(
|
pub(crate) async fn handle_via_direct<R, W>(
|
||||||
client_reader: CryptoReader<R>,
|
client_reader: CryptoReader<R>,
|
||||||
client_writer: CryptoWriter<W>,
|
client_writer: CryptoWriter<W>,
|
||||||
@@ -64,7 +105,6 @@ where
|
|||||||
debug!(peer = %success.peer, "TG handshake complete, starting relay");
|
debug!(peer = %success.peer, "TG handshake complete, starting relay");
|
||||||
|
|
||||||
stats.increment_user_connects(user);
|
stats.increment_user_connects(user);
|
||||||
stats.increment_user_curr_connects(user);
|
|
||||||
stats.increment_current_connections_direct();
|
stats.increment_current_connections_direct();
|
||||||
|
|
||||||
let relay_result = relay_bidirectional(
|
let relay_result = relay_bidirectional(
|
||||||
@@ -109,7 +149,6 @@ where
|
|||||||
};
|
};
|
||||||
|
|
||||||
stats.decrement_current_connections_direct();
|
stats.decrement_current_connections_direct();
|
||||||
stats.decrement_user_curr_connects(user);
|
|
||||||
|
|
||||||
match &relay_result {
|
match &relay_result {
|
||||||
Ok(()) => debug!(user = %user, "Direct relay completed"),
|
Ok(()) => debug!(user = %user, "Direct relay completed"),
|
||||||
@@ -160,6 +199,7 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
|
|||||||
warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster");
|
warn!(dc_idx = dc_idx, "Requested non-standard DC with no override; falling back to default cluster");
|
||||||
if config.general.unknown_dc_file_log_enabled
|
if config.general.unknown_dc_file_log_enabled
|
||||||
&& let Some(path) = &config.general.unknown_dc_log_path
|
&& let Some(path) = &config.general.unknown_dc_log_path
|
||||||
|
&& should_log_unknown_dc(dc_idx)
|
||||||
&& let Ok(handle) = tokio::runtime::Handle::try_current()
|
&& let Ok(handle) = tokio::runtime::Handle::try_current()
|
||||||
{
|
{
|
||||||
let path = path.clone();
|
let path = path.clone();
|
||||||
@@ -175,7 +215,7 @@ fn get_dc_addr_static(dc_idx: i16, config: &ProxyConfig) -> Result<SocketAddr> {
|
|||||||
let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs {
|
let fallback_idx = if default_dc >= 1 && default_dc <= num_dcs {
|
||||||
default_dc - 1
|
default_dc - 1
|
||||||
} else {
|
} else {
|
||||||
1
|
0
|
||||||
};
|
};
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
@@ -203,8 +243,6 @@ async fn do_tg_handshake_static(
|
|||||||
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce(
|
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) = generate_tg_nonce(
|
||||||
success.proto_tag,
|
success.proto_tag,
|
||||||
success.dc_idx,
|
success.dc_idx,
|
||||||
&success.dec_key,
|
|
||||||
success.dec_iv,
|
|
||||||
&success.enc_key,
|
&success.enc_key,
|
||||||
success.enc_iv,
|
success.enc_iv,
|
||||||
rng,
|
rng,
|
||||||
@@ -230,3 +268,7 @@ async fn do_tg_handshake_static(
|
|||||||
CryptoWriter::new(write_half, tg_encryptor, max_pending),
|
CryptoWriter::new(write_half, tg_encryptor, max_pending),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "direct_relay_security_tests.rs"]
|
||||||
|
mod security_tests;
|
||||||
|
|||||||
51
src/proxy/direct_relay_security_tests.rs
Normal file
51
src/proxy/direct_relay_security_tests.rs
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unknown_dc_log_is_deduplicated_per_dc_idx() {
|
||||||
|
let _guard = unknown_dc_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("unknown dc test lock must be available");
|
||||||
|
clear_unknown_dc_log_cache_for_testing();
|
||||||
|
|
||||||
|
assert!(should_log_unknown_dc(777));
|
||||||
|
assert!(
|
||||||
|
!should_log_unknown_dc(777),
|
||||||
|
"same unknown dc_idx must not be logged repeatedly"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
should_log_unknown_dc(778),
|
||||||
|
"different unknown dc_idx must still be loggable"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unknown_dc_log_respects_distinct_limit() {
|
||||||
|
let _guard = unknown_dc_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("unknown dc test lock must be available");
|
||||||
|
clear_unknown_dc_log_cache_for_testing();
|
||||||
|
|
||||||
|
for dc in 1..=UNKNOWN_DC_LOG_DISTINCT_LIMIT {
|
||||||
|
assert!(
|
||||||
|
should_log_unknown_dc(dc as i16),
|
||||||
|
"expected first-time unknown dc_idx to be loggable"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!should_log_unknown_dc(i16::MAX),
|
||||||
|
"distinct unknown dc_idx entries above limit must not be logged"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fallback_dc_never_panics_with_single_dc_list() {
|
||||||
|
let mut cfg = ProxyConfig::default();
|
||||||
|
cfg.network.prefer = 6;
|
||||||
|
cfg.network.ipv6 = Some(true);
|
||||||
|
cfg.default_dc = Some(42);
|
||||||
|
|
||||||
|
let addr = get_dc_addr_static(999, &cfg).expect("fallback dc must resolve safely");
|
||||||
|
let expected = SocketAddr::new(TG_DATACENTERS_V6[0], TG_DATACENTER_PORT);
|
||||||
|
assert_eq!(addr, expected);
|
||||||
|
}
|
||||||
@@ -3,8 +3,15 @@
|
|||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::net::{IpAddr, Ipv6Addr};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::sync::{Mutex, OnceLock};
|
||||||
|
use std::collections::hash_map::DefaultHasher;
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use dashmap::mapref::entry::Entry;
|
||||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||||
use tracing::{debug, warn, trace};
|
use tracing::{debug, warn, trace};
|
||||||
use zeroize::Zeroize;
|
use zeroize::Zeroize;
|
||||||
@@ -19,6 +26,272 @@ use crate::stats::ReplayChecker;
|
|||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
use crate::tls_front::{TlsFrontCache, emulator};
|
use crate::tls_front::{TlsFrontCache, emulator};
|
||||||
|
|
||||||
|
const ACCESS_SECRET_BYTES: usize = 16;
|
||||||
|
static INVALID_SECRET_WARNED: OnceLock<Mutex<HashSet<(String, String)>>> = OnceLock::new();
|
||||||
|
|
||||||
|
const AUTH_PROBE_TRACK_RETENTION_SECS: u64 = 10 * 60;
|
||||||
|
#[cfg(test)]
|
||||||
|
const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 256;
|
||||||
|
#[cfg(not(test))]
|
||||||
|
const AUTH_PROBE_TRACK_MAX_ENTRIES: usize = 65_536;
|
||||||
|
const AUTH_PROBE_PRUNE_SCAN_LIMIT: usize = 1_024;
|
||||||
|
const AUTH_PROBE_BACKOFF_START_FAILS: u32 = 4;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 1;
|
||||||
|
#[cfg(not(test))]
|
||||||
|
const AUTH_PROBE_BACKOFF_BASE_MS: u64 = 25;
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 16;
|
||||||
|
#[cfg(not(test))]
|
||||||
|
const AUTH_PROBE_BACKOFF_MAX_MS: u64 = 1_000;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
struct AuthProbeState {
|
||||||
|
fail_streak: u32,
|
||||||
|
blocked_until: Instant,
|
||||||
|
last_seen: Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
static AUTH_PROBE_STATE: OnceLock<DashMap<IpAddr, AuthProbeState>> = OnceLock::new();
|
||||||
|
|
||||||
|
fn auth_probe_state_map() -> &'static DashMap<IpAddr, AuthProbeState> {
|
||||||
|
AUTH_PROBE_STATE.get_or_init(DashMap::new)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_auth_probe_ip(peer_ip: IpAddr) -> IpAddr {
|
||||||
|
match peer_ip {
|
||||||
|
IpAddr::V4(ip) => IpAddr::V4(ip),
|
||||||
|
IpAddr::V6(ip) => {
|
||||||
|
let [a, b, c, d, _, _, _, _] = ip.segments();
|
||||||
|
IpAddr::V6(Ipv6Addr::new(a, b, c, d, 0, 0, 0, 0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth_probe_backoff(fail_streak: u32) -> Duration {
|
||||||
|
if fail_streak < AUTH_PROBE_BACKOFF_START_FAILS {
|
||||||
|
return Duration::ZERO;
|
||||||
|
}
|
||||||
|
let shift = (fail_streak - AUTH_PROBE_BACKOFF_START_FAILS).min(10);
|
||||||
|
let multiplier = 1u64.checked_shl(shift).unwrap_or(u64::MAX);
|
||||||
|
let ms = AUTH_PROBE_BACKOFF_BASE_MS
|
||||||
|
.saturating_mul(multiplier)
|
||||||
|
.min(AUTH_PROBE_BACKOFF_MAX_MS);
|
||||||
|
Duration::from_millis(ms)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth_probe_state_expired(state: &AuthProbeState, now: Instant) -> bool {
|
||||||
|
let retention = Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS);
|
||||||
|
now.duration_since(state.last_seen) > retention
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
|
||||||
|
let mut hasher = DefaultHasher::new();
|
||||||
|
peer_ip.hash(&mut hasher);
|
||||||
|
now.hash(&mut hasher);
|
||||||
|
hasher.finish() as usize
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
|
||||||
|
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||||
|
let state = auth_probe_state_map();
|
||||||
|
let Some(entry) = state.get(&peer_ip) else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
if auth_probe_state_expired(&entry, now) {
|
||||||
|
drop(entry);
|
||||||
|
state.remove(&peer_ip);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
now < entry.blocked_until
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth_probe_record_failure(peer_ip: IpAddr, now: Instant) {
|
||||||
|
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||||
|
let state = auth_probe_state_map();
|
||||||
|
auth_probe_record_failure_with_state(state, peer_ip, now);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth_probe_record_failure_with_state(
|
||||||
|
state: &DashMap<IpAddr, AuthProbeState>,
|
||||||
|
peer_ip: IpAddr,
|
||||||
|
now: Instant,
|
||||||
|
) {
|
||||||
|
let make_new_state = || AuthProbeState {
|
||||||
|
fail_streak: 1,
|
||||||
|
blocked_until: now + auth_probe_backoff(1),
|
||||||
|
last_seen: now,
|
||||||
|
};
|
||||||
|
|
||||||
|
let update_existing = |entry: &mut AuthProbeState| {
|
||||||
|
if auth_probe_state_expired(entry, now) {
|
||||||
|
*entry = make_new_state();
|
||||||
|
} else {
|
||||||
|
entry.fail_streak = entry.fail_streak.saturating_add(1);
|
||||||
|
entry.last_seen = now;
|
||||||
|
entry.blocked_until = now + auth_probe_backoff(entry.fail_streak);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match state.entry(peer_ip) {
|
||||||
|
Entry::Occupied(mut entry) => {
|
||||||
|
update_existing(entry.get_mut());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Entry::Vacant(_) => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||||
|
let mut stale_keys = Vec::new();
|
||||||
|
let mut eviction_candidates = Vec::new();
|
||||||
|
for entry in state.iter().take(AUTH_PROBE_PRUNE_SCAN_LIMIT) {
|
||||||
|
eviction_candidates.push(*entry.key());
|
||||||
|
if auth_probe_state_expired(entry.value(), now) {
|
||||||
|
stale_keys.push(*entry.key());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for stale_key in stale_keys {
|
||||||
|
state.remove(&stale_key);
|
||||||
|
}
|
||||||
|
if state.len() >= AUTH_PROBE_TRACK_MAX_ENTRIES {
|
||||||
|
if eviction_candidates.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let idx = auth_probe_eviction_offset(peer_ip, now) % eviction_candidates.len();
|
||||||
|
let evict_key = eviction_candidates[idx];
|
||||||
|
state.remove(&evict_key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match state.entry(peer_ip) {
|
||||||
|
Entry::Occupied(mut entry) => {
|
||||||
|
update_existing(entry.get_mut());
|
||||||
|
}
|
||||||
|
Entry::Vacant(entry) => {
|
||||||
|
entry.insert(make_new_state());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn auth_probe_record_success(peer_ip: IpAddr) {
|
||||||
|
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||||
|
let state = auth_probe_state_map();
|
||||||
|
state.remove(&peer_ip);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn clear_auth_probe_state_for_testing() {
|
||||||
|
if let Some(state) = AUTH_PROBE_STATE.get() {
|
||||||
|
state.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn auth_probe_fail_streak_for_testing(peer_ip: IpAddr) -> Option<u32> {
|
||||||
|
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||||
|
let state = AUTH_PROBE_STATE.get()?;
|
||||||
|
state.get(&peer_ip).map(|entry| entry.fail_streak)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn auth_probe_is_throttled_for_testing(peer_ip: IpAddr) -> bool {
|
||||||
|
auth_probe_is_throttled(peer_ip, Instant::now())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn auth_probe_test_lock() -> &'static Mutex<()> {
|
||||||
|
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||||
|
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn clear_warned_secrets_for_testing() {
|
||||||
|
if let Some(warned) = INVALID_SECRET_WARNED.get()
|
||||||
|
&& let Ok(mut guard) = warned.lock()
|
||||||
|
{
|
||||||
|
guard.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn warned_secrets_test_lock() -> &'static Mutex<()> {
|
||||||
|
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||||
|
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn warn_invalid_secret_once(name: &str, reason: &str, expected: usize, got: Option<usize>) {
|
||||||
|
let key = (name.to_string(), reason.to_string());
|
||||||
|
let warned = INVALID_SECRET_WARNED.get_or_init(|| Mutex::new(HashSet::new()));
|
||||||
|
let should_warn = match warned.lock() {
|
||||||
|
Ok(mut guard) => guard.insert(key),
|
||||||
|
Err(_) => true,
|
||||||
|
};
|
||||||
|
|
||||||
|
if !should_warn {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
match got {
|
||||||
|
Some(actual) => {
|
||||||
|
warn!(
|
||||||
|
user = %name,
|
||||||
|
expected = expected,
|
||||||
|
got = actual,
|
||||||
|
"Skipping user: access secret has unexpected length"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
warn!(
|
||||||
|
user = %name,
|
||||||
|
"Skipping user: access secret is not valid hex"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode_user_secret(name: &str, secret_hex: &str) -> Option<Vec<u8>> {
|
||||||
|
match hex::decode(secret_hex) {
|
||||||
|
Ok(bytes) if bytes.len() == ACCESS_SECRET_BYTES => Some(bytes),
|
||||||
|
Ok(bytes) => {
|
||||||
|
warn_invalid_secret_once(
|
||||||
|
name,
|
||||||
|
"invalid_length",
|
||||||
|
ACCESS_SECRET_BYTES,
|
||||||
|
Some(bytes.len()),
|
||||||
|
);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
warn_invalid_secret_once(name, "invalid_hex", ACCESS_SECRET_BYTES, None);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decide whether a client-supplied proto tag is allowed given the configured
|
||||||
|
// proxy modes and the transport that carried the handshake.
|
||||||
|
//
|
||||||
|
// A common mistake is to treat `modes.tls` and `modes.secure` as interchangeable
|
||||||
|
// even though they correspond to different transport profiles: `modes.tls` is
|
||||||
|
// for the TLS-fronted (EE-TLS) path, while `modes.secure` is for direct MTProto
|
||||||
|
// over TCP (DD). Enforcing this separation prevents an attacker from using a
|
||||||
|
// TLS-capable client to bypass the operator intent for the direct MTProto mode,
|
||||||
|
// and vice versa.
|
||||||
|
fn mode_enabled_for_proto(config: &ProxyConfig, proto_tag: ProtoTag, is_tls: bool) -> bool {
|
||||||
|
match proto_tag {
|
||||||
|
ProtoTag::Secure => {
|
||||||
|
if is_tls {
|
||||||
|
config.general.modes.tls
|
||||||
|
} else {
|
||||||
|
config.general.modes.secure
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn decode_user_secrets(
|
fn decode_user_secrets(
|
||||||
config: &ProxyConfig,
|
config: &ProxyConfig,
|
||||||
preferred_user: Option<&str>,
|
preferred_user: Option<&str>,
|
||||||
@@ -27,7 +300,7 @@ fn decode_user_secrets(
|
|||||||
|
|
||||||
if let Some(preferred) = preferred_user
|
if let Some(preferred) = preferred_user
|
||||||
&& let Some(secret_hex) = config.access.users.get(preferred)
|
&& let Some(secret_hex) = config.access.users.get(preferred)
|
||||||
&& let Ok(bytes) = hex::decode(secret_hex)
|
&& let Some(bytes) = decode_user_secret(preferred, secret_hex)
|
||||||
{
|
{
|
||||||
secrets.push((preferred.to_string(), bytes));
|
secrets.push((preferred.to_string(), bytes));
|
||||||
}
|
}
|
||||||
@@ -36,7 +309,7 @@ fn decode_user_secrets(
|
|||||||
if preferred_user.is_some_and(|preferred| preferred == name.as_str()) {
|
if preferred_user.is_some_and(|preferred| preferred == name.as_str()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if let Ok(bytes) = hex::decode(secret_hex) {
|
if let Some(bytes) = decode_user_secret(name, secret_hex) {
|
||||||
secrets.push((name.clone(), bytes));
|
secrets.push((name.clone(), bytes));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -48,7 +321,7 @@ fn decode_user_secrets(
|
|||||||
///
|
///
|
||||||
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
|
/// Key material (`dec_key`, `dec_iv`, `enc_key`, `enc_iv`) is
|
||||||
/// zeroized on drop.
|
/// zeroized on drop.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug)]
|
||||||
pub struct HandshakeSuccess {
|
pub struct HandshakeSuccess {
|
||||||
/// Authenticated user name
|
/// Authenticated user name
|
||||||
pub user: String,
|
pub user: String,
|
||||||
@@ -94,28 +367,27 @@ where
|
|||||||
{
|
{
|
||||||
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
|
debug!(peer = %peer, handshake_len = handshake.len(), "Processing TLS handshake");
|
||||||
|
|
||||||
|
if auth_probe_is_throttled(peer.ip(), Instant::now()) {
|
||||||
|
debug!(peer = %peer, "TLS handshake rejected by pre-auth probe throttle");
|
||||||
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
|
}
|
||||||
|
|
||||||
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
|
if handshake.len() < tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 {
|
||||||
debug!(peer = %peer, "TLS handshake too short");
|
debug!(peer = %peer, "TLS handshake too short");
|
||||||
return HandshakeResult::BadClient { reader, writer };
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
}
|
}
|
||||||
|
|
||||||
let digest = &handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN];
|
|
||||||
let digest_half = &digest[..tls::TLS_DIGEST_HALF_LEN];
|
|
||||||
|
|
||||||
if replay_checker.check_and_add_tls_digest(digest_half) {
|
|
||||||
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
|
|
||||||
return HandshakeResult::BadClient { reader, writer };
|
|
||||||
}
|
|
||||||
|
|
||||||
let secrets = decode_user_secrets(config, None);
|
let secrets = decode_user_secrets(config, None);
|
||||||
|
|
||||||
let validation = match tls::validate_tls_handshake(
|
let validation = match tls::validate_tls_handshake_with_replay_window(
|
||||||
handshake,
|
handshake,
|
||||||
&secrets,
|
&secrets,
|
||||||
config.access.ignore_time_skew,
|
config.access.ignore_time_skew,
|
||||||
|
config.access.replay_window_secs,
|
||||||
) {
|
) {
|
||||||
Some(v) => v,
|
Some(v) => v,
|
||||||
None => {
|
None => {
|
||||||
|
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||||
debug!(
|
debug!(
|
||||||
peer = %peer,
|
peer = %peer,
|
||||||
ignore_time_skew = config.access.ignore_time_skew,
|
ignore_time_skew = config.access.ignore_time_skew,
|
||||||
@@ -125,6 +397,15 @@ where
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Replay tracking is applied only after successful authentication to avoid
|
||||||
|
// letting unauthenticated probes evict valid entries from the replay cache.
|
||||||
|
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
|
||||||
|
if replay_checker.check_and_add_tls_digest(digest_half) {
|
||||||
|
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||||
|
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
|
||||||
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
|
}
|
||||||
|
|
||||||
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
|
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
|
||||||
Some((_, s)) => s,
|
Some((_, s)) => s,
|
||||||
None => return HandshakeResult::BadClient { reader, writer },
|
None => return HandshakeResult::BadClient { reader, writer },
|
||||||
@@ -166,6 +447,9 @@ where
|
|||||||
Some(b"h2".to_vec())
|
Some(b"h2".to_vec())
|
||||||
} else if alpn_list.iter().any(|p| p == b"http/1.1") {
|
} else if alpn_list.iter().any(|p| p == b"http/1.1") {
|
||||||
Some(b"http/1.1".to_vec())
|
Some(b"http/1.1".to_vec())
|
||||||
|
} else if !alpn_list.is_empty() {
|
||||||
|
debug!(peer = %peer, "Client ALPN list has no supported protocol; using masking fallback");
|
||||||
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
@@ -228,6 +512,8 @@ where
|
|||||||
"TLS handshake successful"
|
"TLS handshake successful"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
auth_probe_record_success(peer.ip());
|
||||||
|
|
||||||
HandshakeResult::Success((
|
HandshakeResult::Success((
|
||||||
FakeTlsReader::new(reader),
|
FakeTlsReader::new(reader),
|
||||||
FakeTlsWriter::new(writer),
|
FakeTlsWriter::new(writer),
|
||||||
@@ -252,13 +538,13 @@ where
|
|||||||
{
|
{
|
||||||
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
|
trace!(peer = %peer, handshake = ?hex::encode(handshake), "MTProto handshake bytes");
|
||||||
|
|
||||||
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
|
if auth_probe_is_throttled(peer.ip(), Instant::now()) {
|
||||||
|
debug!(peer = %peer, "MTProto handshake rejected by pre-auth probe throttle");
|
||||||
if replay_checker.check_and_add_handshake(dec_prekey_iv) {
|
|
||||||
warn!(peer = %peer, "MTProto replay attack detected");
|
|
||||||
return HandshakeResult::BadClient { reader, writer };
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let dec_prekey_iv = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
|
||||||
|
|
||||||
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
|
let enc_prekey_iv: Vec<u8> = dec_prekey_iv.iter().rev().copied().collect();
|
||||||
|
|
||||||
let decoded_users = decode_user_secrets(config, preferred_user);
|
let decoded_users = decode_user_secrets(config, preferred_user);
|
||||||
@@ -273,39 +559,33 @@ where
|
|||||||
dec_key_input.extend_from_slice(&secret);
|
dec_key_input.extend_from_slice(&secret);
|
||||||
let dec_key = sha256(&dec_key_input);
|
let dec_key = sha256(&dec_key_input);
|
||||||
|
|
||||||
let dec_iv = u128::from_be_bytes(dec_iv_bytes.try_into().unwrap());
|
let mut dec_iv_arr = [0u8; IV_LEN];
|
||||||
|
dec_iv_arr.copy_from_slice(dec_iv_bytes);
|
||||||
|
let dec_iv = u128::from_be_bytes(dec_iv_arr);
|
||||||
|
|
||||||
let mut decryptor = AesCtr::new(&dec_key, dec_iv);
|
let mut decryptor = AesCtr::new(&dec_key, dec_iv);
|
||||||
let decrypted = decryptor.decrypt(handshake);
|
let decrypted = decryptor.decrypt(handshake);
|
||||||
|
|
||||||
let tag_bytes: [u8; 4] = decrypted[PROTO_TAG_POS..PROTO_TAG_POS + 4]
|
let tag_bytes: [u8; 4] = [
|
||||||
.try_into()
|
decrypted[PROTO_TAG_POS],
|
||||||
.unwrap();
|
decrypted[PROTO_TAG_POS + 1],
|
||||||
|
decrypted[PROTO_TAG_POS + 2],
|
||||||
|
decrypted[PROTO_TAG_POS + 3],
|
||||||
|
];
|
||||||
|
|
||||||
let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
|
let proto_tag = match ProtoTag::from_bytes(tag_bytes) {
|
||||||
Some(tag) => tag,
|
Some(tag) => tag,
|
||||||
None => continue,
|
None => continue,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mode_ok = match proto_tag {
|
let mode_ok = mode_enabled_for_proto(config, proto_tag, is_tls);
|
||||||
ProtoTag::Secure => {
|
|
||||||
if is_tls {
|
|
||||||
config.general.modes.tls || config.general.modes.secure
|
|
||||||
} else {
|
|
||||||
config.general.modes.secure || config.general.modes.tls
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ProtoTag::Intermediate | ProtoTag::Abridged => config.general.modes.classic,
|
|
||||||
};
|
|
||||||
|
|
||||||
if !mode_ok {
|
if !mode_ok {
|
||||||
debug!(peer = %peer, user = %user, proto = ?proto_tag, "Mode not enabled");
|
debug!(peer = %peer, user = %user, proto = ?proto_tag, "Mode not enabled");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let dc_idx = i16::from_le_bytes(
|
let dc_idx = i16::from_le_bytes([decrypted[DC_IDX_POS], decrypted[DC_IDX_POS + 1]]);
|
||||||
decrypted[DC_IDX_POS..DC_IDX_POS + 2].try_into().unwrap()
|
|
||||||
);
|
|
||||||
|
|
||||||
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
|
let enc_prekey = &enc_prekey_iv[..PREKEY_LEN];
|
||||||
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
|
let enc_iv_bytes = &enc_prekey_iv[PREKEY_LEN..];
|
||||||
@@ -315,10 +595,24 @@ where
|
|||||||
enc_key_input.extend_from_slice(&secret);
|
enc_key_input.extend_from_slice(&secret);
|
||||||
let enc_key = sha256(&enc_key_input);
|
let enc_key = sha256(&enc_key_input);
|
||||||
|
|
||||||
let enc_iv = u128::from_be_bytes(enc_iv_bytes.try_into().unwrap());
|
let mut enc_iv_arr = [0u8; IV_LEN];
|
||||||
|
enc_iv_arr.copy_from_slice(enc_iv_bytes);
|
||||||
|
let enc_iv = u128::from_be_bytes(enc_iv_arr);
|
||||||
|
|
||||||
let encryptor = AesCtr::new(&enc_key, enc_iv);
|
let encryptor = AesCtr::new(&enc_key, enc_iv);
|
||||||
|
|
||||||
|
// Apply replay tracking only after successful authentication.
|
||||||
|
//
|
||||||
|
// This ordering prevents an attacker from producing invalid handshakes that
|
||||||
|
// still collide with a valid handshake's replay slot and thus evict a valid
|
||||||
|
// entry from the cache. We accept the cost of performing the full
|
||||||
|
// authentication check first to avoid poisoning the replay cache.
|
||||||
|
if replay_checker.check_and_add_handshake(dec_prekey_iv) {
|
||||||
|
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||||
|
warn!(peer = %peer, user = %user, "MTProto replay attack detected");
|
||||||
|
return HandshakeResult::BadClient { reader, writer };
|
||||||
|
}
|
||||||
|
|
||||||
let success = HandshakeSuccess {
|
let success = HandshakeSuccess {
|
||||||
user: user.clone(),
|
user: user.clone(),
|
||||||
dc_idx,
|
dc_idx,
|
||||||
@@ -340,6 +634,8 @@ where
|
|||||||
"MTProto handshake successful"
|
"MTProto handshake successful"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
auth_probe_record_success(peer.ip());
|
||||||
|
|
||||||
let max_pending = config.general.crypto_pending_buffer;
|
let max_pending = config.general.crypto_pending_buffer;
|
||||||
return HandshakeResult::Success((
|
return HandshakeResult::Success((
|
||||||
CryptoReader::new(reader, decryptor),
|
CryptoReader::new(reader, decryptor),
|
||||||
@@ -348,6 +644,7 @@ where
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||||
debug!(peer = %peer, "MTProto handshake: no matching user found");
|
debug!(peer = %peer, "MTProto handshake: no matching user found");
|
||||||
HandshakeResult::BadClient { reader, writer }
|
HandshakeResult::BadClient { reader, writer }
|
||||||
}
|
}
|
||||||
@@ -356,8 +653,6 @@ where
|
|||||||
pub fn generate_tg_nonce(
|
pub fn generate_tg_nonce(
|
||||||
proto_tag: ProtoTag,
|
proto_tag: ProtoTag,
|
||||||
dc_idx: i16,
|
dc_idx: i16,
|
||||||
_client_dec_key: &[u8; 32],
|
|
||||||
_client_dec_iv: u128,
|
|
||||||
client_enc_key: &[u8; 32],
|
client_enc_key: &[u8; 32],
|
||||||
client_enc_iv: u128,
|
client_enc_iv: u128,
|
||||||
rng: &SecureRandom,
|
rng: &SecureRandom,
|
||||||
@@ -365,14 +660,16 @@ pub fn generate_tg_nonce(
|
|||||||
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
|
) -> ([u8; HANDSHAKE_LEN], [u8; 32], u128, [u8; 32], u128) {
|
||||||
loop {
|
loop {
|
||||||
let bytes = rng.bytes(HANDSHAKE_LEN);
|
let bytes = rng.bytes(HANDSHAKE_LEN);
|
||||||
let mut nonce: [u8; HANDSHAKE_LEN] = bytes.try_into().unwrap();
|
let Ok(mut nonce): Result<[u8; HANDSHAKE_LEN], _> = bytes.try_into() else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; }
|
if RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]) { continue; }
|
||||||
|
|
||||||
let first_four: [u8; 4] = nonce[..4].try_into().unwrap();
|
let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]];
|
||||||
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; }
|
if RESERVED_NONCE_BEGINNINGS.contains(&first_four) { continue; }
|
||||||
|
|
||||||
let continue_four: [u8; 4] = nonce[4..8].try_into().unwrap();
|
let continue_four: [u8; 4] = [nonce[4], nonce[5], nonce[6], nonce[7]];
|
||||||
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; }
|
if RESERVED_NONCE_CONTINUES.contains(&continue_four) { continue; }
|
||||||
|
|
||||||
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
|
nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
|
||||||
@@ -390,11 +687,17 @@ pub fn generate_tg_nonce(
|
|||||||
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
||||||
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
|
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
|
||||||
|
|
||||||
let tg_enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
|
let mut tg_enc_key = [0u8; 32];
|
||||||
let tg_enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
|
tg_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]);
|
||||||
|
let mut tg_enc_iv_arr = [0u8; IV_LEN];
|
||||||
|
tg_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]);
|
||||||
|
let tg_enc_iv = u128::from_be_bytes(tg_enc_iv_arr);
|
||||||
|
|
||||||
let tg_dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
|
let mut tg_dec_key = [0u8; 32];
|
||||||
let tg_dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
|
tg_dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]);
|
||||||
|
let mut tg_dec_iv_arr = [0u8; IV_LEN];
|
||||||
|
tg_dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]);
|
||||||
|
let tg_dec_iv = u128::from_be_bytes(tg_dec_iv_arr);
|
||||||
|
|
||||||
return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv);
|
return (nonce, tg_enc_key, tg_enc_iv, tg_dec_key, tg_dec_iv);
|
||||||
}
|
}
|
||||||
@@ -405,11 +708,17 @@ pub fn encrypt_tg_nonce_with_ciphers(nonce: &[u8; HANDSHAKE_LEN]) -> (Vec<u8>, A
|
|||||||
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
|
||||||
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
|
let dec_key_iv: Vec<u8> = enc_key_iv.iter().rev().copied().collect();
|
||||||
|
|
||||||
let enc_key: [u8; 32] = enc_key_iv[..KEY_LEN].try_into().unwrap();
|
let mut enc_key = [0u8; 32];
|
||||||
let enc_iv = u128::from_be_bytes(enc_key_iv[KEY_LEN..].try_into().unwrap());
|
enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]);
|
||||||
|
let mut enc_iv_arr = [0u8; IV_LEN];
|
||||||
|
enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]);
|
||||||
|
let enc_iv = u128::from_be_bytes(enc_iv_arr);
|
||||||
|
|
||||||
let dec_key: [u8; 32] = dec_key_iv[..KEY_LEN].try_into().unwrap();
|
let mut dec_key = [0u8; 32];
|
||||||
let dec_iv = u128::from_be_bytes(dec_key_iv[KEY_LEN..].try_into().unwrap());
|
dec_key.copy_from_slice(&dec_key_iv[..KEY_LEN]);
|
||||||
|
let mut dec_iv_arr = [0u8; IV_LEN];
|
||||||
|
dec_iv_arr.copy_from_slice(&dec_key_iv[KEY_LEN..]);
|
||||||
|
let dec_iv = u128::from_be_bytes(dec_iv_arr);
|
||||||
|
|
||||||
let mut encryptor = AesCtr::new(&enc_key, enc_iv);
|
let mut encryptor = AesCtr::new(&enc_key, enc_iv);
|
||||||
let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4
|
let encrypted_full = encryptor.encrypt(nonce); // counter: 0 → 4
|
||||||
@@ -429,80 +738,15 @@ pub fn encrypt_tg_nonce(nonce: &[u8; HANDSHAKE_LEN]) -> Vec<u8> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
#[path = "handshake_security_tests.rs"]
|
||||||
use super::*;
|
mod security_tests;
|
||||||
|
|
||||||
#[test]
|
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and
|
||||||
fn test_generate_tg_nonce() {
|
/// must never be Copy. A Copy impl would allow silent key duplication,
|
||||||
let client_dec_key = [0x42u8; 32];
|
/// undermining the zeroize-on-drop guarantee.
|
||||||
let client_dec_iv = 12345u128;
|
mod compile_time_security_checks {
|
||||||
let client_enc_key = [0x24u8; 32];
|
use super::HandshakeSuccess;
|
||||||
let client_enc_iv = 54321u128;
|
use static_assertions::assert_not_impl_all;
|
||||||
|
|
||||||
let rng = SecureRandom::new();
|
assert_not_impl_all!(HandshakeSuccess: Copy, Clone);
|
||||||
let (nonce, _tg_enc_key, _tg_enc_iv, _tg_dec_key, _tg_dec_iv) =
|
|
||||||
generate_tg_nonce(
|
|
||||||
ProtoTag::Secure,
|
|
||||||
2,
|
|
||||||
&client_dec_key,
|
|
||||||
client_dec_iv,
|
|
||||||
&client_enc_key,
|
|
||||||
client_enc_iv,
|
|
||||||
&rng,
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(nonce.len(), HANDSHAKE_LEN);
|
|
||||||
|
|
||||||
let tag_bytes: [u8; 4] = nonce[PROTO_TAG_POS..PROTO_TAG_POS + 4].try_into().unwrap();
|
|
||||||
assert_eq!(ProtoTag::from_bytes(tag_bytes), Some(ProtoTag::Secure));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_encrypt_tg_nonce() {
|
|
||||||
let client_dec_key = [0x42u8; 32];
|
|
||||||
let client_dec_iv = 12345u128;
|
|
||||||
let client_enc_key = [0x24u8; 32];
|
|
||||||
let client_enc_iv = 54321u128;
|
|
||||||
|
|
||||||
let rng = SecureRandom::new();
|
|
||||||
let (nonce, _, _, _, _) =
|
|
||||||
generate_tg_nonce(
|
|
||||||
ProtoTag::Secure,
|
|
||||||
2,
|
|
||||||
&client_dec_key,
|
|
||||||
client_dec_iv,
|
|
||||||
&client_enc_key,
|
|
||||||
client_enc_iv,
|
|
||||||
&rng,
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
|
|
||||||
let encrypted = encrypt_tg_nonce(&nonce);
|
|
||||||
|
|
||||||
assert_eq!(encrypted.len(), HANDSHAKE_LEN);
|
|
||||||
assert_eq!(&encrypted[..PROTO_TAG_POS], &nonce[..PROTO_TAG_POS]);
|
|
||||||
assert_ne!(&encrypted[PROTO_TAG_POS..], &nonce[PROTO_TAG_POS..]);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_handshake_success_zeroize_on_drop() {
|
|
||||||
let success = HandshakeSuccess {
|
|
||||||
user: "test".to_string(),
|
|
||||||
dc_idx: 2,
|
|
||||||
proto_tag: ProtoTag::Secure,
|
|
||||||
dec_key: [0xAA; 32],
|
|
||||||
dec_iv: 0xBBBBBBBB,
|
|
||||||
enc_key: [0xCC; 32],
|
|
||||||
enc_iv: 0xDDDDDDDD,
|
|
||||||
peer: "127.0.0.1:1234".parse().unwrap(),
|
|
||||||
is_tls: true,
|
|
||||||
};
|
|
||||||
|
|
||||||
assert_eq!(success.dec_key, [0xAA; 32]);
|
|
||||||
assert_eq!(success.enc_key, [0xCC; 32]);
|
|
||||||
|
|
||||||
drop(success);
|
|
||||||
// Drop impl zeroizes key material without panic
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
1110
src/proxy/handshake_security_tests.rs
Normal file
1110
src/proxy/handshake_security_tests.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -14,12 +14,41 @@ use crate::network::dns_overrides::resolve_socket_addr;
|
|||||||
use crate::stats::beobachten::BeobachtenStore;
|
use crate::stats::beobachten::BeobachtenStore;
|
||||||
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
|
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
|
||||||
|
|
||||||
|
#[cfg(not(test))]
|
||||||
const MASK_TIMEOUT: Duration = Duration::from_secs(5);
|
const MASK_TIMEOUT: Duration = Duration::from_secs(5);
|
||||||
|
#[cfg(test)]
|
||||||
|
const MASK_TIMEOUT: Duration = Duration::from_millis(50);
|
||||||
/// Maximum duration for the entire masking relay.
|
/// Maximum duration for the entire masking relay.
|
||||||
/// Limits resource consumption from slow-loris attacks and port scanners.
|
/// Limits resource consumption from slow-loris attacks and port scanners.
|
||||||
|
#[cfg(not(test))]
|
||||||
const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60);
|
const MASK_RELAY_TIMEOUT: Duration = Duration::from_secs(60);
|
||||||
|
#[cfg(test)]
|
||||||
|
const MASK_RELAY_TIMEOUT: Duration = Duration::from_millis(200);
|
||||||
const MASK_BUFFER_SIZE: usize = 8192;
|
const MASK_BUFFER_SIZE: usize = 8192;
|
||||||
|
|
||||||
|
async fn write_proxy_header_with_timeout<W>(mask_write: &mut W, header: &[u8]) -> bool
|
||||||
|
where
|
||||||
|
W: AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
match timeout(MASK_TIMEOUT, mask_write.write_all(header)).await {
|
||||||
|
Ok(Ok(())) => true,
|
||||||
|
Ok(Err(_)) => false,
|
||||||
|
Err(_) => {
|
||||||
|
debug!("Timeout writing proxy protocol header to mask backend");
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn consume_client_data_with_timeout<R>(reader: R)
|
||||||
|
where
|
||||||
|
R: AsyncRead + Unpin,
|
||||||
|
{
|
||||||
|
if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader)).await.is_err() {
|
||||||
|
debug!("Timed out while consuming client data on masking fallback path");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Detect client type based on initial data
|
/// Detect client type based on initial data
|
||||||
fn detect_client_type(data: &[u8]) -> &'static str {
|
fn detect_client_type(data: &[u8]) -> &'static str {
|
||||||
// Check for HTTP request
|
// Check for HTTP request
|
||||||
@@ -71,7 +100,7 @@ where
|
|||||||
|
|
||||||
if !config.censorship.mask {
|
if !config.censorship.mask {
|
||||||
// Masking disabled, just consume data
|
// Masking disabled, just consume data
|
||||||
consume_client_data(reader).await;
|
consume_client_data_with_timeout(reader).await;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,7 +136,7 @@ where
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
if let Some(header) = proxy_header {
|
if let Some(header) = proxy_header {
|
||||||
if mask_write.write_all(&header).await.is_err() {
|
if !write_proxy_header_with_timeout(&mut mask_write, &header).await {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -117,11 +146,11 @@ where
|
|||||||
}
|
}
|
||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
debug!(error = %e, "Failed to connect to mask unix socket");
|
debug!(error = %e, "Failed to connect to mask unix socket");
|
||||||
consume_client_data(reader).await;
|
consume_client_data_with_timeout(reader).await;
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
debug!("Timeout connecting to mask unix socket");
|
debug!("Timeout connecting to mask unix socket");
|
||||||
consume_client_data(reader).await;
|
consume_client_data_with_timeout(reader).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
@@ -166,7 +195,7 @@ where
|
|||||||
|
|
||||||
let (mask_read, mut mask_write) = stream.into_split();
|
let (mask_read, mut mask_write) = stream.into_split();
|
||||||
if let Some(header) = proxy_header {
|
if let Some(header) = proxy_header {
|
||||||
if mask_write.write_all(&header).await.is_err() {
|
if !write_proxy_header_with_timeout(&mut mask_write, &header).await {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -176,11 +205,11 @@ where
|
|||||||
}
|
}
|
||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
debug!(error = %e, "Failed to connect to mask host");
|
debug!(error = %e, "Failed to connect to mask host");
|
||||||
consume_client_data(reader).await;
|
consume_client_data_with_timeout(reader).await;
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
debug!("Timeout connecting to mask host");
|
debug!("Timeout connecting to mask host");
|
||||||
consume_client_data(reader).await;
|
consume_client_data_with_timeout(reader).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -203,47 +232,20 @@ where
|
|||||||
if mask_write.write_all(initial_data).await.is_err() {
|
if mask_write.write_all(initial_data).await.is_err() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if mask_write.flush().await.is_err() {
|
||||||
// Relay traffic
|
return;
|
||||||
let c2m = tokio::spawn(async move {
|
|
||||||
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
|
|
||||||
loop {
|
|
||||||
match reader.read(&mut buf).await {
|
|
||||||
Ok(0) | Err(_) => {
|
|
||||||
let _ = mask_write.shutdown().await;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Ok(n) => {
|
|
||||||
if mask_write.write_all(&buf[..n]).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
let m2c = tokio::spawn(async move {
|
|
||||||
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
|
|
||||||
loop {
|
|
||||||
match mask_read.read(&mut buf).await {
|
|
||||||
Ok(0) | Err(_) => {
|
|
||||||
let _ = writer.shutdown().await;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
Ok(n) => {
|
|
||||||
if writer.write_all(&buf[..n]).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Wait for either to complete
|
|
||||||
tokio::select! {
|
|
||||||
_ = c2m => {}
|
|
||||||
_ = m2c => {}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let _ = tokio::join!(
|
||||||
|
async {
|
||||||
|
let _ = tokio::io::copy(&mut reader, &mut mask_write).await;
|
||||||
|
let _ = mask_write.shutdown().await;
|
||||||
|
},
|
||||||
|
async {
|
||||||
|
let _ = tokio::io::copy(&mut mask_read, &mut writer).await;
|
||||||
|
let _ = writer.shutdown().await;
|
||||||
|
}
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Just consume all data from client without responding
|
/// Just consume all data from client without responding
|
||||||
@@ -255,3 +257,7 @@ async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
#[path = "masking_security_tests.rs"]
|
||||||
|
mod security_tests;
|
||||||
|
|||||||
731
src/proxy/masking_security_tests.rs
Normal file
731
src/proxy/masking_security_tests.rs
Normal file
@@ -0,0 +1,731 @@
|
|||||||
|
use super::*;
|
||||||
|
use crate::config::ProxyConfig;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::task::{Context, Poll};
|
||||||
|
use tokio::io::{duplex, AsyncBufReadExt, BufReader};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
#[cfg(unix)]
|
||||||
|
use tokio::net::UnixListener;
|
||||||
|
use tokio::time::{sleep, timeout, Duration};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn bad_client_probe_is_forwarded_verbatim_to_mask_backend() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let backend_addr = listener.local_addr().unwrap();
|
||||||
|
let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
|
||||||
|
let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
|
||||||
|
|
||||||
|
let accept_task = tokio::spawn({
|
||||||
|
let probe = probe.clone();
|
||||||
|
let backend_reply = backend_reply.clone();
|
||||||
|
async move {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut received = vec![0u8; probe.len()];
|
||||||
|
stream.read_exact(&mut received).await.unwrap();
|
||||||
|
assert_eq!(received, probe);
|
||||||
|
stream.write_all(&backend_reply).await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = backend_addr.port();
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.10:42424".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
let (client_reader, _client_writer) = duplex(256);
|
||||||
|
let (mut client_visible_reader, client_visible_writer) = duplex(2048);
|
||||||
|
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
&probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mut observed = vec![0u8; backend_reply.len()];
|
||||||
|
client_visible_reader.read_exact(&mut observed).await.unwrap();
|
||||||
|
assert_eq!(observed, backend_reply);
|
||||||
|
accept_task.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn tls_scanner_probe_keeps_http_like_fallback_surface() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let backend_addr = listener.local_addr().unwrap();
|
||||||
|
let probe = vec![0x16, 0x03, 0x01, 0x00, 0x10, 0x01, 0x02, 0x03, 0x04];
|
||||||
|
let backend_reply = b"HTTP/1.1 400 Bad Request\r\nContent-Length: 0\r\n\r\n".to_vec();
|
||||||
|
|
||||||
|
let accept_task = tokio::spawn({
|
||||||
|
let probe = probe.clone();
|
||||||
|
let backend_reply = backend_reply.clone();
|
||||||
|
async move {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut received = vec![0u8; probe.len()];
|
||||||
|
stream.read_exact(&mut received).await.unwrap();
|
||||||
|
assert_eq!(received, probe);
|
||||||
|
stream.write_all(&backend_reply).await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = true;
|
||||||
|
config.general.beobachten_minutes = 1;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = backend_addr.port();
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "198.51.100.44:55221".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
let (client_reader, _client_writer) = duplex(256);
|
||||||
|
let (mut client_visible_reader, client_visible_writer) = duplex(2048);
|
||||||
|
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
&probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mut observed = vec![0u8; backend_reply.len()];
|
||||||
|
client_visible_reader.read_exact(&mut observed).await.unwrap();
|
||||||
|
assert_eq!(observed, backend_reply);
|
||||||
|
|
||||||
|
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
|
||||||
|
assert!(snapshot.contains("[TLS-scanner]"));
|
||||||
|
assert!(snapshot.contains("198.51.100.44-1"));
|
||||||
|
accept_task.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detect_client_type_covers_ssh_port_scanner_and_unknown() {
|
||||||
|
assert_eq!(detect_client_type(b"SSH-2.0-OpenSSH_9.7"), "SSH");
|
||||||
|
assert_eq!(detect_client_type(b"\x01\x02\x03"), "port-scanner");
|
||||||
|
assert_eq!(detect_client_type(b"random-binary-payload"), "unknown");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detect_client_type_len_boundary_9_vs_10_bytes() {
|
||||||
|
assert_eq!(detect_client_type(b"123456789"), "port-scanner");
|
||||||
|
assert_eq!(detect_client_type(b"1234567890"), "unknown");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn beobachten_records_scanner_class_when_mask_is_disabled() {
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = true;
|
||||||
|
config.general.beobachten_minutes = 1;
|
||||||
|
config.censorship.mask = false;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.99:41234".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
let initial = b"SSH-2.0-probe";
|
||||||
|
|
||||||
|
let (mut client_reader_side, client_reader) = duplex(256);
|
||||||
|
let (_client_visible_reader, client_visible_writer) = duplex(256);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let task = tokio::spawn(async move {
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
initial,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
beobachten
|
||||||
|
});
|
||||||
|
|
||||||
|
client_reader_side.write_all(b"noise").await.unwrap();
|
||||||
|
drop(client_reader_side);
|
||||||
|
|
||||||
|
let beobachten = timeout(Duration::from_secs(3), task).await.unwrap().unwrap();
|
||||||
|
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
|
||||||
|
assert!(snapshot.contains("[SSH]"));
|
||||||
|
assert!(snapshot.contains("203.0.113.99-1"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn backend_unavailable_falls_back_to_silent_consume() {
|
||||||
|
let temp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let unused_port = temp_listener.local_addr().unwrap().port();
|
||||||
|
drop(temp_listener);
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = unused_port;
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.11:42425".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
let probe = b"GET /probe HTTP/1.1\r\nHost: x\r\n\r\n";
|
||||||
|
|
||||||
|
let (mut client_reader_side, client_reader) = duplex(256);
|
||||||
|
let (mut client_visible_reader, client_visible_writer) = duplex(256);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let task = tokio::spawn(async move {
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
|
||||||
|
client_reader_side.write_all(b"noise").await.unwrap();
|
||||||
|
drop(client_reader_side);
|
||||||
|
|
||||||
|
timeout(Duration::from_secs(3), task).await.unwrap().unwrap();
|
||||||
|
|
||||||
|
let mut buf = [0u8; 1];
|
||||||
|
let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(n, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn mask_disabled_consumes_client_data_without_response() {
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = false;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "198.51.100.12:45454".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
let initial = b"scanner";
|
||||||
|
|
||||||
|
let (mut client_reader_side, client_reader) = duplex(256);
|
||||||
|
let (mut client_visible_reader, client_visible_writer) = duplex(256);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let task = tokio::spawn(async move {
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
initial,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
|
||||||
|
client_reader_side.write_all(b"untrusted payload").await.unwrap();
|
||||||
|
drop(client_reader_side);
|
||||||
|
|
||||||
|
timeout(Duration::from_secs(3), task).await.unwrap().unwrap();
|
||||||
|
|
||||||
|
let mut buf = [0u8; 1];
|
||||||
|
let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(n, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn proxy_protocol_v1_header_is_sent_before_probe() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let backend_addr = listener.local_addr().unwrap();
|
||||||
|
let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
|
||||||
|
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
|
||||||
|
|
||||||
|
let accept_task = tokio::spawn({
|
||||||
|
let probe = probe.clone();
|
||||||
|
let backend_reply = backend_reply.clone();
|
||||||
|
async move {
|
||||||
|
let (stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut reader = BufReader::new(stream);
|
||||||
|
|
||||||
|
let mut header_line = Vec::new();
|
||||||
|
reader.read_until(b'\n', &mut header_line).await.unwrap();
|
||||||
|
let header_text = String::from_utf8(header_line.clone()).unwrap();
|
||||||
|
assert!(header_text.starts_with("PROXY TCP4 "));
|
||||||
|
assert!(header_text.ends_with("\r\n"));
|
||||||
|
|
||||||
|
let mut received_probe = vec![0u8; probe.len()];
|
||||||
|
reader.read_exact(&mut received_probe).await.unwrap();
|
||||||
|
assert_eq!(received_probe, probe);
|
||||||
|
|
||||||
|
let mut stream = reader.into_inner();
|
||||||
|
stream.write_all(&backend_reply).await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = backend_addr.port();
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 1;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.15:50001".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
let (client_reader, _client_writer) = duplex(256);
|
||||||
|
let (mut client_visible_reader, client_visible_writer) = duplex(2048);
|
||||||
|
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
&probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mut observed = vec![0u8; backend_reply.len()];
|
||||||
|
client_visible_reader.read_exact(&mut observed).await.unwrap();
|
||||||
|
assert_eq!(observed, backend_reply);
|
||||||
|
accept_task.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn proxy_protocol_v2_header_is_sent_before_probe() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let backend_addr = listener.local_addr().unwrap();
|
||||||
|
let probe = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
|
||||||
|
let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n".to_vec();
|
||||||
|
|
||||||
|
let accept_task = tokio::spawn({
|
||||||
|
let probe = probe.clone();
|
||||||
|
let backend_reply = backend_reply.clone();
|
||||||
|
async move {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
|
||||||
|
let mut sig = [0u8; 12];
|
||||||
|
stream.read_exact(&mut sig).await.unwrap();
|
||||||
|
assert_eq!(&sig, b"\r\n\r\n\0\r\nQUIT\n");
|
||||||
|
|
||||||
|
let mut fixed = [0u8; 4];
|
||||||
|
stream.read_exact(&mut fixed).await.unwrap();
|
||||||
|
let addr_len = u16::from_be_bytes([fixed[2], fixed[3]]) as usize;
|
||||||
|
|
||||||
|
let mut addr_block = vec![0u8; addr_len];
|
||||||
|
stream.read_exact(&mut addr_block).await.unwrap();
|
||||||
|
|
||||||
|
let mut received_probe = vec![0u8; probe.len()];
|
||||||
|
stream.read_exact(&mut received_probe).await.unwrap();
|
||||||
|
assert_eq!(received_probe, probe);
|
||||||
|
|
||||||
|
stream.write_all(&backend_reply).await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = backend_addr.port();
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 2;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.18:50004".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
let (client_reader, _client_writer) = duplex(256);
|
||||||
|
let (mut client_visible_reader, client_visible_writer) = duplex(2048);
|
||||||
|
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
&probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mut observed = vec![0u8; backend_reply.len()];
|
||||||
|
client_visible_reader.read_exact(&mut observed).await.unwrap();
|
||||||
|
assert_eq!(observed, backend_reply);
|
||||||
|
accept_task.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn proxy_protocol_v1_mixed_family_falls_back_to_unknown_header() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let backend_addr = listener.local_addr().unwrap();
|
||||||
|
let probe = b"GET /mix HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
|
||||||
|
let backend_reply = b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\n\r\n".to_vec();
|
||||||
|
|
||||||
|
let accept_task = tokio::spawn({
|
||||||
|
let probe = probe.clone();
|
||||||
|
let backend_reply = backend_reply.clone();
|
||||||
|
async move {
|
||||||
|
let (stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut reader = BufReader::new(stream);
|
||||||
|
|
||||||
|
let mut header_line = Vec::new();
|
||||||
|
reader.read_until(b'\n', &mut header_line).await.unwrap();
|
||||||
|
let header_text = String::from_utf8(header_line).unwrap();
|
||||||
|
assert_eq!(header_text, "PROXY UNKNOWN\r\n");
|
||||||
|
|
||||||
|
let mut received_probe = vec![0u8; probe.len()];
|
||||||
|
reader.read_exact(&mut received_probe).await.unwrap();
|
||||||
|
assert_eq!(received_probe, probe);
|
||||||
|
|
||||||
|
let mut stream = reader.into_inner();
|
||||||
|
stream.write_all(&backend_reply).await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = backend_addr.port();
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 1;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.20:50006".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "[::1]:443".parse().unwrap();
|
||||||
|
|
||||||
|
let (client_reader, _client_writer) = duplex(256);
|
||||||
|
let (mut client_visible_reader, client_visible_writer) = duplex(2048);
|
||||||
|
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
&probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mut observed = vec![0u8; backend_reply.len()];
|
||||||
|
client_visible_reader.read_exact(&mut observed).await.unwrap();
|
||||||
|
assert_eq!(observed, backend_reply);
|
||||||
|
accept_task.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn unix_socket_mask_path_forwards_probe_and_response() {
|
||||||
|
let sock_path = format!("/tmp/telemt-mask-test-{}-{}.sock", std::process::id(), rand::random::<u64>());
|
||||||
|
let _ = std::fs::remove_file(&sock_path);
|
||||||
|
|
||||||
|
let listener = UnixListener::bind(&sock_path).unwrap();
|
||||||
|
let probe = b"GET /unix HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
|
||||||
|
let backend_reply = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
|
||||||
|
|
||||||
|
let accept_task = tokio::spawn({
|
||||||
|
let probe = probe.clone();
|
||||||
|
let backend_reply = backend_reply.clone();
|
||||||
|
async move {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut received = vec![0u8; probe.len()];
|
||||||
|
stream.read_exact(&mut received).await.unwrap();
|
||||||
|
assert_eq!(received, probe);
|
||||||
|
stream.write_all(&backend_reply).await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_unix_sock = Some(sock_path.clone());
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.30:50010".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
let (client_reader, _client_writer) = duplex(256);
|
||||||
|
let (mut client_visible_reader, client_visible_writer) = duplex(2048);
|
||||||
|
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
&probe,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let mut observed = vec![0u8; backend_reply.len()];
|
||||||
|
client_visible_reader.read_exact(&mut observed).await.unwrap();
|
||||||
|
assert_eq!(observed, backend_reply);
|
||||||
|
|
||||||
|
accept_task.await.unwrap();
|
||||||
|
let _ = std::fs::remove_file(sock_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn mask_disabled_slowloris_connection_is_closed_by_consume_timeout() {
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = false;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "198.51.100.33:45455".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
let (_client_reader_side, client_reader) = duplex(256);
|
||||||
|
let (_client_visible_reader, client_visible_writer) = duplex(256);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let task = tokio::spawn(async move {
|
||||||
|
handle_bad_client(
|
||||||
|
client_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
b"slowloris",
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
|
||||||
|
timeout(Duration::from_secs(1), task).await.unwrap().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PendingWriter;
|
||||||
|
|
||||||
|
impl tokio::io::AsyncWrite for PendingWriter {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
_buf: &[u8],
|
||||||
|
) -> Poll<std::io::Result<usize>> {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct DropTrackedPendingReader {
|
||||||
|
dropped: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl tokio::io::AsyncRead for DropTrackedPendingReader {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
_buf: &mut tokio::io::ReadBuf<'_>,
|
||||||
|
) -> Poll<std::io::Result<()>> {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for DropTrackedPendingReader {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.dropped.store(true, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct DropTrackedPendingWriter {
|
||||||
|
dropped: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl tokio::io::AsyncWrite for DropTrackedPendingWriter {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut Context<'_>,
|
||||||
|
_buf: &[u8],
|
||||||
|
) -> Poll<std::io::Result<usize>> {
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for DropTrackedPendingWriter {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.dropped.store(true, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn proxy_header_write_timeout_returns_false() {
|
||||||
|
let mut writer = PendingWriter;
|
||||||
|
let ok = write_proxy_header_with_timeout(&mut writer, b"PROXY UNKNOWN\r\n").await;
|
||||||
|
assert!(!ok, "Proxy header writes that never complete must time out");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stalls() {
|
||||||
|
let (mut client_feed_writer, client_feed_reader) = duplex(64);
|
||||||
|
let (mut client_visible_reader, client_visible_writer) = duplex(64);
|
||||||
|
let (mut backend_feed_writer, backend_feed_reader) = duplex(64);
|
||||||
|
|
||||||
|
// Make client->mask direction immediately active so the c2m path blocks on PendingWriter.
|
||||||
|
client_feed_writer.write_all(b"X").await.unwrap();
|
||||||
|
|
||||||
|
let relay = tokio::spawn(async move {
|
||||||
|
relay_to_mask(
|
||||||
|
client_feed_reader,
|
||||||
|
client_visible_writer,
|
||||||
|
backend_feed_reader,
|
||||||
|
PendingWriter,
|
||||||
|
b"",
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Allow relay tasks to start, then emulate mask backend response.
|
||||||
|
sleep(Duration::from_millis(20)).await;
|
||||||
|
backend_feed_writer.write_all(b"HTTP/1.1 200 OK\r\n\r\n").await.unwrap();
|
||||||
|
backend_feed_writer.shutdown().await.unwrap();
|
||||||
|
|
||||||
|
let mut observed = vec![0u8; 19];
|
||||||
|
timeout(Duration::from_secs(1), client_visible_reader.read_exact(&mut observed))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(observed, b"HTTP/1.1 200 OK\r\n\r\n");
|
||||||
|
|
||||||
|
relay.abort();
|
||||||
|
let _ = relay.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_to_mask_preserves_backend_response_after_client_half_close() {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let backend_addr = listener.local_addr().unwrap();
|
||||||
|
let request = b"GET / HTTP/1.1\r\nHost: front.example\r\n\r\n".to_vec();
|
||||||
|
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK".to_vec();
|
||||||
|
|
||||||
|
let backend_task = tokio::spawn({
|
||||||
|
let request = request.clone();
|
||||||
|
let response = response.clone();
|
||||||
|
async move {
|
||||||
|
let (mut stream, _) = listener.accept().await.unwrap();
|
||||||
|
let mut observed_req = vec![0u8; request.len()];
|
||||||
|
stream.read_exact(&mut observed_req).await.unwrap();
|
||||||
|
assert_eq!(observed_req, request);
|
||||||
|
stream.write_all(&response).await.unwrap();
|
||||||
|
stream.shutdown().await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut config = ProxyConfig::default();
|
||||||
|
config.general.beobachten = false;
|
||||||
|
config.censorship.mask = true;
|
||||||
|
config.censorship.mask_host = Some("127.0.0.1".to_string());
|
||||||
|
config.censorship.mask_port = backend_addr.port();
|
||||||
|
config.censorship.mask_unix_sock = None;
|
||||||
|
config.censorship.mask_proxy_protocol = 0;
|
||||||
|
|
||||||
|
let peer: SocketAddr = "203.0.113.77:55001".parse().unwrap();
|
||||||
|
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||||
|
|
||||||
|
let (mut client_write, client_read) = duplex(1024);
|
||||||
|
let (mut client_visible_reader, client_visible_writer) = duplex(2048);
|
||||||
|
let beobachten = BeobachtenStore::new();
|
||||||
|
|
||||||
|
let fallback_task = tokio::spawn(async move {
|
||||||
|
handle_bad_client(
|
||||||
|
client_read,
|
||||||
|
client_visible_writer,
|
||||||
|
&request,
|
||||||
|
peer,
|
||||||
|
local_addr,
|
||||||
|
&config,
|
||||||
|
&beobachten,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
});
|
||||||
|
|
||||||
|
client_write.shutdown().await.unwrap();
|
||||||
|
|
||||||
|
let mut observed_resp = vec![0u8; response.len()];
|
||||||
|
timeout(Duration::from_secs(1), client_visible_reader.read_exact(&mut observed_resp))
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(observed_resp, response);
|
||||||
|
|
||||||
|
timeout(Duration::from_secs(1), fallback_task).await.unwrap().unwrap();
|
||||||
|
timeout(Duration::from_secs(1), backend_task).await.unwrap().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
|
||||||
|
let reader_dropped = Arc::new(AtomicBool::new(false));
|
||||||
|
let writer_dropped = Arc::new(AtomicBool::new(false));
|
||||||
|
let mask_reader_dropped = Arc::new(AtomicBool::new(false));
|
||||||
|
let mask_writer_dropped = Arc::new(AtomicBool::new(false));
|
||||||
|
|
||||||
|
let reader = DropTrackedPendingReader {
|
||||||
|
dropped: reader_dropped.clone(),
|
||||||
|
};
|
||||||
|
let writer = DropTrackedPendingWriter {
|
||||||
|
dropped: writer_dropped.clone(),
|
||||||
|
};
|
||||||
|
let mask_read = DropTrackedPendingReader {
|
||||||
|
dropped: mask_reader_dropped.clone(),
|
||||||
|
};
|
||||||
|
let mask_write = DropTrackedPendingWriter {
|
||||||
|
dropped: mask_writer_dropped.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let timed = timeout(
|
||||||
|
Duration::from_millis(40),
|
||||||
|
relay_to_mask(reader, writer, mask_read, mask_write, b""),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(timed.is_err(), "stalled relay must be bounded by timeout");
|
||||||
|
|
||||||
|
assert!(reader_dropped.load(Ordering::SeqCst));
|
||||||
|
assert!(writer_dropped.load(Ordering::SeqCst));
|
||||||
|
assert!(mask_reader_dropped.load(Ordering::SeqCst));
|
||||||
|
assert!(mask_writer_dropped.load(Ordering::SeqCst));
|
||||||
|
}
|
||||||
@@ -1,14 +1,16 @@
|
|||||||
use std::collections::HashMap;
|
|
||||||
use std::collections::hash_map::DefaultHasher;
|
use std::collections::hash_map::DefaultHasher;
|
||||||
use std::hash::{Hash, Hasher};
|
use std::hash::{Hash, Hasher};
|
||||||
use std::net::{IpAddr, SocketAddr};
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
use std::sync::{Arc, Mutex, OnceLock};
|
use std::sync::{Arc, OnceLock};
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
#[cfg(test)]
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
use bytes::Bytes;
|
use dashmap::DashMap;
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
use tokio::sync::{mpsc, oneshot, watch};
|
use tokio::sync::{mpsc, oneshot, watch};
|
||||||
|
use tokio::time::timeout;
|
||||||
use tracing::{debug, trace, warn};
|
use tracing::{debug, trace, warn};
|
||||||
|
|
||||||
use crate::config::ProxyConfig;
|
use crate::config::ProxyConfig;
|
||||||
@@ -21,22 +23,24 @@ use crate::proxy::route_mode::{
|
|||||||
cutover_stagger_delay,
|
cutover_stagger_delay,
|
||||||
};
|
};
|
||||||
use crate::stats::Stats;
|
use crate::stats::Stats;
|
||||||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter};
|
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
|
||||||
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
|
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
|
||||||
|
|
||||||
enum C2MeCommand {
|
enum C2MeCommand {
|
||||||
Data { payload: Bytes, flags: u32 },
|
Data { payload: PooledBuffer, flags: u32 },
|
||||||
Close,
|
Close,
|
||||||
}
|
}
|
||||||
|
|
||||||
const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60);
|
const DESYNC_DEDUP_WINDOW: Duration = Duration::from_secs(60);
|
||||||
|
const DESYNC_DEDUP_MAX_ENTRIES: usize = 65_536;
|
||||||
|
const DESYNC_DEDUP_PRUNE_SCAN_LIMIT: usize = 1024;
|
||||||
const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
|
const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
|
||||||
const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
|
const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
|
||||||
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
|
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
|
||||||
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
|
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
|
||||||
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
|
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
|
||||||
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
|
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
|
||||||
static DESYNC_DEDUP: OnceLock<Mutex<HashMap<u64, Instant>>> = OnceLock::new();
|
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
|
||||||
|
|
||||||
struct RelayForensicsState {
|
struct RelayForensicsState {
|
||||||
trace_id: u64,
|
trace_id: u64,
|
||||||
@@ -90,24 +94,55 @@ fn should_emit_full_desync(key: u64, all_full: bool, now: Instant) -> bool {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
let dedup = DESYNC_DEDUP.get_or_init(|| Mutex::new(HashMap::new()));
|
let dedup = DESYNC_DEDUP.get_or_init(DashMap::new);
|
||||||
let mut guard = dedup.lock().expect("desync dedup mutex poisoned");
|
|
||||||
guard.retain(|_, seen_at| now.duration_since(*seen_at) < DESYNC_DEDUP_WINDOW);
|
|
||||||
|
|
||||||
match guard.get_mut(&key) {
|
if let Some(mut seen_at) = dedup.get_mut(&key) {
|
||||||
Some(seen_at) => {
|
if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW {
|
||||||
if now.duration_since(*seen_at) >= DESYNC_DEDUP_WINDOW {
|
*seen_at = now;
|
||||||
*seen_at = now;
|
return true;
|
||||||
true
|
}
|
||||||
} else {
|
return false;
|
||||||
false
|
}
|
||||||
|
|
||||||
|
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
|
||||||
|
let mut stale_keys = Vec::new();
|
||||||
|
let mut eviction_candidate = None;
|
||||||
|
for entry in dedup.iter().take(DESYNC_DEDUP_PRUNE_SCAN_LIMIT) {
|
||||||
|
if eviction_candidate.is_none() {
|
||||||
|
eviction_candidate = Some(*entry.key());
|
||||||
|
}
|
||||||
|
if now.duration_since(*entry.value()) >= DESYNC_DEDUP_WINDOW {
|
||||||
|
stale_keys.push(*entry.key());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => {
|
for stale_key in stale_keys {
|
||||||
guard.insert(key, now);
|
dedup.remove(&stale_key);
|
||||||
true
|
}
|
||||||
|
if dedup.len() >= DESYNC_DEDUP_MAX_ENTRIES {
|
||||||
|
let Some(evict_key) = eviction_candidate else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
dedup.remove(&evict_key);
|
||||||
|
dedup.insert(key, now);
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dedup.insert(key, now);
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn clear_desync_dedup_for_testing() {
|
||||||
|
if let Some(dedup) = DESYNC_DEDUP.get() {
|
||||||
|
dedup.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn desync_dedup_test_lock() -> &'static Mutex<()> {
|
||||||
|
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||||
|
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn report_desync_frame_too_large(
|
fn report_desync_frame_too_large(
|
||||||
@@ -229,7 +264,7 @@ pub(crate) async fn handle_via_middle_proxy<R, W>(
|
|||||||
me_pool: Arc<MePool>,
|
me_pool: Arc<MePool>,
|
||||||
stats: Arc<Stats>,
|
stats: Arc<Stats>,
|
||||||
config: Arc<ProxyConfig>,
|
config: Arc<ProxyConfig>,
|
||||||
_buffer_pool: Arc<BufferPool>,
|
buffer_pool: Arc<BufferPool>,
|
||||||
local_addr: SocketAddr,
|
local_addr: SocketAddr,
|
||||||
rng: Arc<SecureRandom>,
|
rng: Arc<SecureRandom>,
|
||||||
mut route_rx: watch::Receiver<RouteCutoverState>,
|
mut route_rx: watch::Receiver<RouteCutoverState>,
|
||||||
@@ -271,7 +306,6 @@ where
|
|||||||
};
|
};
|
||||||
|
|
||||||
stats.increment_user_connects(&user);
|
stats.increment_user_connects(&user);
|
||||||
stats.increment_user_curr_connects(&user);
|
|
||||||
stats.increment_current_connections_me();
|
stats.increment_current_connections_me();
|
||||||
|
|
||||||
if let Some(cutover) = affected_cutover_state(
|
if let Some(cutover) = affected_cutover_state(
|
||||||
@@ -291,7 +325,6 @@ where
|
|||||||
let _ = me_pool.send_close(conn_id).await;
|
let _ = me_pool.send_close(conn_id).await;
|
||||||
me_pool.registry().unregister(conn_id).await;
|
me_pool.registry().unregister(conn_id).await;
|
||||||
stats.decrement_current_connections_me();
|
stats.decrement_current_connections_me();
|
||||||
stats.decrement_user_curr_connects(&user);
|
|
||||||
return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
|
return Err(ProxyError::Proxy(ROUTE_SWITCH_ERROR_MSG.to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -557,6 +590,8 @@ where
|
|||||||
&mut crypto_reader,
|
&mut crypto_reader,
|
||||||
proto_tag,
|
proto_tag,
|
||||||
frame_limit,
|
frame_limit,
|
||||||
|
Duration::from_secs(config.timeouts.client_handshake.max(1)),
|
||||||
|
&buffer_pool,
|
||||||
&forensics,
|
&forensics,
|
||||||
&mut frame_counter,
|
&mut frame_counter,
|
||||||
&stats,
|
&stats,
|
||||||
@@ -638,7 +673,6 @@ where
|
|||||||
);
|
);
|
||||||
me_pool.registry().unregister(conn_id).await;
|
me_pool.registry().unregister(conn_id).await;
|
||||||
stats.decrement_current_connections_me();
|
stats.decrement_current_connections_me();
|
||||||
stats.decrement_user_curr_connects(&user);
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -646,30 +680,49 @@ async fn read_client_payload<R>(
|
|||||||
client_reader: &mut CryptoReader<R>,
|
client_reader: &mut CryptoReader<R>,
|
||||||
proto_tag: ProtoTag,
|
proto_tag: ProtoTag,
|
||||||
max_frame: usize,
|
max_frame: usize,
|
||||||
|
frame_read_timeout: Duration,
|
||||||
|
buffer_pool: &Arc<BufferPool>,
|
||||||
forensics: &RelayForensicsState,
|
forensics: &RelayForensicsState,
|
||||||
frame_counter: &mut u64,
|
frame_counter: &mut u64,
|
||||||
stats: &Stats,
|
stats: &Stats,
|
||||||
) -> Result<Option<(Bytes, bool)>>
|
) -> Result<Option<(PooledBuffer, bool)>>
|
||||||
where
|
where
|
||||||
R: AsyncRead + Unpin + Send + 'static,
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
{
|
{
|
||||||
|
async fn read_exact_with_timeout<R>(
|
||||||
|
client_reader: &mut CryptoReader<R>,
|
||||||
|
buf: &mut [u8],
|
||||||
|
frame_read_timeout: Duration,
|
||||||
|
) -> Result<()>
|
||||||
|
where
|
||||||
|
R: AsyncRead + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
match timeout(frame_read_timeout, client_reader.read_exact(buf)).await {
|
||||||
|
Ok(Ok(_)) => Ok(()),
|
||||||
|
Ok(Err(e)) => Err(ProxyError::Io(e)),
|
||||||
|
Err(_) => Err(ProxyError::Io(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::TimedOut,
|
||||||
|
"middle-relay client frame read timeout",
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let (len, quickack, raw_len_bytes) = match proto_tag {
|
let (len, quickack, raw_len_bytes) = match proto_tag {
|
||||||
ProtoTag::Abridged => {
|
ProtoTag::Abridged => {
|
||||||
let mut first = [0u8; 1];
|
let mut first = [0u8; 1];
|
||||||
match client_reader.read_exact(&mut first).await {
|
match read_exact_with_timeout(client_reader, &mut first, frame_read_timeout).await {
|
||||||
Ok(_) => {}
|
Ok(()) => {}
|
||||||
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
|
Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
|
||||||
Err(e) => return Err(ProxyError::Io(e)),
|
return Ok(None);
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
}
|
}
|
||||||
|
|
||||||
let quickack = (first[0] & 0x80) != 0;
|
let quickack = (first[0] & 0x80) != 0;
|
||||||
let len_words = if (first[0] & 0x7f) == 0x7f {
|
let len_words = if (first[0] & 0x7f) == 0x7f {
|
||||||
let mut ext = [0u8; 3];
|
let mut ext = [0u8; 3];
|
||||||
client_reader
|
read_exact_with_timeout(client_reader, &mut ext, frame_read_timeout).await?;
|
||||||
.read_exact(&mut ext)
|
|
||||||
.await
|
|
||||||
.map_err(ProxyError::Io)?;
|
|
||||||
u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize
|
u32::from_le_bytes([ext[0], ext[1], ext[2], 0]) as usize
|
||||||
} else {
|
} else {
|
||||||
(first[0] & 0x7f) as usize
|
(first[0] & 0x7f) as usize
|
||||||
@@ -682,10 +735,12 @@ where
|
|||||||
}
|
}
|
||||||
ProtoTag::Intermediate | ProtoTag::Secure => {
|
ProtoTag::Intermediate | ProtoTag::Secure => {
|
||||||
let mut len_buf = [0u8; 4];
|
let mut len_buf = [0u8; 4];
|
||||||
match client_reader.read_exact(&mut len_buf).await {
|
match read_exact_with_timeout(client_reader, &mut len_buf, frame_read_timeout).await {
|
||||||
Ok(_) => {}
|
Ok(()) => {}
|
||||||
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
|
Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
|
||||||
Err(e) => return Err(ProxyError::Io(e)),
|
return Ok(None);
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
}
|
}
|
||||||
let quickack = (len_buf[3] & 0x80) != 0;
|
let quickack = (len_buf[3] & 0x80) != 0;
|
||||||
(
|
(
|
||||||
@@ -737,18 +792,21 @@ where
|
|||||||
len
|
len
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut payload = vec![0u8; len];
|
let mut payload = buffer_pool.get();
|
||||||
client_reader
|
payload.clear();
|
||||||
.read_exact(&mut payload)
|
let current_cap = payload.capacity();
|
||||||
.await
|
if current_cap < len {
|
||||||
.map_err(ProxyError::Io)?;
|
payload.reserve(len - current_cap);
|
||||||
|
}
|
||||||
|
payload.resize(len, 0);
|
||||||
|
read_exact_with_timeout(client_reader, &mut payload[..len], frame_read_timeout).await?;
|
||||||
|
|
||||||
// Secure Intermediate: strip validated trailing padding bytes.
|
// Secure Intermediate: strip validated trailing padding bytes.
|
||||||
if proto_tag == ProtoTag::Secure {
|
if proto_tag == ProtoTag::Secure {
|
||||||
payload.truncate(secure_payload_len);
|
payload.truncate(secure_payload_len);
|
||||||
}
|
}
|
||||||
*frame_counter += 1;
|
*frame_counter += 1;
|
||||||
return Ok(Some((Bytes::from(payload), quickack)));
|
return Ok(Some((payload, quickack)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -940,82 +998,5 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
#[path = "middle_relay_security_tests.rs"]
|
||||||
use super::*;
|
mod security_tests;
|
||||||
use tokio::time::{Duration as TokioDuration, timeout};
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn should_yield_sender_only_on_budget_with_backlog() {
|
|
||||||
assert!(!should_yield_c2me_sender(0, true));
|
|
||||||
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true));
|
|
||||||
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false));
|
|
||||||
assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn enqueue_c2me_command_uses_try_send_fast_path() {
|
|
||||||
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(2);
|
|
||||||
enqueue_c2me_command(
|
|
||||||
&tx,
|
|
||||||
C2MeCommand::Data {
|
|
||||||
payload: Bytes::from_static(&[1, 2, 3]),
|
|
||||||
flags: 0,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let recv = timeout(TokioDuration::from_millis(50), rx.recv())
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.unwrap();
|
|
||||||
match recv {
|
|
||||||
C2MeCommand::Data { payload, flags } => {
|
|
||||||
assert_eq!(payload.as_ref(), &[1, 2, 3]);
|
|
||||||
assert_eq!(flags, 0);
|
|
||||||
}
|
|
||||||
C2MeCommand::Close => panic!("unexpected close command"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
|
|
||||||
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
|
|
||||||
tx.send(C2MeCommand::Data {
|
|
||||||
payload: Bytes::from_static(&[9]),
|
|
||||||
flags: 9,
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let tx2 = tx.clone();
|
|
||||||
let producer = tokio::spawn(async move {
|
|
||||||
enqueue_c2me_command(
|
|
||||||
&tx2,
|
|
||||||
C2MeCommand::Data {
|
|
||||||
payload: Bytes::from_static(&[7, 7]),
|
|
||||||
flags: 7,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
});
|
|
||||||
|
|
||||||
let _ = timeout(TokioDuration::from_millis(100), rx.recv())
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
producer.await.unwrap();
|
|
||||||
|
|
||||||
let recv = timeout(TokioDuration::from_millis(100), rx.recv())
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
.unwrap();
|
|
||||||
match recv {
|
|
||||||
C2MeCommand::Data { payload, flags } => {
|
|
||||||
assert_eq!(payload.as_ref(), &[7, 7]);
|
|
||||||
assert_eq!(flags, 7);
|
|
||||||
}
|
|
||||||
C2MeCommand::Close => panic!("unexpected close command"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
781
src/proxy/middle_relay_security_tests.rs
Normal file
781
src/proxy/middle_relay_security_tests.rs
Normal file
@@ -0,0 +1,781 @@
|
|||||||
|
use super::*;
|
||||||
|
use bytes::Bytes;
|
||||||
|
use crate::crypto::AesCtr;
|
||||||
|
use crate::crypto::SecureRandom;
|
||||||
|
use crate::stats::Stats;
|
||||||
|
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::AtomicU64;
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
use tokio::io::duplex;
|
||||||
|
use tokio::time::{Duration as TokioDuration, timeout};
|
||||||
|
|
||||||
|
fn make_pooled_payload(data: &[u8]) -> PooledBuffer {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4));
|
||||||
|
let mut payload = pool.get();
|
||||||
|
payload.resize(data.len(), 0);
|
||||||
|
payload[..data.len()].copy_from_slice(data);
|
||||||
|
payload
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_pooled_payload_from(pool: &Arc<BufferPool>, data: &[u8]) -> PooledBuffer {
|
||||||
|
let mut payload = pool.get();
|
||||||
|
payload.resize(data.len(), 0);
|
||||||
|
payload[..data.len()].copy_from_slice(data);
|
||||||
|
payload
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_yield_sender_only_on_budget_with_backlog() {
|
||||||
|
assert!(!should_yield_c2me_sender(0, true));
|
||||||
|
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true));
|
||||||
|
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false));
|
||||||
|
assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_c2me_command_uses_try_send_fast_path() {
|
||||||
|
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(2);
|
||||||
|
enqueue_c2me_command(
|
||||||
|
&tx,
|
||||||
|
C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload(&[1, 2, 3]),
|
||||||
|
flags: 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let recv = timeout(TokioDuration::from_millis(50), rx.recv())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
match recv {
|
||||||
|
C2MeCommand::Data { payload, flags } => {
|
||||||
|
assert_eq!(payload.as_ref(), &[1, 2, 3]);
|
||||||
|
assert_eq!(flags, 0);
|
||||||
|
}
|
||||||
|
C2MeCommand::Close => panic!("unexpected close command"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() {
|
||||||
|
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
|
||||||
|
tx.send(C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload(&[9]),
|
||||||
|
flags: 9,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tx2 = tx.clone();
|
||||||
|
let producer = tokio::spawn(async move {
|
||||||
|
enqueue_c2me_command(
|
||||||
|
&tx2,
|
||||||
|
C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload(&[7, 7]),
|
||||||
|
flags: 7,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
let _ = timeout(TokioDuration::from_millis(100), rx.recv())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
producer.await.unwrap();
|
||||||
|
|
||||||
|
let recv = timeout(TokioDuration::from_millis(100), rx.recv())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
|
match recv {
|
||||||
|
C2MeCommand::Data { payload, flags } => {
|
||||||
|
assert_eq!(payload.as_ref(), &[7, 7]);
|
||||||
|
assert_eq!(flags, 7);
|
||||||
|
}
|
||||||
|
C2MeCommand::Close => panic!("unexpected close command"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_c2me_command_closed_channel_recycles_payload() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(64, 4));
|
||||||
|
let payload = make_pooled_payload_from(&pool, &[1, 2, 3, 4]);
|
||||||
|
let (tx, rx) = mpsc::channel::<C2MeCommand>(1);
|
||||||
|
drop(rx);
|
||||||
|
|
||||||
|
let result = enqueue_c2me_command(
|
||||||
|
&tx,
|
||||||
|
C2MeCommand::Data {
|
||||||
|
payload,
|
||||||
|
flags: 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_err(), "closed queue must fail enqueue");
|
||||||
|
drop(result);
|
||||||
|
assert!(
|
||||||
|
pool.stats().pooled >= 1,
|
||||||
|
"payload must return to pool when enqueue fails on closed channel"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_c2me_command_full_then_closed_recycles_waiting_payload() {
|
||||||
|
let pool = Arc::new(BufferPool::with_config(64, 4));
|
||||||
|
let (tx, rx) = mpsc::channel::<C2MeCommand>(1);
|
||||||
|
|
||||||
|
tx.send(C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload_from(&pool, &[9]),
|
||||||
|
flags: 1,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tx2 = tx.clone();
|
||||||
|
let pool2 = pool.clone();
|
||||||
|
let blocked_send = tokio::spawn(async move {
|
||||||
|
enqueue_c2me_command(
|
||||||
|
&tx2,
|
||||||
|
C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload_from(&pool2, &[7, 7, 7]),
|
||||||
|
flags: 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(10)).await;
|
||||||
|
drop(rx);
|
||||||
|
|
||||||
|
let result = timeout(TokioDuration::from_secs(1), blocked_send)
|
||||||
|
.await
|
||||||
|
.expect("blocked send task must finish")
|
||||||
|
.expect("blocked send task must not panic");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
result.is_err(),
|
||||||
|
"closing receiver while sender is blocked must fail enqueue"
|
||||||
|
);
|
||||||
|
drop(result);
|
||||||
|
assert!(
|
||||||
|
pool.stats().pooled >= 2,
|
||||||
|
"both queued and blocked payloads must return to pool after channel close"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn desync_dedup_cache_is_bounded() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("desync dedup test lock must be available");
|
||||||
|
clear_desync_dedup_for_testing();
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
|
||||||
|
assert!(
|
||||||
|
should_emit_full_desync(key, false, now),
|
||||||
|
"unique keys up to cap must be tracked"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!should_emit_full_desync(u64::MAX, false, now),
|
||||||
|
"new key above cap must remain suppressed to avoid log amplification"
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!should_emit_full_desync(7, false, now),
|
||||||
|
"already tracked key inside dedup window must stay suppressed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn desync_dedup_full_cache_churn_stays_suppressed() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("desync dedup test lock must be available");
|
||||||
|
clear_desync_dedup_for_testing();
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 {
|
||||||
|
assert!(should_emit_full_desync(key, false, now));
|
||||||
|
}
|
||||||
|
|
||||||
|
for offset in 0..2048u64 {
|
||||||
|
assert!(
|
||||||
|
!should_emit_full_desync(u64::MAX - offset, false, now),
|
||||||
|
"fresh full-cache churn must remain suppressed under pressure"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_forensics_state() -> RelayForensicsState {
|
||||||
|
RelayForensicsState {
|
||||||
|
trace_id: 1,
|
||||||
|
conn_id: 2,
|
||||||
|
user: "test-user".to_string(),
|
||||||
|
peer: "127.0.0.1:50000".parse::<SocketAddr>().unwrap(),
|
||||||
|
peer_hash: 3,
|
||||||
|
started_at: Instant::now(),
|
||||||
|
bytes_c2me: 0,
|
||||||
|
bytes_me2c: Arc::new(AtomicU64::new(0)),
|
||||||
|
desync_all_full: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_crypto_reader(reader: tokio::io::DuplexStream) -> CryptoReader<tokio::io::DuplexStream> {
|
||||||
|
let key = [0u8; 32];
|
||||||
|
let iv = 0u128;
|
||||||
|
CryptoReader::new(reader, AesCtr::new(&key, iv))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_crypto_writer(writer: tokio::io::DuplexStream) -> CryptoWriter<tokio::io::DuplexStream> {
|
||||||
|
let key = [0u8; 32];
|
||||||
|
let iv = 0u128;
|
||||||
|
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
|
||||||
|
let key = [0u8; 32];
|
||||||
|
let iv = 0u128;
|
||||||
|
let mut cipher = AesCtr::new(&key, iv);
|
||||||
|
cipher.encrypt(plaintext)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_times_out_on_header_stall() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
let (reader, _writer) = duplex(1024);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let result = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
1024,
|
||||||
|
TokioDuration::from_millis(25),
|
||||||
|
&buffer_pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut),
|
||||||
|
"stalled header read must time out"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_times_out_on_payload_stall() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
let (reader, mut writer) = duplex(1024);
|
||||||
|
let encrypted_len = encrypt_for_reader(&[8, 0, 0, 0]);
|
||||||
|
writer.write_all(&encrypted_len).await.unwrap();
|
||||||
|
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let result = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
1024,
|
||||||
|
TokioDuration::from_millis(25),
|
||||||
|
&buffer_pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut),
|
||||||
|
"stalled payload body read must time out"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_large_intermediate_frame_is_exact() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
|
||||||
|
let (reader, mut writer) = duplex(262_144);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let payload_len = buffer_pool.buffer_size().saturating_mul(3).max(65_537);
|
||||||
|
let mut plaintext = Vec::with_capacity(4 + payload_len);
|
||||||
|
plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes());
|
||||||
|
plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(31)));
|
||||||
|
|
||||||
|
let encrypted = encrypt_for_reader(&plaintext);
|
||||||
|
writer.write_all(&encrypted).await.unwrap();
|
||||||
|
|
||||||
|
let read = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
payload_len + 16,
|
||||||
|
TokioDuration::from_secs(1),
|
||||||
|
&buffer_pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("payload read must succeed")
|
||||||
|
.expect("frame must be present");
|
||||||
|
|
||||||
|
let (frame, quickack) = read;
|
||||||
|
assert!(!quickack, "quickack flag must be unset");
|
||||||
|
assert_eq!(frame.len(), payload_len, "payload size must match wire length");
|
||||||
|
for (idx, byte) in frame.iter().enumerate() {
|
||||||
|
assert_eq!(*byte, (idx as u8).wrapping_mul(31));
|
||||||
|
}
|
||||||
|
assert_eq!(frame_counter, 1, "exactly one frame must be counted");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_secure_strips_tail_padding_bytes() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
|
||||||
|
let (reader, mut writer) = duplex(1024);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let payload = [0x11u8, 0x22, 0x33, 0x44, 0xaa, 0xbb, 0xcc, 0xdd];
|
||||||
|
let tail = [0xeeu8, 0xff, 0x99];
|
||||||
|
let wire_len = payload.len() + tail.len();
|
||||||
|
|
||||||
|
let mut plaintext = Vec::with_capacity(4 + wire_len);
|
||||||
|
plaintext.extend_from_slice(&(wire_len as u32).to_le_bytes());
|
||||||
|
plaintext.extend_from_slice(&payload);
|
||||||
|
plaintext.extend_from_slice(&tail);
|
||||||
|
let encrypted = encrypt_for_reader(&plaintext);
|
||||||
|
writer.write_all(&encrypted).await.unwrap();
|
||||||
|
|
||||||
|
let read = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Secure,
|
||||||
|
1024,
|
||||||
|
TokioDuration::from_secs(1),
|
||||||
|
&buffer_pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("secure payload read must succeed")
|
||||||
|
.expect("secure frame must be present");
|
||||||
|
|
||||||
|
let (frame, quickack) = read;
|
||||||
|
assert!(!quickack, "quickack flag must be unset");
|
||||||
|
assert_eq!(frame.as_ref(), &payload);
|
||||||
|
assert_eq!(frame_counter, 1, "one secure frame must be counted");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_secure_rejects_wire_len_below_4() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
|
||||||
|
let (reader, mut writer) = duplex(1024);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let mut plaintext = Vec::with_capacity(7);
|
||||||
|
plaintext.extend_from_slice(&3u32.to_le_bytes());
|
||||||
|
plaintext.extend_from_slice(&[1u8, 2, 3]);
|
||||||
|
let encrypted = encrypt_for_reader(&plaintext);
|
||||||
|
writer.write_all(&encrypted).await.unwrap();
|
||||||
|
|
||||||
|
let result = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Secure,
|
||||||
|
1024,
|
||||||
|
TokioDuration::from_secs(1),
|
||||||
|
&buffer_pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
matches!(result, Err(ProxyError::Proxy(ref msg)) if msg.contains("Frame too small: 3")),
|
||||||
|
"secure wire length below 4 must be fail-closed by the frame-too-small guard"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_intermediate_skips_zero_len_frame() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
|
||||||
|
let (reader, mut writer) = duplex(1024);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let payload = [7u8, 6, 5, 4, 3, 2, 1, 0];
|
||||||
|
let mut plaintext = Vec::with_capacity(4 + 4 + payload.len());
|
||||||
|
plaintext.extend_from_slice(&0u32.to_le_bytes());
|
||||||
|
plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes());
|
||||||
|
plaintext.extend_from_slice(&payload);
|
||||||
|
let encrypted = encrypt_for_reader(&plaintext);
|
||||||
|
writer.write_all(&encrypted).await.unwrap();
|
||||||
|
|
||||||
|
let read = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
1024,
|
||||||
|
TokioDuration::from_secs(1),
|
||||||
|
&buffer_pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("intermediate payload read must succeed")
|
||||||
|
.expect("frame must be present");
|
||||||
|
|
||||||
|
let (frame, quickack) = read;
|
||||||
|
assert!(!quickack, "quickack flag must be unset");
|
||||||
|
assert_eq!(frame.as_ref(), &payload);
|
||||||
|
assert_eq!(frame_counter, 1, "zero-length frame must be skipped");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_abridged_extended_len_sets_quickack() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
|
||||||
|
let (reader, mut writer) = duplex(4096);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let buffer_pool = Arc::new(BufferPool::new());
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let payload_len = 4 * 130;
|
||||||
|
let len_words = (payload_len / 4) as u32;
|
||||||
|
let mut plaintext = Vec::with_capacity(1 + 3 + payload_len);
|
||||||
|
plaintext.push(0xff | 0x80);
|
||||||
|
let lw = len_words.to_le_bytes();
|
||||||
|
plaintext.extend_from_slice(&lw[..3]);
|
||||||
|
plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_add(17)));
|
||||||
|
|
||||||
|
let encrypted = encrypt_for_reader(&plaintext);
|
||||||
|
writer.write_all(&encrypted).await.unwrap();
|
||||||
|
|
||||||
|
let read = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Abridged,
|
||||||
|
payload_len + 16,
|
||||||
|
TokioDuration::from_secs(1),
|
||||||
|
&buffer_pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("abridged payload read must succeed")
|
||||||
|
.expect("frame must be present");
|
||||||
|
|
||||||
|
let (frame, quickack) = read;
|
||||||
|
assert!(quickack, "quickack bit must be propagated from abridged header");
|
||||||
|
assert_eq!(frame.len(), payload_len);
|
||||||
|
assert_eq!(frame_counter, 1, "one abridged frame must be counted");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_returns_buffer_to_pool_after_emit() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
|
||||||
|
let pool = Arc::new(BufferPool::with_config(64, 8));
|
||||||
|
pool.preallocate(1);
|
||||||
|
assert_eq!(pool.stats().pooled, 1, "precondition: one pooled buffer");
|
||||||
|
|
||||||
|
let (reader, mut writer) = duplex(4096);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
// Force growth beyond default pool buffer size to catch ownership-take regressions.
|
||||||
|
let payload_len = 257usize;
|
||||||
|
let mut plaintext = Vec::with_capacity(4 + payload_len);
|
||||||
|
plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes());
|
||||||
|
plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(13)));
|
||||||
|
|
||||||
|
let encrypted = encrypt_for_reader(&plaintext);
|
||||||
|
writer.write_all(&encrypted).await.unwrap();
|
||||||
|
|
||||||
|
let _ = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
payload_len + 8,
|
||||||
|
TokioDuration::from_secs(1),
|
||||||
|
&pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("payload read must succeed")
|
||||||
|
.expect("frame must be present");
|
||||||
|
|
||||||
|
assert_eq!(frame_counter, 1);
|
||||||
|
let pool_stats = pool.stats();
|
||||||
|
assert!(
|
||||||
|
pool_stats.pooled >= 1,
|
||||||
|
"emitted payload buffer must be returned to pool to avoid pool drain"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn read_client_payload_keeps_pool_buffer_checked_out_until_frame_drop() {
|
||||||
|
let _guard = desync_dedup_test_lock()
|
||||||
|
.lock()
|
||||||
|
.expect("middle relay test lock must be available");
|
||||||
|
|
||||||
|
let pool = Arc::new(BufferPool::with_config(64, 2));
|
||||||
|
pool.preallocate(1);
|
||||||
|
assert_eq!(pool.stats().pooled, 1, "one pooled buffer must be available");
|
||||||
|
|
||||||
|
let (reader, mut writer) = duplex(1024);
|
||||||
|
let mut crypto_reader = make_crypto_reader(reader);
|
||||||
|
let stats = Stats::new();
|
||||||
|
let forensics = make_forensics_state();
|
||||||
|
let mut frame_counter = 0;
|
||||||
|
|
||||||
|
let payload = [0x41u8, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48];
|
||||||
|
let mut plaintext = Vec::with_capacity(4 + payload.len());
|
||||||
|
plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes());
|
||||||
|
plaintext.extend_from_slice(&payload);
|
||||||
|
let encrypted = encrypt_for_reader(&plaintext);
|
||||||
|
writer.write_all(&encrypted).await.unwrap();
|
||||||
|
|
||||||
|
let (frame, quickack) = read_client_payload(
|
||||||
|
&mut crypto_reader,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
1024,
|
||||||
|
TokioDuration::from_secs(1),
|
||||||
|
&pool,
|
||||||
|
&forensics,
|
||||||
|
&mut frame_counter,
|
||||||
|
&stats,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("payload read must succeed")
|
||||||
|
.expect("frame must be present");
|
||||||
|
|
||||||
|
assert!(!quickack);
|
||||||
|
assert_eq!(frame.as_ref(), &payload);
|
||||||
|
assert_eq!(
|
||||||
|
pool.stats().pooled,
|
||||||
|
0,
|
||||||
|
"buffer must stay checked out while frame payload is alive"
|
||||||
|
);
|
||||||
|
|
||||||
|
drop(frame);
|
||||||
|
assert!(
|
||||||
|
pool.stats().pooled >= 1,
|
||||||
|
"buffer must return to pool only after frame drop"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_c2me_close_unblocks_after_queue_drain() {
|
||||||
|
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
|
||||||
|
tx.send(C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload(&[0x41]),
|
||||||
|
flags: 0,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tx2 = tx.clone();
|
||||||
|
let close_task = tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await });
|
||||||
|
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(10)).await;
|
||||||
|
|
||||||
|
let first = timeout(TokioDuration::from_millis(100), rx.recv())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.expect("first queued item must be present");
|
||||||
|
assert!(matches!(first, C2MeCommand::Data { .. }));
|
||||||
|
|
||||||
|
close_task.await.unwrap().expect("close enqueue must succeed after drain");
|
||||||
|
|
||||||
|
let second = timeout(TokioDuration::from_millis(100), rx.recv())
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.expect("close command must follow after queue drain");
|
||||||
|
assert!(matches!(second, C2MeCommand::Close));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn enqueue_c2me_close_full_then_receiver_drop_fails_cleanly() {
|
||||||
|
let (tx, rx) = mpsc::channel::<C2MeCommand>(1);
|
||||||
|
tx.send(C2MeCommand::Data {
|
||||||
|
payload: make_pooled_payload(&[0x42]),
|
||||||
|
flags: 0,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tx2 = tx.clone();
|
||||||
|
let close_task = tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await });
|
||||||
|
|
||||||
|
tokio::time::sleep(TokioDuration::from_millis(10)).await;
|
||||||
|
drop(rx);
|
||||||
|
|
||||||
|
let result = timeout(TokioDuration::from_secs(1), close_task)
|
||||||
|
.await
|
||||||
|
.expect("close task must finish")
|
||||||
|
.expect("close task must not panic");
|
||||||
|
assert!(
|
||||||
|
result.is_err(),
|
||||||
|
"close enqueue must fail cleanly when receiver is dropped under pressure"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn process_me_writer_response_ack_obeys_flush_policy() {
|
||||||
|
let (writer_side, _reader_side) = duplex(1024);
|
||||||
|
let mut writer = make_crypto_writer(writer_side);
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
let mut frame_buf = Vec::new();
|
||||||
|
let stats = Stats::new();
|
||||||
|
let bytes_me2c = AtomicU64::new(0);
|
||||||
|
|
||||||
|
let immediate = process_me_writer_response(
|
||||||
|
MeResponse::Ack(0x11223344),
|
||||||
|
&mut writer,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
&rng,
|
||||||
|
&mut frame_buf,
|
||||||
|
&stats,
|
||||||
|
"user",
|
||||||
|
&bytes_me2c,
|
||||||
|
77,
|
||||||
|
true,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("ack response must be processed");
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
immediate,
|
||||||
|
MeWriterResponseOutcome::Continue {
|
||||||
|
frames: 1,
|
||||||
|
bytes: 4,
|
||||||
|
flush_immediately: true,
|
||||||
|
}
|
||||||
|
));
|
||||||
|
|
||||||
|
let delayed = process_me_writer_response(
|
||||||
|
MeResponse::Ack(0x55667788),
|
||||||
|
&mut writer,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
&rng,
|
||||||
|
&mut frame_buf,
|
||||||
|
&stats,
|
||||||
|
"user",
|
||||||
|
&bytes_me2c,
|
||||||
|
77,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("ack response must be processed");
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
delayed,
|
||||||
|
MeWriterResponseOutcome::Continue {
|
||||||
|
frames: 1,
|
||||||
|
bytes: 4,
|
||||||
|
flush_immediately: false,
|
||||||
|
}
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn process_me_writer_response_data_updates_byte_accounting() {
|
||||||
|
let (writer_side, _reader_side) = duplex(1024);
|
||||||
|
let mut writer = make_crypto_writer(writer_side);
|
||||||
|
let rng = SecureRandom::new();
|
||||||
|
let mut frame_buf = Vec::new();
|
||||||
|
let stats = Stats::new();
|
||||||
|
let bytes_me2c = AtomicU64::new(0);
|
||||||
|
|
||||||
|
let payload = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||||
|
let outcome = process_me_writer_response(
|
||||||
|
MeResponse::Data {
|
||||||
|
flags: 0,
|
||||||
|
data: Bytes::from(payload.clone()),
|
||||||
|
},
|
||||||
|
&mut writer,
|
||||||
|
ProtoTag::Intermediate,
|
||||||
|
&rng,
|
||||||
|
&mut frame_buf,
|
||||||
|
&stats,
|
||||||
|
"user",
|
||||||
|
&bytes_me2c,
|
||||||
|
88,
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("data response must be processed");
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
outcome,
|
||||||
|
MeWriterResponseOutcome::Continue {
|
||||||
|
frames: 1,
|
||||||
|
bytes,
|
||||||
|
flush_immediately: false,
|
||||||
|
} if bytes == payload.len()
|
||||||
|
));
|
||||||
|
assert_eq!(
|
||||||
|
bytes_me2c.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
payload.len() as u64,
|
||||||
|
"ME->C byte accounting must increase by emitted payload size"
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1256,6 +1256,33 @@ impl Stats {
|
|||||||
Self::touch_user_stats(stats.value());
|
Self::touch_user_stats(stats.value());
|
||||||
stats.curr_connects.fetch_add(1, Ordering::Relaxed);
|
stats.curr_connects.fetch_add(1, Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn try_acquire_user_curr_connects(&self, user: &str, limit: Option<u64>) -> bool {
|
||||||
|
if !self.telemetry_user_enabled() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.maybe_cleanup_user_stats();
|
||||||
|
let stats = self.user_stats.entry(user.to_string()).or_default();
|
||||||
|
Self::touch_user_stats(stats.value());
|
||||||
|
|
||||||
|
let counter = &stats.curr_connects;
|
||||||
|
let mut current = counter.load(Ordering::Relaxed);
|
||||||
|
loop {
|
||||||
|
if let Some(max) = limit && current >= max {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
match counter.compare_exchange_weak(
|
||||||
|
current,
|
||||||
|
current.saturating_add(1),
|
||||||
|
Ordering::Relaxed,
|
||||||
|
Ordering::Relaxed,
|
||||||
|
) {
|
||||||
|
Ok(_) => return true,
|
||||||
|
Err(actual) => current = actual,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn decrement_user_curr_connects(&self, user: &str) {
|
pub fn decrement_user_curr_connects(&self, user: &str) {
|
||||||
self.maybe_cleanup_user_stats();
|
self.maybe_cleanup_user_stats();
|
||||||
|
|||||||
@@ -513,6 +513,7 @@ impl FrameCodecTrait for SecureCodec {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use std::collections::HashSet;
|
||||||
use tokio_util::codec::{FramedRead, FramedWrite};
|
use tokio_util::codec::{FramedRead, FramedWrite};
|
||||||
use tokio::io::duplex;
|
use tokio::io::duplex;
|
||||||
use futures::{SinkExt, StreamExt};
|
use futures::{SinkExt, StreamExt};
|
||||||
@@ -630,4 +631,31 @@ mod tests {
|
|||||||
let result = codec.decode(&mut buf);
|
let result = codec.decode(&mut buf);
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn secure_codec_always_adds_padding_and_jitters_wire_length() {
|
||||||
|
let codec = SecureCodec::new(Arc::new(SecureRandom::new()));
|
||||||
|
let payload = Bytes::from_static(&[1, 2, 3, 4, 5, 6, 7, 8]);
|
||||||
|
let mut wire_lens = HashSet::new();
|
||||||
|
|
||||||
|
for _ in 0..64 {
|
||||||
|
let frame = Frame::new(payload.clone());
|
||||||
|
let mut out = BytesMut::new();
|
||||||
|
codec.encode(&frame, &mut out).unwrap();
|
||||||
|
|
||||||
|
assert!(out.len() >= 4 + payload.len() + 1);
|
||||||
|
let wire_len = u32::from_le_bytes([out[0], out[1], out[2], out[3]]) as usize;
|
||||||
|
assert!(
|
||||||
|
(payload.len() + 1..=payload.len() + 3).contains(&wire_len),
|
||||||
|
"Secure wire length must be payload+1..3, got {wire_len}"
|
||||||
|
);
|
||||||
|
assert_ne!(wire_len % 4, 0, "Secure wire length must be non-4-aligned");
|
||||||
|
wire_lens.insert(wire_len);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
wire_lens.len() >= 2,
|
||||||
|
"Secure padding should create observable wire-length jitter"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -115,59 +115,109 @@ async fn reap_draining_writers(
|
|||||||
pool: &Arc<MePool>,
|
pool: &Arc<MePool>,
|
||||||
warn_next_allowed: &mut HashMap<u64, Instant>,
|
warn_next_allowed: &mut HashMap<u64, Instant>,
|
||||||
) {
|
) {
|
||||||
|
if pool.draining_active_runtime() == 0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
let now_epoch_secs = MePool::now_epoch_secs();
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let drain_ttl_secs = pool.me_pool_drain_ttl_secs.load(std::sync::atomic::Ordering::Relaxed);
|
let drain_ttl_secs = pool.me_pool_drain_ttl_secs.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
let drain_threshold = pool
|
let drain_threshold = pool
|
||||||
.me_pool_drain_threshold
|
.me_pool_drain_threshold
|
||||||
.load(std::sync::atomic::Ordering::Relaxed);
|
.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
let writers = pool.writers.read().await.clone();
|
let mut draining_writers = {
|
||||||
let mut draining_writers = Vec::new();
|
let writers = pool.writers.read().await;
|
||||||
for writer in writers {
|
let mut draining_writers = Vec::<DrainingWriterSnapshot>::new();
|
||||||
if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) {
|
for writer in writers.iter() {
|
||||||
continue;
|
if !writer.draining.load(std::sync::atomic::Ordering::Relaxed) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
draining_writers.push(DrainingWriterSnapshot {
|
||||||
|
id: writer.id,
|
||||||
|
writer_dc: writer.writer_dc,
|
||||||
|
addr: writer.addr,
|
||||||
|
generation: writer.generation,
|
||||||
|
created_at: writer.created_at,
|
||||||
|
draining_started_at_epoch_secs: writer
|
||||||
|
.draining_started_at_epoch_secs
|
||||||
|
.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
drain_deadline_epoch_secs: writer
|
||||||
|
.drain_deadline_epoch_secs
|
||||||
|
.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
allow_drain_fallback: writer
|
||||||
|
.allow_drain_fallback
|
||||||
|
.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
let is_empty = pool.registry.is_writer_empty(writer.id).await;
|
draining_writers
|
||||||
if is_empty {
|
};
|
||||||
pool.remove_writer_and_close_clients(writer.id).await;
|
|
||||||
continue;
|
if draining_writers.is_empty() {
|
||||||
}
|
return;
|
||||||
draining_writers.push(writer);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize {
|
let draining_ids: Vec<u64> = draining_writers.iter().map(|writer| writer.id).collect();
|
||||||
draining_writers.sort_by(|left, right| {
|
let non_empty_writer_ids = pool.registry.non_empty_writer_ids(&draining_ids).await;
|
||||||
let left_started = left
|
let mut non_empty_draining_writers =
|
||||||
.draining_started_at_epoch_secs
|
Vec::<DrainingWriterSnapshot>::with_capacity(draining_writers.len());
|
||||||
.load(std::sync::atomic::Ordering::Relaxed);
|
for writer in draining_writers.drain(..) {
|
||||||
let right_started = right
|
if non_empty_writer_ids.contains(&writer.id) {
|
||||||
.draining_started_at_epoch_secs
|
non_empty_draining_writers.push(writer);
|
||||||
.load(std::sync::atomic::Ordering::Relaxed);
|
} else {
|
||||||
left_started
|
|
||||||
.cmp(&right_started)
|
|
||||||
.then_with(|| left.created_at.cmp(&right.created_at))
|
|
||||||
.then_with(|| left.id.cmp(&right.id))
|
|
||||||
});
|
|
||||||
let overflow = draining_writers.len().saturating_sub(drain_threshold as usize);
|
|
||||||
warn!(
|
|
||||||
draining_writers = draining_writers.len(),
|
|
||||||
me_pool_drain_threshold = drain_threshold,
|
|
||||||
removing_writers = overflow,
|
|
||||||
"ME draining writer threshold exceeded, force-closing oldest draining writers"
|
|
||||||
);
|
|
||||||
for writer in draining_writers.drain(..overflow) {
|
|
||||||
pool.stats.increment_pool_force_close_total();
|
|
||||||
pool.remove_writer_and_close_clients(writer.id).await;
|
pool.remove_writer_and_close_clients(writer.id).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
draining_writers = non_empty_draining_writers;
|
||||||
|
if draining_writers.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let overflow = if drain_threshold > 0 && draining_writers.len() > drain_threshold as usize {
|
||||||
|
draining_writers.len().saturating_sub(drain_threshold as usize)
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
let has_deadline_expired = draining_writers.iter().any(|writer| {
|
||||||
|
writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs
|
||||||
|
});
|
||||||
|
let can_drop_with_replacement = if overflow > 0 || has_deadline_expired {
|
||||||
|
pool.has_non_draining_writer_per_desired_dc_group().await
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
};
|
||||||
|
|
||||||
|
if overflow > 0 {
|
||||||
|
if can_drop_with_replacement {
|
||||||
|
draining_writers.sort_by(|left, right| {
|
||||||
|
left.draining_started_at_epoch_secs
|
||||||
|
.cmp(&right.draining_started_at_epoch_secs)
|
||||||
|
.then_with(|| left.created_at.cmp(&right.created_at))
|
||||||
|
.then_with(|| left.id.cmp(&right.id))
|
||||||
|
});
|
||||||
|
warn!(
|
||||||
|
draining_writers = draining_writers.len(),
|
||||||
|
me_pool_drain_threshold = drain_threshold,
|
||||||
|
removing_writers = overflow,
|
||||||
|
"ME draining writer threshold exceeded, force-closing oldest draining writers"
|
||||||
|
);
|
||||||
|
for writer in draining_writers.drain(..overflow) {
|
||||||
|
pool.stats.increment_pool_force_close_total();
|
||||||
|
pool.remove_writer_and_close_clients(writer.id).await;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
draining_writers = draining_writers.len(),
|
||||||
|
me_pool_drain_threshold = drain_threshold,
|
||||||
|
overflow,
|
||||||
|
"ME draining threshold exceeded, but replacement coverage is incomplete; keeping draining writers"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for writer in draining_writers {
|
for writer in draining_writers {
|
||||||
let drain_started_at_epoch_secs = writer
|
|
||||||
.draining_started_at_epoch_secs
|
|
||||||
.load(std::sync::atomic::Ordering::Relaxed);
|
|
||||||
if drain_ttl_secs > 0
|
if drain_ttl_secs > 0
|
||||||
&& drain_started_at_epoch_secs != 0
|
&& writer.draining_started_at_epoch_secs != 0
|
||||||
&& now_epoch_secs.saturating_sub(drain_started_at_epoch_secs) > drain_ttl_secs
|
&& now_epoch_secs.saturating_sub(writer.draining_started_at_epoch_secs) > drain_ttl_secs
|
||||||
&& should_emit_writer_warn(
|
&& should_emit_writer_warn(
|
||||||
warn_next_allowed,
|
warn_next_allowed,
|
||||||
writer.id,
|
writer.id,
|
||||||
@@ -182,21 +232,45 @@ async fn reap_draining_writers(
|
|||||||
generation = writer.generation,
|
generation = writer.generation,
|
||||||
drain_ttl_secs,
|
drain_ttl_secs,
|
||||||
force_close_secs = pool.me_pool_force_close_secs.load(std::sync::atomic::Ordering::Relaxed),
|
force_close_secs = pool.me_pool_force_close_secs.load(std::sync::atomic::Ordering::Relaxed),
|
||||||
allow_drain_fallback = writer.allow_drain_fallback.load(std::sync::atomic::Ordering::Relaxed),
|
allow_drain_fallback = writer.allow_drain_fallback,
|
||||||
"ME draining writer remains non-empty past drain TTL"
|
"ME draining writer remains non-empty past drain TTL"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let deadline_epoch_secs = writer
|
if writer.drain_deadline_epoch_secs != 0 && now_epoch_secs >= writer.drain_deadline_epoch_secs
|
||||||
.drain_deadline_epoch_secs
|
{
|
||||||
.load(std::sync::atomic::Ordering::Relaxed);
|
if can_drop_with_replacement {
|
||||||
if deadline_epoch_secs != 0 && now_epoch_secs >= deadline_epoch_secs {
|
warn!(writer_id = writer.id, "Drain timeout, force-closing");
|
||||||
warn!(writer_id = writer.id, "Drain timeout, force-closing");
|
pool.stats.increment_pool_force_close_total();
|
||||||
pool.stats.increment_pool_force_close_total();
|
pool.remove_writer_and_close_clients(writer.id).await;
|
||||||
pool.remove_writer_and_close_clients(writer.id).await;
|
} else if should_emit_writer_warn(
|
||||||
|
warn_next_allowed,
|
||||||
|
writer.id,
|
||||||
|
now,
|
||||||
|
pool.warn_rate_limit_duration(),
|
||||||
|
) {
|
||||||
|
warn!(
|
||||||
|
writer_id = writer.id,
|
||||||
|
writer_dc = writer.writer_dc,
|
||||||
|
endpoint = %writer.addr,
|
||||||
|
"Drain timeout reached, but replacement coverage is incomplete; keeping draining writer"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct DrainingWriterSnapshot {
|
||||||
|
id: u64,
|
||||||
|
writer_dc: i32,
|
||||||
|
addr: SocketAddr,
|
||||||
|
generation: u64,
|
||||||
|
created_at: Instant,
|
||||||
|
draining_started_at_epoch_secs: u64,
|
||||||
|
drain_deadline_epoch_secs: u64,
|
||||||
|
allow_drain_fallback: bool,
|
||||||
|
}
|
||||||
|
|
||||||
fn should_emit_writer_warn(
|
fn should_emit_writer_warn(
|
||||||
next_allowed: &mut HashMap<u64, Instant>,
|
next_allowed: &mut HashMap<u64, Instant>,
|
||||||
writer_id: u64,
|
writer_id: u64,
|
||||||
@@ -1330,6 +1404,15 @@ mod tests {
|
|||||||
me_pool_drain_threshold,
|
me_pool_drain_threshold,
|
||||||
..GeneralConfig::default()
|
..GeneralConfig::default()
|
||||||
};
|
};
|
||||||
|
let mut proxy_map_v4 = HashMap::new();
|
||||||
|
proxy_map_v4.insert(
|
||||||
|
2,
|
||||||
|
vec![(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 10)), 443)],
|
||||||
|
);
|
||||||
|
let decision = NetworkDecision {
|
||||||
|
ipv4_me: true,
|
||||||
|
..NetworkDecision::default()
|
||||||
|
};
|
||||||
MePool::new(
|
MePool::new(
|
||||||
None,
|
None,
|
||||||
vec![1u8; 32],
|
vec![1u8; 32],
|
||||||
@@ -1341,10 +1424,10 @@ mod tests {
|
|||||||
None,
|
None,
|
||||||
12,
|
12,
|
||||||
1200,
|
1200,
|
||||||
HashMap::new(),
|
proxy_map_v4,
|
||||||
HashMap::new(),
|
HashMap::new(),
|
||||||
None,
|
None,
|
||||||
NetworkDecision::default(),
|
decision,
|
||||||
None,
|
None,
|
||||||
Arc::new(SecureRandom::new()),
|
Arc::new(SecureRandom::new()),
|
||||||
Arc::new(Stats::default()),
|
Arc::new(Stats::default()),
|
||||||
@@ -1438,6 +1521,7 @@ mod tests {
|
|||||||
pool.writers.write().await.push(writer);
|
pool.writers.write().await.push(writer);
|
||||||
pool.registry.register_writer(writer_id, tx).await;
|
pool.registry.register_writer(writer_id, tx).await;
|
||||||
pool.conn_count.fetch_add(1, Ordering::Relaxed);
|
pool.conn_count.fetch_add(1, Ordering::Relaxed);
|
||||||
|
pool.increment_draining_active_runtime();
|
||||||
assert!(
|
assert!(
|
||||||
pool.registry
|
pool.registry
|
||||||
.bind_writer(
|
.bind_writer(
|
||||||
@@ -1455,8 +1539,56 @@ mod tests {
|
|||||||
conn_id
|
conn_id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn insert_live_writer(pool: &Arc<MePool>, writer_id: u64, writer_dc: i32) {
|
||||||
|
let (tx, _writer_rx) = mpsc::channel::<WriterCommand>(8);
|
||||||
|
let writer = MeWriter {
|
||||||
|
id: writer_id,
|
||||||
|
addr: SocketAddr::new(
|
||||||
|
IpAddr::V4(Ipv4Addr::new(203, 0, 113, (writer_id as u8).saturating_add(1))),
|
||||||
|
4000 + writer_id as u16,
|
||||||
|
),
|
||||||
|
source_ip: IpAddr::V4(Ipv4Addr::LOCALHOST),
|
||||||
|
writer_dc,
|
||||||
|
generation: 2,
|
||||||
|
contour: Arc::new(AtomicU8::new(WriterContour::Active.as_u8())),
|
||||||
|
created_at: Instant::now(),
|
||||||
|
tx: tx.clone(),
|
||||||
|
cancel: CancellationToken::new(),
|
||||||
|
degraded: Arc::new(AtomicBool::new(false)),
|
||||||
|
rtt_ema_ms_x10: Arc::new(AtomicU32::new(0)),
|
||||||
|
draining: Arc::new(AtomicBool::new(false)),
|
||||||
|
draining_started_at_epoch_secs: Arc::new(AtomicU64::new(0)),
|
||||||
|
drain_deadline_epoch_secs: Arc::new(AtomicU64::new(0)),
|
||||||
|
allow_drain_fallback: Arc::new(AtomicBool::new(false)),
|
||||||
|
};
|
||||||
|
pool.writers.write().await.push(writer);
|
||||||
|
pool.registry.register_writer(writer_id, tx).await;
|
||||||
|
pool.conn_count.fetch_add(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn reap_draining_writers_force_closes_oldest_over_threshold() {
|
async fn reap_draining_writers_force_closes_oldest_over_threshold() {
|
||||||
|
let pool = make_pool(2).await;
|
||||||
|
insert_live_writer(&pool, 1, 2).await;
|
||||||
|
assert!(pool.has_non_draining_writer_per_desired_dc_group().await);
|
||||||
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
|
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
|
||||||
|
let conn_b = insert_draining_writer(&pool, 20, now_epoch_secs.saturating_sub(20)).await;
|
||||||
|
let conn_c = insert_draining_writer(&pool, 30, now_epoch_secs.saturating_sub(10)).await;
|
||||||
|
let mut warn_next_allowed = HashMap::new();
|
||||||
|
|
||||||
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
|
let mut writer_ids: Vec<u64> = pool.writers.read().await.iter().map(|writer| writer.id).collect();
|
||||||
|
writer_ids.sort_unstable();
|
||||||
|
assert_eq!(writer_ids, vec![1, 20, 30]);
|
||||||
|
assert!(pool.registry.get_writer(conn_a).await.is_none());
|
||||||
|
assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20);
|
||||||
|
assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reap_draining_writers_does_not_force_close_overflow_without_replacement() {
|
||||||
let pool = make_pool(2).await;
|
let pool = make_pool(2).await;
|
||||||
let now_epoch_secs = MePool::now_epoch_secs();
|
let now_epoch_secs = MePool::now_epoch_secs();
|
||||||
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
|
let conn_a = insert_draining_writer(&pool, 10, now_epoch_secs.saturating_sub(30)).await;
|
||||||
@@ -1466,9 +1598,10 @@ mod tests {
|
|||||||
|
|
||||||
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
reap_draining_writers(&pool, &mut warn_next_allowed).await;
|
||||||
|
|
||||||
let writer_ids: Vec<u64> = pool.writers.read().await.iter().map(|writer| writer.id).collect();
|
let mut writer_ids: Vec<u64> = pool.writers.read().await.iter().map(|writer| writer.id).collect();
|
||||||
assert_eq!(writer_ids, vec![20, 30]);
|
writer_ids.sort_unstable();
|
||||||
assert!(pool.registry.get_writer(conn_a).await.is_none());
|
assert_eq!(writer_ids, vec![10, 20, 30]);
|
||||||
|
assert_eq!(pool.registry.get_writer(conn_a).await.unwrap().writer_id, 10);
|
||||||
assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20);
|
assert_eq!(pool.registry.get_writer(conn_b).await.unwrap().writer_id, 20);
|
||||||
assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30);
|
assert_eq!(pool.registry.get_writer(conn_c).await.unwrap().writer_id, 30);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -160,6 +160,7 @@ pub struct MePool {
|
|||||||
pub(super) refill_inflight: Arc<Mutex<HashSet<RefillEndpointKey>>>,
|
pub(super) refill_inflight: Arc<Mutex<HashSet<RefillEndpointKey>>>,
|
||||||
pub(super) refill_inflight_dc: Arc<Mutex<HashSet<RefillDcKey>>>,
|
pub(super) refill_inflight_dc: Arc<Mutex<HashSet<RefillDcKey>>>,
|
||||||
pub(super) conn_count: AtomicUsize,
|
pub(super) conn_count: AtomicUsize,
|
||||||
|
pub(super) draining_active_runtime: AtomicU64,
|
||||||
pub(super) stats: Arc<crate::stats::Stats>,
|
pub(super) stats: Arc<crate::stats::Stats>,
|
||||||
pub(super) generation: AtomicU64,
|
pub(super) generation: AtomicU64,
|
||||||
pub(super) active_generation: AtomicU64,
|
pub(super) active_generation: AtomicU64,
|
||||||
@@ -438,6 +439,7 @@ impl MePool {
|
|||||||
refill_inflight: Arc::new(Mutex::new(HashSet::new())),
|
refill_inflight: Arc::new(Mutex::new(HashSet::new())),
|
||||||
refill_inflight_dc: Arc::new(Mutex::new(HashSet::new())),
|
refill_inflight_dc: Arc::new(Mutex::new(HashSet::new())),
|
||||||
conn_count: AtomicUsize::new(0),
|
conn_count: AtomicUsize::new(0),
|
||||||
|
draining_active_runtime: AtomicU64::new(0),
|
||||||
generation: AtomicU64::new(1),
|
generation: AtomicU64::new(1),
|
||||||
active_generation: AtomicU64::new(1),
|
active_generation: AtomicU64::new(1),
|
||||||
warm_generation: AtomicU64::new(0),
|
warm_generation: AtomicU64::new(0),
|
||||||
@@ -690,6 +692,32 @@ impl MePool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) fn draining_active_runtime(&self) -> u64 {
|
||||||
|
self.draining_active_runtime.load(Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn increment_draining_active_runtime(&self) {
|
||||||
|
self.draining_active_runtime.fetch_add(1, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn decrement_draining_active_runtime(&self) {
|
||||||
|
let mut current = self.draining_active_runtime.load(Ordering::Relaxed);
|
||||||
|
loop {
|
||||||
|
if current == 0 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
match self.draining_active_runtime.compare_exchange_weak(
|
||||||
|
current,
|
||||||
|
current - 1,
|
||||||
|
Ordering::Relaxed,
|
||||||
|
Ordering::Relaxed,
|
||||||
|
) {
|
||||||
|
Ok(_) => break,
|
||||||
|
Err(actual) => current = actual,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(super) async fn key_selector(&self) -> u32 {
|
pub(super) async fn key_selector(&self) -> u32 {
|
||||||
self.proxy_secret.read().await.key_selector
|
self.proxy_secret.read().await.key_selector
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -141,6 +141,38 @@ impl MePool {
|
|||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) async fn has_non_draining_writer_per_desired_dc_group(&self) -> bool {
|
||||||
|
let desired_by_dc = self.desired_dc_endpoints().await;
|
||||||
|
let required_dcs: HashSet<i32> = desired_by_dc
|
||||||
|
.iter()
|
||||||
|
.filter_map(|(dc, endpoints)| {
|
||||||
|
if endpoints.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(*dc)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
if required_dcs.is_empty() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
let ws = self.writers.read().await;
|
||||||
|
let mut covered_dcs = HashSet::<i32>::with_capacity(required_dcs.len());
|
||||||
|
for writer in ws.iter() {
|
||||||
|
if writer.draining.load(Ordering::Relaxed) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if required_dcs.contains(&writer.writer_dc) {
|
||||||
|
covered_dcs.insert(writer.writer_dc);
|
||||||
|
if covered_dcs.len() == required_dcs.len() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
fn hardswap_warmup_connect_delay_ms(&self) -> u64 {
|
fn hardswap_warmup_connect_delay_ms(&self) -> u64 {
|
||||||
let min_ms = self.me_hardswap_warmup_delay_min_ms.load(Ordering::Relaxed);
|
let min_ms = self.me_hardswap_warmup_delay_min_ms.load(Ordering::Relaxed);
|
||||||
let max_ms = self.me_hardswap_warmup_delay_max_ms.load(Ordering::Relaxed);
|
let max_ms = self.me_hardswap_warmup_delay_max_ms.load(Ordering::Relaxed);
|
||||||
@@ -475,12 +507,30 @@ impl MePool {
|
|||||||
coverage_ratio = format_args!("{coverage_ratio:.3}"),
|
coverage_ratio = format_args!("{coverage_ratio:.3}"),
|
||||||
min_ratio = format_args!("{min_ratio:.3}"),
|
min_ratio = format_args!("{min_ratio:.3}"),
|
||||||
drain_timeout_secs,
|
drain_timeout_secs,
|
||||||
"ME reinit cycle covered; draining stale writers"
|
"ME reinit cycle covered; processing stale writers"
|
||||||
);
|
);
|
||||||
self.stats.increment_pool_swap_total();
|
self.stats.increment_pool_swap_total();
|
||||||
|
let can_drop_with_replacement = self
|
||||||
|
.has_non_draining_writer_per_desired_dc_group()
|
||||||
|
.await;
|
||||||
|
if can_drop_with_replacement {
|
||||||
|
info!(
|
||||||
|
stale_writers = stale_writer_ids.len(),
|
||||||
|
"ME reinit stale writers: replacement coverage ready, force-closing clients for fast rebind"
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
stale_writers = stale_writer_ids.len(),
|
||||||
|
"ME reinit stale writers: replacement coverage incomplete, keeping draining fallback"
|
||||||
|
);
|
||||||
|
}
|
||||||
for writer_id in stale_writer_ids {
|
for writer_id in stale_writer_ids {
|
||||||
self.mark_writer_draining_with_timeout(writer_id, drain_timeout, !hardswap)
|
self.mark_writer_draining_with_timeout(writer_id, drain_timeout, !hardswap)
|
||||||
.await;
|
.await;
|
||||||
|
if can_drop_with_replacement {
|
||||||
|
self.stats.increment_pool_force_close_total();
|
||||||
|
self.remove_writer_and_close_clients(writer_id).await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if hardswap {
|
if hardswap {
|
||||||
self.clear_pending_hardswap_state();
|
self.clear_pending_hardswap_state();
|
||||||
|
|||||||
@@ -514,6 +514,7 @@ impl MePool {
|
|||||||
let was_draining = w.draining.load(Ordering::Relaxed);
|
let was_draining = w.draining.load(Ordering::Relaxed);
|
||||||
if was_draining {
|
if was_draining {
|
||||||
self.stats.decrement_pool_drain_active();
|
self.stats.decrement_pool_drain_active();
|
||||||
|
self.decrement_draining_active_runtime();
|
||||||
}
|
}
|
||||||
self.stats.increment_me_writer_removed_total();
|
self.stats.increment_me_writer_removed_total();
|
||||||
w.cancel.cancel();
|
w.cancel.cancel();
|
||||||
@@ -572,6 +573,7 @@ impl MePool {
|
|||||||
.store(drain_deadline_epoch_secs, Ordering::Relaxed);
|
.store(drain_deadline_epoch_secs, Ordering::Relaxed);
|
||||||
if !already_draining {
|
if !already_draining {
|
||||||
self.stats.increment_pool_drain_active();
|
self.stats.increment_pool_drain_active();
|
||||||
|
self.increment_draining_active_runtime();
|
||||||
}
|
}
|
||||||
w.contour
|
w.contour
|
||||||
.store(WriterContour::Draining.as_u8(), Ordering::Relaxed);
|
.store(WriterContour::Draining.as_u8(), Ordering::Relaxed);
|
||||||
|
|||||||
@@ -436,6 +436,19 @@ impl ConnRegistry {
|
|||||||
.map(|s| s.is_empty())
|
.map(|s| s.is_empty())
|
||||||
.unwrap_or(true)
|
.unwrap_or(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) async fn non_empty_writer_ids(&self, writer_ids: &[u64]) -> HashSet<u64> {
|
||||||
|
let inner = self.inner.read().await;
|
||||||
|
let mut out = HashSet::<u64>::with_capacity(writer_ids.len());
|
||||||
|
for writer_id in writer_ids {
|
||||||
|
if let Some(conns) = inner.conns_for_writer.get(writer_id)
|
||||||
|
&& !conns.is_empty()
|
||||||
|
{
|
||||||
|
out.insert(*writer_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -634,4 +647,35 @@ mod tests {
|
|||||||
);
|
);
|
||||||
assert!(registry.get_writer(conn_id).await.is_none());
|
assert!(registry.get_writer(conn_id).await.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn non_empty_writer_ids_returns_only_writers_with_bound_clients() {
|
||||||
|
let registry = ConnRegistry::new();
|
||||||
|
let (conn_id, _rx) = registry.register().await;
|
||||||
|
let (writer_tx_a, _writer_rx_a) = tokio::sync::mpsc::channel(8);
|
||||||
|
let (writer_tx_b, _writer_rx_b) = tokio::sync::mpsc::channel(8);
|
||||||
|
registry.register_writer(10, writer_tx_a).await;
|
||||||
|
registry.register_writer(20, writer_tx_b).await;
|
||||||
|
|
||||||
|
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 443);
|
||||||
|
assert!(
|
||||||
|
registry
|
||||||
|
.bind_writer(
|
||||||
|
conn_id,
|
||||||
|
10,
|
||||||
|
ConnMeta {
|
||||||
|
target_dc: 2,
|
||||||
|
client_addr: addr,
|
||||||
|
our_addr: addr,
|
||||||
|
proto_flags: 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
);
|
||||||
|
|
||||||
|
let non_empty = registry.non_empty_writer_ids(&[10, 20, 30]).await;
|
||||||
|
assert!(non_empty.contains(&10));
|
||||||
|
assert!(!non_empty.contains(&20));
|
||||||
|
assert!(!non_empty.contains(&30));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user