mirror of
https://github.com/telemt/telemt.git
synced 2026-05-01 09:24:10 +03:00
Compare commits
25 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1c947e8e3 | ||
|
|
cfe01dced2 | ||
|
|
8520955a5f | ||
|
|
065786b839 | ||
|
|
f0e1a6cf1c | ||
|
|
236bbb4970 | ||
|
|
8ef5263fce | ||
|
|
893cef22e3 | ||
|
|
bdfa641843 | ||
|
|
007fc86189 | ||
|
|
10c9bcd97d | ||
|
|
8ab9405dca | ||
|
|
9412f089c0 | ||
|
|
4e57cee9b9 | ||
|
|
e217371dc8 | ||
|
|
d567dfe40b | ||
|
|
37c916056a | ||
|
|
2f2fe9d5d3 | ||
|
|
1df668144c | ||
|
|
8494429690 | ||
|
|
f25bb17b86 | ||
|
|
27b5d576c0 | ||
|
|
e78592ef9b | ||
|
|
4ed87d1946 | ||
|
|
635bea4de4 |
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -2791,7 +2791,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
|
||||
|
||||
[[package]]
|
||||
name = "telemt"
|
||||
version = "3.4.6"
|
||||
version = "3.4.9"
|
||||
dependencies = [
|
||||
"aes",
|
||||
"anyhow",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "telemt"
|
||||
version = "3.4.6"
|
||||
version = "3.4.9"
|
||||
edition = "2024"
|
||||
|
||||
[features]
|
||||
|
||||
@@ -99,7 +99,7 @@ Monero (XMR) directly:
|
||||
8Bk4tZEYPQWSypeD2hrUXG2rKbAKF16GqEN942ZdAP5cFdSqW6h4DwkP5cJMAdszzuPeHeHZPTyjWWFwzeFdjuci3ktfMoB
|
||||
```
|
||||
|
||||
All donations go toward infrastructure, development, and research.
|
||||
All donations go toward infrastructure, development and research
|
||||
|
||||
|
||||

|
||||
|
||||
@@ -82,6 +82,7 @@ pub(super) async fn load_config_from_disk(config_path: &Path) -> Result<ProxyCon
|
||||
.map_err(|e| ApiFailure::internal(format!("failed to load config: {}", e)))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(super) async fn save_config_to_disk(
|
||||
config_path: &Path,
|
||||
cfg: &ProxyConfig,
|
||||
@@ -106,6 +107,12 @@ pub(super) async fn save_access_sections_to_disk(
|
||||
if applied.contains(section) {
|
||||
continue;
|
||||
}
|
||||
if find_toml_table_bounds(&content, section.table_name()).is_none()
|
||||
&& access_section_is_empty(cfg, *section)
|
||||
{
|
||||
applied.push(*section);
|
||||
continue;
|
||||
}
|
||||
let rendered = render_access_section(cfg, *section)?;
|
||||
content = upsert_toml_table(&content, section.table_name(), &rendered);
|
||||
applied.push(*section);
|
||||
@@ -183,6 +190,17 @@ fn render_access_section(cfg: &ProxyConfig, section: AccessSection) -> Result<St
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn access_section_is_empty(cfg: &ProxyConfig, section: AccessSection) -> bool {
|
||||
match section {
|
||||
AccessSection::Users => cfg.access.users.is_empty(),
|
||||
AccessSection::UserAdTags => cfg.access.user_ad_tags.is_empty(),
|
||||
AccessSection::UserMaxTcpConns => cfg.access.user_max_tcp_conns.is_empty(),
|
||||
AccessSection::UserExpirations => cfg.access.user_expirations.is_empty(),
|
||||
AccessSection::UserDataQuota => cfg.access.user_data_quota.is_empty(),
|
||||
AccessSection::UserMaxUniqueIps => cfg.access.user_max_unique_ips.is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
fn serialize_table_body<T: Serialize>(value: &T) -> Result<String, ApiFailure> {
|
||||
toml::to_string(value)
|
||||
.map_err(|e| ApiFailure::internal(format!("failed to serialize access section: {}", e)))
|
||||
|
||||
@@ -28,6 +28,7 @@ mod config_store;
|
||||
mod events;
|
||||
mod http_utils;
|
||||
mod model;
|
||||
mod patch;
|
||||
mod runtime_edge;
|
||||
mod runtime_init;
|
||||
mod runtime_min;
|
||||
|
||||
@@ -5,6 +5,7 @@ use chrono::{DateTime, Utc};
|
||||
use hyper::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::patch::{Patch, patch_field};
|
||||
use crate::crypto::SecureRandom;
|
||||
|
||||
const MAX_USERNAME_LEN: usize = 64;
|
||||
@@ -455,6 +456,13 @@ pub(super) struct UserLinks {
|
||||
pub(super) classic: Vec<String>,
|
||||
pub(super) secure: Vec<String>,
|
||||
pub(super) tls: Vec<String>,
|
||||
pub(super) tls_domains: Vec<TlsDomainLink>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub(super) struct TlsDomainLink {
|
||||
pub(super) domain: String,
|
||||
pub(super) link: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
@@ -507,11 +515,16 @@ pub(super) struct CreateUserRequest {
|
||||
#[derive(Deserialize)]
|
||||
pub(super) struct PatchUserRequest {
|
||||
pub(super) secret: Option<String>,
|
||||
pub(super) user_ad_tag: Option<String>,
|
||||
pub(super) max_tcp_conns: Option<usize>,
|
||||
pub(super) expiration_rfc3339: Option<String>,
|
||||
pub(super) data_quota_bytes: Option<u64>,
|
||||
pub(super) max_unique_ips: Option<usize>,
|
||||
#[serde(default, deserialize_with = "patch_field")]
|
||||
pub(super) user_ad_tag: Patch<String>,
|
||||
#[serde(default, deserialize_with = "patch_field")]
|
||||
pub(super) max_tcp_conns: Patch<usize>,
|
||||
#[serde(default, deserialize_with = "patch_field")]
|
||||
pub(super) expiration_rfc3339: Patch<String>,
|
||||
#[serde(default, deserialize_with = "patch_field")]
|
||||
pub(super) data_quota_bytes: Patch<u64>,
|
||||
#[serde(default, deserialize_with = "patch_field")]
|
||||
pub(super) max_unique_ips: Patch<usize>,
|
||||
}
|
||||
|
||||
#[derive(Default, Deserialize)]
|
||||
@@ -530,6 +543,20 @@ pub(super) fn parse_optional_expiration(
|
||||
Ok(Some(parsed.with_timezone(&Utc)))
|
||||
}
|
||||
|
||||
pub(super) fn parse_patch_expiration(
|
||||
value: &Patch<String>,
|
||||
) -> Result<Patch<DateTime<Utc>>, ApiFailure> {
|
||||
match value {
|
||||
Patch::Unchanged => Ok(Patch::Unchanged),
|
||||
Patch::Remove => Ok(Patch::Remove),
|
||||
Patch::Set(raw) => {
|
||||
let parsed = DateTime::parse_from_rfc3339(raw)
|
||||
.map_err(|_| ApiFailure::bad_request("expiration_rfc3339 must be valid RFC3339"))?;
|
||||
Ok(Patch::Set(parsed.with_timezone(&Utc)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn is_valid_user_secret(secret: &str) -> bool {
|
||||
secret.len() == 32 && secret.chars().all(|c| c.is_ascii_hexdigit())
|
||||
}
|
||||
|
||||
130
src/api/patch.rs
Normal file
130
src/api/patch.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use serde::Deserialize;
|
||||
|
||||
/// Three-state field for JSON Merge Patch semantics on the `PATCH /v1/users/{user}`
|
||||
/// endpoint.
|
||||
///
|
||||
/// `Unchanged` is produced when the JSON body omits the field entirely and tells the
|
||||
/// handler to leave the corresponding configuration entry untouched. `Remove` is
|
||||
/// produced when the JSON body sets the field to `null` and instructs the handler to
|
||||
/// drop the entry from the corresponding access HashMap. `Set` carries an explicit
|
||||
/// new value, including zero, which is preserved verbatim in the configuration.
|
||||
#[derive(Debug)]
|
||||
pub(super) enum Patch<T> {
|
||||
Unchanged,
|
||||
Remove,
|
||||
Set(T),
|
||||
}
|
||||
|
||||
impl<T> Default for Patch<T> {
|
||||
fn default() -> Self {
|
||||
Self::Unchanged
|
||||
}
|
||||
}
|
||||
|
||||
/// Serde deserializer adapter for fields that follow JSON Merge Patch semantics.
|
||||
///
|
||||
/// Pair this with `#[serde(default, deserialize_with = "patch_field")]` on a
|
||||
/// `Patch<T>` field. An omitted field falls back to `Patch::Unchanged` via
|
||||
/// `Default`; an explicit JSON `null` becomes `Patch::Remove`; any other value
|
||||
/// becomes `Patch::Set(v)`.
|
||||
pub(super) fn patch_field<'de, D, T>(deserializer: D) -> Result<Patch<T>, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
T: serde::Deserialize<'de>,
|
||||
{
|
||||
Option::<T>::deserialize(deserializer).map(|opt| match opt {
|
||||
Some(value) => Patch::Set(value),
|
||||
None => Patch::Remove,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::api::model::{PatchUserRequest, parse_patch_expiration};
|
||||
use chrono::{TimeZone, Utc};
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct Holder {
|
||||
#[serde(default, deserialize_with = "patch_field")]
|
||||
value: Patch<u64>,
|
||||
}
|
||||
|
||||
fn parse(json: &str) -> Holder {
|
||||
serde_json::from_str(json).expect("valid json")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn omitted_field_yields_unchanged() {
|
||||
let h = parse("{}");
|
||||
assert!(matches!(h.value, Patch::Unchanged));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn explicit_null_yields_remove() {
|
||||
let h = parse(r#"{"value": null}"#);
|
||||
assert!(matches!(h.value, Patch::Remove));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn explicit_value_yields_set() {
|
||||
let h = parse(r#"{"value": 42}"#);
|
||||
assert!(matches!(h.value, Patch::Set(42)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn explicit_zero_yields_set_zero() {
|
||||
let h = parse(r#"{"value": 0}"#);
|
||||
assert!(matches!(h.value, Patch::Set(0)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_patch_expiration_passes_unchanged_and_remove_through() {
|
||||
assert!(matches!(
|
||||
parse_patch_expiration(&Patch::Unchanged),
|
||||
Ok(Patch::Unchanged)
|
||||
));
|
||||
assert!(matches!(
|
||||
parse_patch_expiration(&Patch::Remove),
|
||||
Ok(Patch::Remove)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_patch_expiration_parses_set_value() {
|
||||
let parsed =
|
||||
parse_patch_expiration(&Patch::Set("2030-01-02T03:04:05Z".into())).expect("valid");
|
||||
match parsed {
|
||||
Patch::Set(dt) => {
|
||||
assert_eq!(dt, Utc.with_ymd_and_hms(2030, 1, 2, 3, 4, 5).unwrap());
|
||||
}
|
||||
other => panic!("expected Patch::Set, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_patch_expiration_rejects_invalid_set_value() {
|
||||
assert!(parse_patch_expiration(&Patch::Set("not-a-date".into())).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn patch_user_request_deserializes_mixed_states() {
|
||||
let raw = r#"{
|
||||
"secret": "00112233445566778899aabbccddeeff",
|
||||
"max_tcp_conns": 0,
|
||||
"max_unique_ips": null,
|
||||
"data_quota_bytes": 1024
|
||||
}"#;
|
||||
let req: PatchUserRequest = serde_json::from_str(raw).expect("valid json");
|
||||
assert_eq!(
|
||||
req.secret.as_deref(),
|
||||
Some("00112233445566778899aabbccddeeff")
|
||||
);
|
||||
assert!(matches!(req.max_tcp_conns, Patch::Set(0)));
|
||||
assert!(matches!(req.max_unique_ips, Patch::Remove));
|
||||
assert!(matches!(req.data_quota_bytes, Patch::Set(1024)));
|
||||
assert!(matches!(req.expiration_rfc3339, Patch::Unchanged));
|
||||
assert!(matches!(req.user_ad_tag, Patch::Unchanged));
|
||||
}
|
||||
}
|
||||
227
src/api/users.rs
227
src/api/users.rs
@@ -8,14 +8,15 @@ use crate::stats::Stats;
|
||||
|
||||
use super::ApiShared;
|
||||
use super::config_store::{
|
||||
AccessSection, ensure_expected_revision, load_config_from_disk, save_access_sections_to_disk,
|
||||
save_config_to_disk,
|
||||
AccessSection, current_revision, ensure_expected_revision, load_config_from_disk,
|
||||
save_access_sections_to_disk,
|
||||
};
|
||||
use super::model::{
|
||||
ApiFailure, CreateUserRequest, CreateUserResponse, PatchUserRequest, RotateSecretRequest,
|
||||
UserInfo, UserLinks, is_valid_ad_tag, is_valid_user_secret, is_valid_username,
|
||||
parse_optional_expiration, random_user_secret,
|
||||
TlsDomainLink, UserInfo, UserLinks, is_valid_ad_tag, is_valid_user_secret, is_valid_username,
|
||||
parse_optional_expiration, parse_patch_expiration, random_user_secret,
|
||||
};
|
||||
use super::patch::Patch;
|
||||
|
||||
pub(super) async fn create_user(
|
||||
body: CreateUserRequest,
|
||||
@@ -175,6 +176,13 @@ pub(super) async fn patch_user(
|
||||
expected_revision: Option<String>,
|
||||
shared: &ApiShared,
|
||||
) -> Result<(UserInfo, String), ApiFailure> {
|
||||
let touches_users = body.secret.is_some();
|
||||
let touches_user_ad_tags = !matches!(&body.user_ad_tag, Patch::Unchanged);
|
||||
let touches_user_max_tcp_conns = !matches!(&body.max_tcp_conns, Patch::Unchanged);
|
||||
let touches_user_expirations = !matches!(&body.expiration_rfc3339, Patch::Unchanged);
|
||||
let touches_user_data_quota = !matches!(&body.data_quota_bytes, Patch::Unchanged);
|
||||
let touches_user_max_unique_ips = !matches!(&body.max_unique_ips, Patch::Unchanged);
|
||||
|
||||
if let Some(secret) = body.secret.as_ref()
|
||||
&& !is_valid_user_secret(secret)
|
||||
{
|
||||
@@ -182,14 +190,14 @@ pub(super) async fn patch_user(
|
||||
"secret must be exactly 32 hex characters",
|
||||
));
|
||||
}
|
||||
if let Some(ad_tag) = body.user_ad_tag.as_ref()
|
||||
if let Patch::Set(ad_tag) = &body.user_ad_tag
|
||||
&& !is_valid_ad_tag(ad_tag)
|
||||
{
|
||||
return Err(ApiFailure::bad_request(
|
||||
"user_ad_tag must be exactly 32 hex characters",
|
||||
));
|
||||
}
|
||||
let expiration = parse_optional_expiration(body.expiration_rfc3339.as_deref())?;
|
||||
let expiration = parse_patch_expiration(&body.expiration_rfc3339)?;
|
||||
let _guard = shared.mutation_lock.lock().await;
|
||||
let mut cfg = load_config_from_disk(&shared.config_path).await?;
|
||||
ensure_expected_revision(&shared.config_path, expected_revision.as_deref()).await?;
|
||||
@@ -205,38 +213,95 @@ pub(super) async fn patch_user(
|
||||
if let Some(secret) = body.secret {
|
||||
cfg.access.users.insert(user.to_string(), secret);
|
||||
}
|
||||
if let Some(ad_tag) = body.user_ad_tag {
|
||||
cfg.access.user_ad_tags.insert(user.to_string(), ad_tag);
|
||||
match body.user_ad_tag {
|
||||
Patch::Unchanged => {}
|
||||
Patch::Remove => {
|
||||
cfg.access.user_ad_tags.remove(user);
|
||||
}
|
||||
Patch::Set(ad_tag) => {
|
||||
cfg.access.user_ad_tags.insert(user.to_string(), ad_tag);
|
||||
}
|
||||
}
|
||||
if let Some(limit) = body.max_tcp_conns {
|
||||
cfg.access
|
||||
.user_max_tcp_conns
|
||||
.insert(user.to_string(), limit);
|
||||
match body.max_tcp_conns {
|
||||
Patch::Unchanged => {}
|
||||
Patch::Remove => {
|
||||
cfg.access.user_max_tcp_conns.remove(user);
|
||||
}
|
||||
Patch::Set(limit) => {
|
||||
cfg.access
|
||||
.user_max_tcp_conns
|
||||
.insert(user.to_string(), limit);
|
||||
}
|
||||
}
|
||||
if let Some(expiration) = expiration {
|
||||
cfg.access
|
||||
.user_expirations
|
||||
.insert(user.to_string(), expiration);
|
||||
match expiration {
|
||||
Patch::Unchanged => {}
|
||||
Patch::Remove => {
|
||||
cfg.access.user_expirations.remove(user);
|
||||
}
|
||||
Patch::Set(expiration) => {
|
||||
cfg.access
|
||||
.user_expirations
|
||||
.insert(user.to_string(), expiration);
|
||||
}
|
||||
}
|
||||
if let Some(quota) = body.data_quota_bytes {
|
||||
cfg.access.user_data_quota.insert(user.to_string(), quota);
|
||||
}
|
||||
|
||||
let mut updated_limit = None;
|
||||
if let Some(limit) = body.max_unique_ips {
|
||||
cfg.access
|
||||
.user_max_unique_ips
|
||||
.insert(user.to_string(), limit);
|
||||
updated_limit = Some(limit);
|
||||
match body.data_quota_bytes {
|
||||
Patch::Unchanged => {}
|
||||
Patch::Remove => {
|
||||
cfg.access.user_data_quota.remove(user);
|
||||
}
|
||||
Patch::Set(quota) => {
|
||||
cfg.access.user_data_quota.insert(user.to_string(), quota);
|
||||
}
|
||||
}
|
||||
// Capture how the per-user IP limit changed, so the in-memory ip_tracker
|
||||
// can be synced (set or removed) after the config is persisted.
|
||||
let max_unique_ips_change = match body.max_unique_ips {
|
||||
Patch::Unchanged => None,
|
||||
Patch::Remove => {
|
||||
cfg.access.user_max_unique_ips.remove(user);
|
||||
Some(None)
|
||||
}
|
||||
Patch::Set(limit) => {
|
||||
cfg.access
|
||||
.user_max_unique_ips
|
||||
.insert(user.to_string(), limit);
|
||||
Some(Some(limit))
|
||||
}
|
||||
};
|
||||
|
||||
cfg.validate()
|
||||
.map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?;
|
||||
|
||||
let revision = save_config_to_disk(&shared.config_path, &cfg).await?;
|
||||
let mut touched_sections = Vec::new();
|
||||
if touches_users {
|
||||
touched_sections.push(AccessSection::Users);
|
||||
}
|
||||
if touches_user_ad_tags {
|
||||
touched_sections.push(AccessSection::UserAdTags);
|
||||
}
|
||||
if touches_user_max_tcp_conns {
|
||||
touched_sections.push(AccessSection::UserMaxTcpConns);
|
||||
}
|
||||
if touches_user_expirations {
|
||||
touched_sections.push(AccessSection::UserExpirations);
|
||||
}
|
||||
if touches_user_data_quota {
|
||||
touched_sections.push(AccessSection::UserDataQuota);
|
||||
}
|
||||
if touches_user_max_unique_ips {
|
||||
touched_sections.push(AccessSection::UserMaxUniqueIps);
|
||||
}
|
||||
|
||||
let revision = if touched_sections.is_empty() {
|
||||
current_revision(&shared.config_path).await?
|
||||
} else {
|
||||
save_access_sections_to_disk(&shared.config_path, &cfg, &touched_sections).await?
|
||||
};
|
||||
drop(_guard);
|
||||
if let Some(limit) = updated_limit {
|
||||
shared.ip_tracker.set_user_limit(user, limit).await;
|
||||
match max_unique_ips_change {
|
||||
Some(Some(limit)) => shared.ip_tracker.set_user_limit(user, limit).await,
|
||||
Some(None) => shared.ip_tracker.remove_user_limit(user).await,
|
||||
None => {}
|
||||
}
|
||||
let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips();
|
||||
let users = users_from_config(
|
||||
@@ -404,6 +469,7 @@ pub(super) async fn users_from_config(
|
||||
classic: Vec::new(),
|
||||
secure: Vec::new(),
|
||||
tls: Vec::new(),
|
||||
tls_domains: Vec::new(),
|
||||
});
|
||||
users.push(UserInfo {
|
||||
in_runtime: runtime_cfg
|
||||
@@ -458,10 +524,12 @@ fn build_user_links(
|
||||
.public_port
|
||||
.unwrap_or(resolve_default_link_port(cfg));
|
||||
let tls_domains = resolve_tls_domains(cfg);
|
||||
let extra_tls_domains = resolve_extra_tls_domains(cfg);
|
||||
|
||||
let mut classic = Vec::new();
|
||||
let mut secure = Vec::new();
|
||||
let mut tls = Vec::new();
|
||||
let mut tls_domain_links = Vec::new();
|
||||
|
||||
for host in &hosts {
|
||||
if cfg.general.modes.classic {
|
||||
@@ -484,6 +552,17 @@ fn build_user_links(
|
||||
host, port, secret, domain_hex
|
||||
));
|
||||
}
|
||||
for domain in &extra_tls_domains {
|
||||
let domain_hex = hex::encode(domain);
|
||||
let link = format!(
|
||||
"tg://proxy?server={}&port={}&secret=ee{}{}",
|
||||
host, port, secret, domain_hex
|
||||
);
|
||||
tls_domain_links.push(TlsDomainLink {
|
||||
domain: (*domain).to_string(),
|
||||
link,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -491,6 +570,7 @@ fn build_user_links(
|
||||
classic,
|
||||
secure,
|
||||
tls,
|
||||
tls_domains: tls_domain_links,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -607,6 +687,19 @@ fn resolve_tls_domains(cfg: &ProxyConfig) -> Vec<&str> {
|
||||
domains
|
||||
}
|
||||
|
||||
fn resolve_extra_tls_domains(cfg: &ProxyConfig) -> Vec<&str> {
|
||||
let mut domains = Vec::with_capacity(cfg.censorship.tls_domains.len());
|
||||
let primary = cfg.censorship.tls_domain.as_str();
|
||||
for domain in &cfg.censorship.tls_domains {
|
||||
let value = domain.as_str();
|
||||
if value.is_empty() || value == primary || domains.contains(&value) {
|
||||
continue;
|
||||
}
|
||||
domains.push(value);
|
||||
}
|
||||
domains
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -696,4 +789,80 @@ mod tests {
|
||||
assert!(alice.in_runtime);
|
||||
assert!(!bob.in_runtime);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn users_from_config_returns_tls_link_for_each_tls_domain() {
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.access.users.insert(
|
||||
"alice".to_string(),
|
||||
"0123456789abcdef0123456789abcdef".to_string(),
|
||||
);
|
||||
cfg.general.modes.classic = false;
|
||||
cfg.general.modes.secure = false;
|
||||
cfg.general.modes.tls = true;
|
||||
cfg.general.links.public_host = Some("proxy.example.net".to_string());
|
||||
cfg.general.links.public_port = Some(443);
|
||||
cfg.censorship.tls_domain = "front-a.example.com".to_string();
|
||||
cfg.censorship.tls_domains = vec![
|
||||
"front-b.example.com".to_string(),
|
||||
"front-c.example.com".to_string(),
|
||||
"front-b.example.com".to_string(),
|
||||
"front-a.example.com".to_string(),
|
||||
];
|
||||
|
||||
let stats = Stats::new();
|
||||
let tracker = UserIpTracker::new();
|
||||
let users = users_from_config(&cfg, &stats, &tracker, None, None, None).await;
|
||||
let alice = users
|
||||
.iter()
|
||||
.find(|entry| entry.username == "alice")
|
||||
.expect("alice must be present");
|
||||
|
||||
assert_eq!(alice.links.tls.len(), 3);
|
||||
assert!(
|
||||
alice
|
||||
.links
|
||||
.tls
|
||||
.iter()
|
||||
.any(|link| link.ends_with(&hex::encode("front-a.example.com")))
|
||||
);
|
||||
assert!(
|
||||
alice
|
||||
.links
|
||||
.tls
|
||||
.iter()
|
||||
.any(|link| link.ends_with(&hex::encode("front-b.example.com")))
|
||||
);
|
||||
assert!(
|
||||
alice
|
||||
.links
|
||||
.tls
|
||||
.iter()
|
||||
.any(|link| link.ends_with(&hex::encode("front-c.example.com")))
|
||||
);
|
||||
assert_eq!(alice.links.tls_domains.len(), 2);
|
||||
assert!(
|
||||
alice
|
||||
.links
|
||||
.tls_domains
|
||||
.iter()
|
||||
.any(|entry| entry.domain == "front-b.example.com"
|
||||
&& entry.link.ends_with(&hex::encode("front-b.example.com")))
|
||||
);
|
||||
assert!(
|
||||
alice
|
||||
.links
|
||||
.tls_domains
|
||||
.iter()
|
||||
.any(|entry| entry.domain == "front-c.example.com"
|
||||
&& entry.link.ends_with(&hex::encode("front-c.example.com")))
|
||||
);
|
||||
assert!(
|
||||
!alice
|
||||
.links
|
||||
.tls_domains
|
||||
.iter()
|
||||
.any(|entry| entry.domain == "front-a.example.com")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,7 +102,7 @@ pub(crate) fn default_fake_cert_len() -> usize {
|
||||
}
|
||||
|
||||
pub(crate) fn default_tls_front_dir() -> String {
|
||||
"/etc/telemt/tlsfront".to_string()
|
||||
"tlsfront".to_string()
|
||||
}
|
||||
|
||||
pub(crate) fn default_replay_check_len() -> usize {
|
||||
@@ -568,7 +568,7 @@ pub(crate) fn default_beobachten_flush_secs() -> u64 {
|
||||
}
|
||||
|
||||
pub(crate) fn default_beobachten_file() -> String {
|
||||
"/etc/telemt/beobachten.txt".to_string()
|
||||
"beobachten.txt".to_string()
|
||||
}
|
||||
|
||||
pub(crate) fn default_tls_new_session_tickets() -> u8 {
|
||||
|
||||
@@ -22,7 +22,7 @@ pub struct UserIpTracker {
|
||||
limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>,
|
||||
limit_window: Arc<RwLock<Duration>>,
|
||||
last_compact_epoch_secs: Arc<AtomicU64>,
|
||||
cleanup_queue: Arc<Mutex<Vec<(String, IpAddr)>>>,
|
||||
cleanup_queue: Arc<Mutex<HashMap<(String, IpAddr), usize>>>,
|
||||
cleanup_drain_lock: Arc<AsyncMutex<()>>,
|
||||
}
|
||||
|
||||
@@ -45,17 +45,21 @@ impl UserIpTracker {
|
||||
limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)),
|
||||
limit_window: Arc::new(RwLock::new(Duration::from_secs(30))),
|
||||
last_compact_epoch_secs: Arc::new(AtomicU64::new(0)),
|
||||
cleanup_queue: Arc::new(Mutex::new(Vec::new())),
|
||||
cleanup_queue: Arc::new(Mutex::new(HashMap::new())),
|
||||
cleanup_drain_lock: Arc::new(AsyncMutex::new(())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) {
|
||||
match self.cleanup_queue.lock() {
|
||||
Ok(mut queue) => queue.push((user, ip)),
|
||||
Ok(mut queue) => {
|
||||
let count = queue.entry((user, ip)).or_insert(0);
|
||||
*count = count.saturating_add(1);
|
||||
}
|
||||
Err(poisoned) => {
|
||||
let mut queue = poisoned.into_inner();
|
||||
queue.push((user.clone(), ip));
|
||||
let count = queue.entry((user.clone(), ip)).or_insert(0);
|
||||
*count = count.saturating_add(1);
|
||||
self.cleanup_queue.clear_poison();
|
||||
tracing::warn!(
|
||||
"UserIpTracker cleanup_queue lock poisoned; recovered and enqueued IP cleanup for {} ({})",
|
||||
@@ -75,7 +79,9 @@ impl UserIpTracker {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cleanup_queue_mutex_for_tests(&self) -> Arc<Mutex<Vec<(String, IpAddr)>>> {
|
||||
pub(crate) fn cleanup_queue_mutex_for_tests(
|
||||
&self,
|
||||
) -> Arc<Mutex<HashMap<(String, IpAddr), usize>>> {
|
||||
Arc::clone(&self.cleanup_queue)
|
||||
}
|
||||
|
||||
@@ -105,11 +111,14 @@ impl UserIpTracker {
|
||||
};
|
||||
|
||||
let mut active_ips = self.active_ips.write().await;
|
||||
for (user, ip) in to_remove {
|
||||
for ((user, ip), pending_count) in to_remove {
|
||||
if pending_count == 0 {
|
||||
continue;
|
||||
}
|
||||
if let Some(user_ips) = active_ips.get_mut(&user) {
|
||||
if let Some(count) = user_ips.get_mut(&ip) {
|
||||
if *count > 1 {
|
||||
*count -= 1;
|
||||
if *count > pending_count {
|
||||
*count -= pending_count;
|
||||
} else {
|
||||
user_ips.remove(&ip);
|
||||
}
|
||||
@@ -269,9 +278,11 @@ impl UserIpTracker {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let is_new_ip = !user_recent.contains_key(&ip);
|
||||
|
||||
if let Some(limit) = limit {
|
||||
let active_limit_reached = user_active.len() >= limit;
|
||||
let recent_limit_reached = user_recent.len() >= limit;
|
||||
let recent_limit_reached = user_recent.len() >= limit && is_new_ip;
|
||||
let deny = match mode {
|
||||
UserMaxUniqueIpsMode::ActiveWindow => active_limit_reached,
|
||||
UserMaxUniqueIpsMode::TimeWindow => recent_limit_reached,
|
||||
@@ -851,4 +862,19 @@ mod tests {
|
||||
.unwrap_or(false);
|
||||
assert!(!stale_exists);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_time_window_allows_same_ip_reconnect() {
|
||||
let tracker = UserIpTracker::new();
|
||||
tracker.set_user_limit("test_user", 1).await;
|
||||
tracker
|
||||
.set_limit_policy(UserMaxUniqueIpsMode::TimeWindow, 1)
|
||||
.await;
|
||||
|
||||
let ip1 = test_ipv4(10, 4, 0, 1);
|
||||
|
||||
assert!(tracker.check_and_add("test_user", ip1).await.is_ok());
|
||||
tracker.remove_ip("test_user", ip1).await;
|
||||
assert!(tracker.check_and_add("test_user", ip1).await.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#![allow(clippy::items_after_test_module)]
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::sync::watch;
|
||||
@@ -17,7 +17,7 @@ use crate::transport::middle_proxy::{
|
||||
|
||||
pub(crate) fn resolve_runtime_config_path(
|
||||
config_path_cli: &str,
|
||||
startup_cwd: &std::path::Path,
|
||||
startup_cwd: &Path,
|
||||
config_path_explicit: bool,
|
||||
) -> PathBuf {
|
||||
if config_path_explicit {
|
||||
@@ -46,6 +46,39 @@ pub(crate) fn resolve_runtime_config_path(
|
||||
startup_cwd.join("config.toml")
|
||||
}
|
||||
|
||||
pub(crate) fn resolve_runtime_base_dir(
|
||||
config_path: &Path,
|
||||
startup_cwd: &Path,
|
||||
config_path_explicit: bool,
|
||||
data_path: Option<&Path>,
|
||||
) -> PathBuf {
|
||||
if let Some(path) = data_path {
|
||||
return normalize_runtime_dir(path, startup_cwd);
|
||||
}
|
||||
|
||||
if startup_cwd != Path::new("/") {
|
||||
return normalize_runtime_dir(startup_cwd, startup_cwd);
|
||||
}
|
||||
|
||||
if config_path_explicit
|
||||
&& let Some(parent) = config_path.parent()
|
||||
&& !parent.as_os_str().is_empty()
|
||||
{
|
||||
return normalize_runtime_dir(parent, startup_cwd);
|
||||
}
|
||||
|
||||
PathBuf::from("/etc/telemt")
|
||||
}
|
||||
|
||||
fn normalize_runtime_dir(path: &Path, startup_cwd: &Path) -> PathBuf {
|
||||
let absolute = if path.is_absolute() {
|
||||
path.to_path_buf()
|
||||
} else {
|
||||
startup_cwd.join(path)
|
||||
};
|
||||
absolute.canonicalize().unwrap_or(absolute)
|
||||
}
|
||||
|
||||
/// Parsed CLI arguments.
|
||||
pub(crate) struct CliArgs {
|
||||
pub config_path: String,
|
||||
@@ -231,9 +264,11 @@ fn print_help() {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use super::{
|
||||
expected_handshake_close_description, is_expected_handshake_eof, peer_close_description,
|
||||
resolve_runtime_config_path,
|
||||
resolve_runtime_base_dir, resolve_runtime_config_path,
|
||||
};
|
||||
use crate::error::{ProxyError, StreamError};
|
||||
|
||||
@@ -304,6 +339,91 @@ mod tests {
|
||||
let _ = std::fs::remove_dir(&startup_cwd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_runtime_base_dir_prefers_cli_data_path() {
|
||||
let nonce = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
let startup_cwd = std::env::temp_dir().join(format!("telemt_runtime_base_cwd_{nonce}"));
|
||||
let data_path = std::env::temp_dir().join(format!("telemt_runtime_base_data_{nonce}"));
|
||||
std::fs::create_dir_all(&startup_cwd).unwrap();
|
||||
std::fs::create_dir_all(&data_path).unwrap();
|
||||
|
||||
let resolved = resolve_runtime_base_dir(
|
||||
&startup_cwd.join("config.toml"),
|
||||
&startup_cwd,
|
||||
true,
|
||||
Some(&data_path),
|
||||
);
|
||||
assert_eq!(resolved, data_path.canonicalize().unwrap());
|
||||
|
||||
let _ = std::fs::remove_dir(&data_path);
|
||||
let _ = std::fs::remove_dir(&startup_cwd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_runtime_base_dir_uses_working_directory_before_explicit_config_parent() {
|
||||
let nonce = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
let startup_cwd = std::env::temp_dir().join(format!("telemt_runtime_base_start_{nonce}"));
|
||||
let config_dir = std::env::temp_dir().join(format!("telemt_runtime_base_cfg_{nonce}"));
|
||||
std::fs::create_dir_all(&startup_cwd).unwrap();
|
||||
std::fs::create_dir_all(&config_dir).unwrap();
|
||||
|
||||
let resolved =
|
||||
resolve_runtime_base_dir(&config_dir.join("telemt.toml"), &startup_cwd, true, None);
|
||||
assert_eq!(resolved, startup_cwd.canonicalize().unwrap());
|
||||
|
||||
let _ = std::fs::remove_dir(&config_dir);
|
||||
let _ = std::fs::remove_dir(&startup_cwd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_runtime_base_dir_uses_explicit_config_parent_from_root() {
|
||||
let nonce = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
let config_dir = std::env::temp_dir().join(format!("telemt_runtime_base_root_cfg_{nonce}"));
|
||||
std::fs::create_dir_all(&config_dir).unwrap();
|
||||
|
||||
let resolved =
|
||||
resolve_runtime_base_dir(&config_dir.join("telemt.toml"), Path::new("/"), true, None);
|
||||
assert_eq!(resolved, config_dir.canonicalize().unwrap());
|
||||
|
||||
let _ = std::fs::remove_dir(&config_dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_runtime_base_dir_uses_systemd_working_directory_before_etc() {
|
||||
let nonce = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
let startup_cwd = std::env::temp_dir().join(format!("telemt_runtime_base_systemd_{nonce}"));
|
||||
std::fs::create_dir_all(&startup_cwd).unwrap();
|
||||
|
||||
let resolved =
|
||||
resolve_runtime_base_dir(&startup_cwd.join("config.toml"), &startup_cwd, false, None);
|
||||
assert_eq!(resolved, startup_cwd.canonicalize().unwrap());
|
||||
|
||||
let _ = std::fs::remove_dir(&startup_cwd);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_runtime_base_dir_falls_back_to_etc_from_root() {
|
||||
let resolved = resolve_runtime_base_dir(
|
||||
Path::new("/etc/telemt/config.toml"),
|
||||
Path::new("/"),
|
||||
false,
|
||||
None,
|
||||
);
|
||||
assert_eq!(resolved, PathBuf::from("/etc/telemt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expected_handshake_eof_matches_connection_reset() {
|
||||
let err = ProxyError::Io(std::io::Error::from(std::io::ErrorKind::ConnectionReset));
|
||||
|
||||
@@ -47,7 +47,7 @@ use crate::stats::{ReplayChecker, Stats};
|
||||
use crate::stream::BufferPool;
|
||||
use crate::transport::UpstreamManager;
|
||||
use crate::transport::middle_proxy::MePool;
|
||||
use helpers::{parse_cli, resolve_runtime_config_path};
|
||||
use helpers::{parse_cli, resolve_runtime_base_dir, resolve_runtime_config_path};
|
||||
|
||||
#[cfg(unix)]
|
||||
use crate::daemon::{DaemonOptions, PidFile, drop_privileges};
|
||||
@@ -112,8 +112,51 @@ async fn run_telemt_core(
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
if let Some(ref data_path) = data_path
|
||||
&& !data_path.is_absolute()
|
||||
{
|
||||
eprintln!(
|
||||
"[telemt] data_path must be absolute: {}",
|
||||
data_path.display()
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
let mut config_path =
|
||||
resolve_runtime_config_path(&config_path_cli, &startup_cwd, config_path_explicit);
|
||||
let runtime_base_dir = resolve_runtime_base_dir(
|
||||
&config_path,
|
||||
&startup_cwd,
|
||||
config_path_explicit,
|
||||
data_path.as_deref(),
|
||||
);
|
||||
|
||||
if !runtime_base_dir.exists()
|
||||
&& let Err(e) = std::fs::create_dir_all(&runtime_base_dir)
|
||||
{
|
||||
eprintln!(
|
||||
"[telemt] Can't create runtime directory {}: {}",
|
||||
runtime_base_dir.display(),
|
||||
e
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
if !runtime_base_dir.is_dir() {
|
||||
eprintln!(
|
||||
"[telemt] Runtime path exists but is not a directory: {}",
|
||||
runtime_base_dir.display()
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
if let Err(e) = std::env::set_current_dir(&runtime_base_dir) {
|
||||
eprintln!(
|
||||
"[telemt] Can't use runtime directory {}: {}",
|
||||
runtime_base_dir.display(),
|
||||
e
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let mut config = match ProxyConfig::load(&config_path) {
|
||||
Ok(c) => c,
|
||||
@@ -156,16 +199,15 @@ async fn run_telemt_core(
|
||||
);
|
||||
}
|
||||
} else {
|
||||
let system_dir = std::path::Path::new("/etc/telemt");
|
||||
let system_config_path = system_dir.join("telemt.toml");
|
||||
let startup_config_path = startup_cwd.join("config.toml");
|
||||
let runtime_config_path = runtime_base_dir.join("telemt.toml");
|
||||
let fallback_config_path = runtime_base_dir.join("config.toml");
|
||||
let mut persisted = false;
|
||||
|
||||
if let Some(serialized) = serialized.as_ref() {
|
||||
match std::fs::create_dir_all(system_dir) {
|
||||
Ok(()) => match std::fs::write(&system_config_path, serialized) {
|
||||
match std::fs::create_dir_all(&runtime_base_dir) {
|
||||
Ok(()) => match std::fs::write(&runtime_config_path, serialized) {
|
||||
Ok(()) => {
|
||||
config_path = system_config_path;
|
||||
config_path = runtime_config_path;
|
||||
eprintln!(
|
||||
"[telemt] Created default config at {}",
|
||||
config_path.display()
|
||||
@@ -175,7 +217,7 @@ async fn run_telemt_core(
|
||||
Err(write_error) => {
|
||||
eprintln!(
|
||||
"[telemt] Warning: failed to write default config at {}: {}",
|
||||
system_config_path.display(),
|
||||
runtime_config_path.display(),
|
||||
write_error
|
||||
);
|
||||
}
|
||||
@@ -183,16 +225,16 @@ async fn run_telemt_core(
|
||||
Err(create_error) => {
|
||||
eprintln!(
|
||||
"[telemt] Warning: failed to create {}: {}",
|
||||
system_dir.display(),
|
||||
runtime_base_dir.display(),
|
||||
create_error
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if !persisted {
|
||||
match std::fs::write(&startup_config_path, serialized) {
|
||||
match std::fs::write(&fallback_config_path, serialized) {
|
||||
Ok(()) => {
|
||||
config_path = startup_config_path;
|
||||
config_path = fallback_config_path;
|
||||
eprintln!(
|
||||
"[telemt] Created default config at {}",
|
||||
config_path.display()
|
||||
@@ -202,7 +244,7 @@ async fn run_telemt_core(
|
||||
Err(write_error) => {
|
||||
eprintln!(
|
||||
"[telemt] Warning: failed to write default config at {}: {}",
|
||||
startup_config_path.display(),
|
||||
fallback_config_path.display(),
|
||||
write_error
|
||||
);
|
||||
}
|
||||
|
||||
@@ -10,6 +10,14 @@ use crate::tls_front::TlsFrontCache;
|
||||
use crate::tls_front::fetcher::TlsFetchStrategy;
|
||||
use crate::transport::UpstreamManager;
|
||||
|
||||
fn tls_fetch_host_for_domain(mask_host: &str, primary_tls_domain: &str, domain: &str) -> String {
|
||||
if mask_host.eq_ignore_ascii_case(primary_tls_domain) {
|
||||
domain.to_string()
|
||||
} else {
|
||||
mask_host.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn bootstrap_tls_front(
|
||||
config: &ProxyConfig,
|
||||
tls_domains: &[String],
|
||||
@@ -56,6 +64,7 @@ pub(crate) async fn bootstrap_tls_front(
|
||||
let cache_initial = cache.clone();
|
||||
let domains_initial = tls_domains.to_vec();
|
||||
let host_initial = mask_host.clone();
|
||||
let primary_initial = config.censorship.tls_domain.clone();
|
||||
let unix_sock_initial = mask_unix_sock.clone();
|
||||
let scope_initial = tls_fetch_scope.clone();
|
||||
let upstream_initial = upstream_manager.clone();
|
||||
@@ -64,7 +73,8 @@ pub(crate) async fn bootstrap_tls_front(
|
||||
let mut join = tokio::task::JoinSet::new();
|
||||
for domain in domains_initial {
|
||||
let cache_domain = cache_initial.clone();
|
||||
let host_domain = host_initial.clone();
|
||||
let host_domain =
|
||||
tls_fetch_host_for_domain(&host_initial, &primary_initial, &domain);
|
||||
let unix_sock_domain = unix_sock_initial.clone();
|
||||
let scope_domain = scope_initial.clone();
|
||||
let upstream_domain = upstream_initial.clone();
|
||||
@@ -117,6 +127,7 @@ pub(crate) async fn bootstrap_tls_front(
|
||||
let cache_refresh = cache.clone();
|
||||
let domains_refresh = tls_domains.to_vec();
|
||||
let host_refresh = mask_host.clone();
|
||||
let primary_refresh = config.censorship.tls_domain.clone();
|
||||
let unix_sock_refresh = mask_unix_sock.clone();
|
||||
let scope_refresh = tls_fetch_scope.clone();
|
||||
let upstream_refresh = upstream_manager.clone();
|
||||
@@ -130,7 +141,8 @@ pub(crate) async fn bootstrap_tls_front(
|
||||
let mut join = tokio::task::JoinSet::new();
|
||||
for domain in domains_refresh.clone() {
|
||||
let cache_domain = cache_refresh.clone();
|
||||
let host_domain = host_refresh.clone();
|
||||
let host_domain =
|
||||
tls_fetch_host_for_domain(&host_refresh, &primary_refresh, &domain);
|
||||
let unix_sock_domain = unix_sock_refresh.clone();
|
||||
let scope_domain = scope_refresh.clone();
|
||||
let upstream_domain = upstream_refresh.clone();
|
||||
@@ -186,3 +198,24 @@ pub(crate) async fn bootstrap_tls_front(
|
||||
|
||||
tls_cache
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::tls_fetch_host_for_domain;
|
||||
|
||||
#[test]
|
||||
fn tls_fetch_host_uses_each_domain_when_mask_host_is_primary_default() {
|
||||
assert_eq!(
|
||||
tls_fetch_host_for_domain("a.com", "a.com", "b.com"),
|
||||
"b.com"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tls_fetch_host_preserves_explicit_non_primary_mask_host() {
|
||||
assert_eq!(
|
||||
tls_fetch_host_for_domain("origin.example", "a.com", "b.com"),
|
||||
"origin.example"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,16 +31,24 @@ struct UserConnectionReservation {
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
user: String,
|
||||
ip: IpAddr,
|
||||
tracks_ip: bool,
|
||||
active: bool,
|
||||
}
|
||||
|
||||
impl UserConnectionReservation {
|
||||
fn new(stats: Arc<Stats>, ip_tracker: Arc<UserIpTracker>, user: String, ip: IpAddr) -> Self {
|
||||
fn new(
|
||||
stats: Arc<Stats>,
|
||||
ip_tracker: Arc<UserIpTracker>,
|
||||
user: String,
|
||||
ip: IpAddr,
|
||||
tracks_ip: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
stats,
|
||||
ip_tracker,
|
||||
user,
|
||||
ip,
|
||||
tracks_ip,
|
||||
active: true,
|
||||
}
|
||||
}
|
||||
@@ -49,7 +57,9 @@ impl UserConnectionReservation {
|
||||
if !self.active {
|
||||
return;
|
||||
}
|
||||
self.ip_tracker.remove_ip(&self.user, self.ip).await;
|
||||
if self.tracks_ip {
|
||||
self.ip_tracker.remove_ip(&self.user, self.ip).await;
|
||||
}
|
||||
self.active = false;
|
||||
self.stats.decrement_user_curr_connects(&self.user);
|
||||
}
|
||||
@@ -62,7 +72,9 @@ impl Drop for UserConnectionReservation {
|
||||
}
|
||||
self.active = false;
|
||||
self.stats.decrement_user_curr_connects(&self.user);
|
||||
self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip);
|
||||
if self.tracks_ip {
|
||||
self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1621,6 +1633,7 @@ impl RunningClientHandler {
|
||||
ip_tracker,
|
||||
user.to_string(),
|
||||
peer_addr.ip(),
|
||||
true,
|
||||
))
|
||||
}
|
||||
|
||||
@@ -1666,7 +1679,6 @@ impl RunningClientHandler {
|
||||
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
|
||||
Ok(()) => {
|
||||
ip_tracker.remove_ip(user, peer_addr.ip()).await;
|
||||
stats.decrement_user_curr_connects(user);
|
||||
}
|
||||
Err(reason) => {
|
||||
stats.decrement_user_curr_connects(user);
|
||||
@@ -1682,6 +1694,7 @@ impl RunningClientHandler {
|
||||
}
|
||||
}
|
||||
|
||||
stats.decrement_user_curr_connects(user);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,6 +55,7 @@ const STICKY_HINT_MAX_ENTRIES: usize = 65_536;
|
||||
const CANDIDATE_HINT_TRACK_CAP: usize = 64;
|
||||
const OVERLOAD_CANDIDATE_BUDGET_HINTED: usize = 16;
|
||||
const OVERLOAD_CANDIDATE_BUDGET_UNHINTED: usize = 8;
|
||||
const EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD: usize = 64;
|
||||
const RECENT_USER_RING_SCAN_LIMIT: usize = 32;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
@@ -551,6 +552,19 @@ fn auth_probe_note_saturation_in(shared: &ProxySharedState, now: Instant) {
|
||||
}
|
||||
}
|
||||
|
||||
fn auth_probe_note_expensive_invalid_scan_in(
|
||||
shared: &ProxySharedState,
|
||||
now: Instant,
|
||||
validation_checks: usize,
|
||||
overload: bool,
|
||||
) {
|
||||
if overload || validation_checks < EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD {
|
||||
return;
|
||||
}
|
||||
|
||||
auth_probe_note_saturation_in(shared, now);
|
||||
}
|
||||
|
||||
fn auth_probe_record_failure_in(shared: &ProxySharedState, peer_ip: IpAddr, now: Instant) {
|
||||
let peer_ip = normalize_auth_probe_ip(peer_ip);
|
||||
let state = &shared.handshake.auth_probe;
|
||||
@@ -1378,7 +1392,14 @@ where
|
||||
}
|
||||
|
||||
if !matched {
|
||||
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
|
||||
let failure_now = Instant::now();
|
||||
auth_probe_note_expensive_invalid_scan_in(
|
||||
shared,
|
||||
failure_now,
|
||||
validation_checks,
|
||||
overload,
|
||||
);
|
||||
auth_probe_record_failure_in(shared, peer.ip(), failure_now);
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
debug!(
|
||||
peer = %peer,
|
||||
@@ -1753,7 +1774,14 @@ where
|
||||
}
|
||||
|
||||
if !matched {
|
||||
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
|
||||
let failure_now = Instant::now();
|
||||
auth_probe_note_expensive_invalid_scan_in(
|
||||
shared,
|
||||
failure_now,
|
||||
validation_checks,
|
||||
overload,
|
||||
);
|
||||
auth_probe_record_failure_in(shared, peer.ip(), failure_now);
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
debug!(
|
||||
peer = %peer,
|
||||
|
||||
@@ -12,7 +12,7 @@ use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::sync::{mpsc, oneshot, watch};
|
||||
use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot, watch};
|
||||
use tokio::time::timeout;
|
||||
use tracing::{debug, info, trace, warn};
|
||||
|
||||
@@ -36,7 +36,11 @@ use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
|
||||
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
|
||||
|
||||
enum C2MeCommand {
|
||||
Data { payload: PooledBuffer, flags: u32 },
|
||||
Data {
|
||||
payload: PooledBuffer,
|
||||
flags: u32,
|
||||
_permit: OwnedSemaphorePermit,
|
||||
},
|
||||
Close,
|
||||
}
|
||||
|
||||
@@ -47,6 +51,8 @@ const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
|
||||
const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
|
||||
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
|
||||
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
|
||||
const C2ME_QUEUED_BYTE_PERMIT_UNIT: usize = 16 * 1024;
|
||||
const C2ME_QUEUED_PERMITS_PER_SLOT: usize = 4;
|
||||
const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1);
|
||||
const TINY_FRAME_DEBT_PER_TINY: u32 = 8;
|
||||
const TINY_FRAME_DEBT_LIMIT: u32 = 512;
|
||||
@@ -571,6 +577,43 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool
|
||||
has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET
|
||||
}
|
||||
|
||||
fn c2me_payload_permits(payload_len: usize) -> u32 {
|
||||
payload_len
|
||||
.max(1)
|
||||
.div_ceil(C2ME_QUEUED_BYTE_PERMIT_UNIT)
|
||||
.min(u32::MAX as usize) as u32
|
||||
}
|
||||
|
||||
fn c2me_queued_permit_budget(channel_capacity: usize, frame_limit: usize) -> usize {
|
||||
channel_capacity
|
||||
.saturating_mul(C2ME_QUEUED_PERMITS_PER_SLOT)
|
||||
.max(c2me_payload_permits(frame_limit) as usize)
|
||||
.max(1)
|
||||
}
|
||||
|
||||
async fn acquire_c2me_payload_permit(
|
||||
semaphore: &Arc<Semaphore>,
|
||||
payload_len: usize,
|
||||
send_timeout: Option<Duration>,
|
||||
stats: &Stats,
|
||||
) -> Result<OwnedSemaphorePermit> {
|
||||
let permits = c2me_payload_permits(payload_len);
|
||||
let acquire = semaphore.clone().acquire_many_owned(permits);
|
||||
match send_timeout {
|
||||
Some(send_timeout) => match timeout(send_timeout, acquire).await {
|
||||
Ok(Ok(permit)) => Ok(permit),
|
||||
Ok(Err(_)) => Err(ProxyError::Proxy("ME sender byte budget closed".into())),
|
||||
Err(_) => {
|
||||
stats.increment_me_c2me_send_timeout_total();
|
||||
Err(ProxyError::Proxy("ME sender byte budget timeout".into()))
|
||||
}
|
||||
},
|
||||
None => acquire
|
||||
.await
|
||||
.map_err(|_| ProxyError::Proxy("ME sender byte budget closed".into())),
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 {
|
||||
limit.saturating_add(overshoot)
|
||||
}
|
||||
@@ -1122,13 +1165,19 @@ where
|
||||
0 => None,
|
||||
timeout_ms => Some(Duration::from_millis(timeout_ms)),
|
||||
};
|
||||
let c2me_byte_budget = c2me_queued_permit_budget(c2me_channel_capacity, frame_limit);
|
||||
let c2me_byte_semaphore = Arc::new(Semaphore::new(c2me_byte_budget));
|
||||
let (c2me_tx, mut c2me_rx) = mpsc::channel::<C2MeCommand>(c2me_channel_capacity);
|
||||
let me_pool_c2me = me_pool.clone();
|
||||
let c2me_sender = tokio::spawn(async move {
|
||||
let mut sent_since_yield = 0usize;
|
||||
while let Some(cmd) = c2me_rx.recv().await {
|
||||
match cmd {
|
||||
C2MeCommand::Data { payload, flags } => {
|
||||
C2MeCommand::Data {
|
||||
payload,
|
||||
flags,
|
||||
_permit,
|
||||
} => {
|
||||
me_pool_c2me
|
||||
.send_proxy_req(
|
||||
conn_id,
|
||||
@@ -1624,11 +1673,29 @@ where
|
||||
if payload.len() >= 8 && payload[..8].iter().all(|b| *b == 0) {
|
||||
flags |= RPC_FLAG_NOT_ENCRYPTED;
|
||||
}
|
||||
let payload_permit = match acquire_c2me_payload_permit(
|
||||
&c2me_byte_semaphore,
|
||||
payload.len(),
|
||||
c2me_send_timeout,
|
||||
stats.as_ref(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(permit) => permit,
|
||||
Err(e) => {
|
||||
main_result = Err(e);
|
||||
break;
|
||||
}
|
||||
};
|
||||
// Keep client read loop lightweight: route heavy ME send path via a dedicated task.
|
||||
if enqueue_c2me_command_in(
|
||||
shared.as_ref(),
|
||||
&c2me_tx,
|
||||
C2MeCommand::Data { payload, flags },
|
||||
C2MeCommand::Data {
|
||||
payload,
|
||||
flags,
|
||||
_permit: payload_permit,
|
||||
},
|
||||
c2me_send_timeout,
|
||||
stats.as_ref(),
|
||||
)
|
||||
@@ -2262,7 +2329,7 @@ where
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
match response {
|
||||
MeResponse::Data { flags, data } => {
|
||||
MeResponse::Data { flags, data, .. } => {
|
||||
if batched {
|
||||
trace!(conn_id, bytes = data.len(), flags, "ME->C data (batched)");
|
||||
} else {
|
||||
|
||||
@@ -282,7 +282,7 @@ async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() {
|
||||
assert_eq!(stats.get_user_curr_connects(&user), 1);
|
||||
|
||||
let reservation =
|
||||
UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip);
|
||||
UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip, true);
|
||||
|
||||
// Drop the reservation synchronously without any tokio::spawn/await yielding!
|
||||
drop(reservation);
|
||||
@@ -320,6 +320,7 @@ async fn relay_task_abort_releases_user_gate_and_ip_reservation() {
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
ip_tracker.set_user_limit(user, 8).await;
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.access.user_max_tcp_conns.insert(user.to_string(), 8);
|
||||
@@ -437,6 +438,7 @@ async fn relay_cutover_releases_user_gate_and_ip_reservation() {
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
ip_tracker.set_user_limit(user, 8).await;
|
||||
|
||||
let mut cfg = ProxyConfig::default();
|
||||
cfg.access.user_max_tcp_conns.insert(user.to_string(), 8);
|
||||
@@ -958,6 +960,36 @@ async fn reservation_limit_failure_does_not_leak_curr_connects_counter() {
|
||||
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unlimited_unique_ip_user_is_still_visible_in_active_ip_tracker() {
|
||||
let user = "active-ip-observed-user";
|
||||
let config = crate::config::ProxyConfig::default();
|
||||
let stats = Arc::new(crate::stats::Stats::new());
|
||||
let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new());
|
||||
let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 200, 17)), 50017);
|
||||
|
||||
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
|
||||
user,
|
||||
&config,
|
||||
stats.clone(),
|
||||
peer,
|
||||
ip_tracker.clone(),
|
||||
)
|
||||
.await
|
||||
.expect("reservation without unique-IP limit must succeed");
|
||||
|
||||
assert_eq!(stats.get_user_curr_connects(user), 1);
|
||||
assert_eq!(
|
||||
ip_tracker.get_active_ip_count(user).await,
|
||||
1,
|
||||
"active IP observability must not depend on unique-IP limit enforcement"
|
||||
);
|
||||
|
||||
reservation.release().await;
|
||||
assert_eq!(stats.get_user_curr_connects(user), 0);
|
||||
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn short_tls_probe_is_masked_through_client_pipeline() {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
@@ -2879,6 +2911,7 @@ async fn explicit_reservation_release_cleans_user_and_ip_immediately() {
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
ip_tracker.set_user_limit(user, 4).await;
|
||||
|
||||
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
|
||||
user,
|
||||
@@ -2917,6 +2950,7 @@ async fn explicit_reservation_release_does_not_double_decrement_on_drop() {
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
ip_tracker.set_user_limit(user, 4).await;
|
||||
|
||||
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
|
||||
user,
|
||||
@@ -2947,6 +2981,7 @@ async fn drop_fallback_eventually_cleans_user_and_ip_reservation() {
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
ip_tracker.set_user_limit(user, 1).await;
|
||||
|
||||
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
|
||||
user,
|
||||
@@ -3029,6 +3064,7 @@ async fn release_abort_storm_does_not_leak_user_or_ip_reservations() {
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
ip_tracker.set_user_limit(user, ATTEMPTS + 16).await;
|
||||
|
||||
for idx in 0..ATTEMPTS {
|
||||
let peer = SocketAddr::new(
|
||||
@@ -3079,6 +3115,7 @@ async fn release_abort_loop_preserves_immediate_same_ip_reacquire() {
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
ip_tracker.set_user_limit(user, 1).await;
|
||||
|
||||
for _ in 0..ITERATIONS {
|
||||
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
|
||||
@@ -3137,6 +3174,7 @@ async fn adversarial_mixed_release_drop_abort_wave_converges_to_zero() {
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
ip_tracker.set_user_limit(user, RESERVATIONS + 8).await;
|
||||
|
||||
let mut reservations = Vec::with_capacity(RESERVATIONS);
|
||||
for idx in 0..RESERVATIONS {
|
||||
@@ -3217,6 +3255,8 @@ async fn parallel_users_abort_release_isolation_preserves_independent_cleanup()
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
ip_tracker.set_user_limit(user_a, 64).await;
|
||||
ip_tracker.set_user_limit(user_b, 64).await;
|
||||
|
||||
let mut tasks = tokio::task::JoinSet::new();
|
||||
for idx in 0..64usize {
|
||||
@@ -3278,6 +3318,7 @@ async fn concurrent_release_storm_leaves_zero_user_and_ip_footprint() {
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
ip_tracker.set_user_limit(user, RESERVATIONS + 8).await;
|
||||
|
||||
let mut reservations = Vec::with_capacity(RESERVATIONS);
|
||||
for idx in 0..RESERVATIONS {
|
||||
@@ -3332,6 +3373,7 @@ async fn relay_connect_error_releases_user_and_ip_before_return() {
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
ip_tracker.set_user_limit(user, 8).await;
|
||||
|
||||
let mut config = ProxyConfig::default();
|
||||
config.access.user_max_tcp_conns.insert(user.to_string(), 1);
|
||||
@@ -3427,6 +3469,7 @@ async fn mixed_release_and_drop_same_ip_preserves_counter_correctness() {
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
ip_tracker.set_user_limit(user, 1).await;
|
||||
|
||||
let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static(
|
||||
user,
|
||||
@@ -3487,6 +3530,7 @@ async fn drop_one_of_two_same_ip_reservations_keeps_ip_active() {
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
ip_tracker.set_user_limit(user, 1).await;
|
||||
|
||||
let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static(
|
||||
user,
|
||||
@@ -3696,6 +3740,7 @@ async fn cross_thread_drop_uses_captured_runtime_for_ip_cleanup() {
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
ip_tracker.set_user_limit(user, 8).await;
|
||||
|
||||
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
|
||||
user,
|
||||
@@ -3740,6 +3785,7 @@ async fn immediate_reacquire_after_cross_thread_drop_succeeds() {
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
ip_tracker.set_user_limit(user, 1).await;
|
||||
|
||||
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
|
||||
user,
|
||||
|
||||
@@ -1252,6 +1252,97 @@ async fn tls_overload_budget_limits_candidate_scan_depth() {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_expensive_invalid_scan_activates_saturation_budget() {
|
||||
let mut config = ProxyConfig::default();
|
||||
config.access.users.clear();
|
||||
config.access.ignore_time_skew = true;
|
||||
for idx in 0..80u8 {
|
||||
config.access.users.insert(
|
||||
format!("user-{idx}"),
|
||||
format!("{:032x}", u128::from(idx) + 1),
|
||||
);
|
||||
}
|
||||
config.rebuild_runtime_user_auth().unwrap();
|
||||
|
||||
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
let shared = ProxySharedState::new();
|
||||
let attacker_secret = [0xEFu8; 16];
|
||||
let handshake = make_valid_tls_handshake(&attacker_secret, 0);
|
||||
|
||||
let first_peer: SocketAddr = "198.51.100.214:44326".parse().unwrap();
|
||||
let first = handle_tls_handshake_with_shared(
|
||||
&handshake,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
first_peer,
|
||||
&config,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
shared.as_ref(),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(first, HandshakeResult::BadClient { .. }));
|
||||
assert!(
|
||||
auth_probe_saturation_state_for_testing_in_shared(shared.as_ref())
|
||||
.lock()
|
||||
.unwrap()
|
||||
.is_some(),
|
||||
"expensive invalid scan must activate global saturation"
|
||||
);
|
||||
assert_eq!(
|
||||
shared
|
||||
.handshake
|
||||
.auth_expensive_checks_total
|
||||
.load(Ordering::Relaxed),
|
||||
80,
|
||||
"first invalid probe preserves full first-hit compatibility before enabling saturation"
|
||||
);
|
||||
|
||||
{
|
||||
let mut saturation = auth_probe_saturation_state_for_testing_in_shared(shared.as_ref())
|
||||
.lock()
|
||||
.unwrap();
|
||||
let state = saturation.as_mut().expect("saturation must be present");
|
||||
state.blocked_until = Instant::now() + Duration::from_millis(200);
|
||||
}
|
||||
|
||||
let second_peer: SocketAddr = "198.51.100.215:44326".parse().unwrap();
|
||||
let second = handle_tls_handshake_with_shared(
|
||||
&handshake,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
second_peer,
|
||||
&config,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
shared.as_ref(),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(second, HandshakeResult::BadClient { .. }));
|
||||
assert_eq!(
|
||||
shared
|
||||
.handshake
|
||||
.auth_budget_exhausted_total
|
||||
.load(Ordering::Relaxed),
|
||||
1,
|
||||
"second invalid probe must be capped by overload budget"
|
||||
);
|
||||
assert_eq!(
|
||||
shared
|
||||
.handshake
|
||||
.auth_expensive_checks_total
|
||||
.load(Ordering::Relaxed),
|
||||
80 + OVERLOAD_CANDIDATE_BUDGET_UNHINTED as u64,
|
||||
"saturation budget must bound follow-up invalid scans"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mtproto_runtime_snapshot_prefers_preferred_user_hint() {
|
||||
let mut config = ProxyConfig::default();
|
||||
|
||||
@@ -70,6 +70,7 @@ async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() {
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: payload.clone(),
|
||||
route_permit: None,
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
@@ -139,6 +140,7 @@ async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() {
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xAA, 0xBB, 0xCC]),
|
||||
route_permit: None,
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
|
||||
@@ -12,6 +12,12 @@ fn make_pooled_payload(data: &[u8]) -> PooledBuffer {
|
||||
payload
|
||||
}
|
||||
|
||||
fn make_c2me_permit() -> tokio::sync::OwnedSemaphorePermit {
|
||||
Arc::new(tokio::sync::Semaphore::new(1))
|
||||
.try_acquire_many_owned(1)
|
||||
.expect("test permit must be available")
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "Tracking for M-04: Verify should_emit_full_desync returns true on first occurrence and false on duplicate within window"]
|
||||
fn should_emit_full_desync_filters_duplicates() {
|
||||
@@ -107,6 +113,7 @@ async fn c2me_channel_full_path_yields_then_sends() {
|
||||
tx.send(C2MeCommand::Data {
|
||||
payload: make_pooled_payload(&[0xAA]),
|
||||
flags: 1,
|
||||
_permit: make_c2me_permit(),
|
||||
})
|
||||
.await
|
||||
.expect("priming queue with one frame must succeed");
|
||||
@@ -119,6 +126,7 @@ async fn c2me_channel_full_path_yields_then_sends() {
|
||||
C2MeCommand::Data {
|
||||
payload: make_pooled_payload(&[0xBB, 0xCC]),
|
||||
flags: 2,
|
||||
_permit: make_c2me_permit(),
|
||||
},
|
||||
None,
|
||||
&stats,
|
||||
@@ -138,7 +146,7 @@ async fn c2me_channel_full_path_yields_then_sends() {
|
||||
.expect("receiver should observe primed frame")
|
||||
.expect("first queued command must exist");
|
||||
match first {
|
||||
C2MeCommand::Data { payload, flags } => {
|
||||
C2MeCommand::Data { payload, flags, .. } => {
|
||||
assert_eq!(payload.as_ref(), &[0xAA]);
|
||||
assert_eq!(flags, 1);
|
||||
}
|
||||
@@ -155,7 +163,7 @@ async fn c2me_channel_full_path_yields_then_sends() {
|
||||
.expect("receiver should observe backpressure-resumed frame")
|
||||
.expect("second queued command must exist");
|
||||
match second {
|
||||
C2MeCommand::Data { payload, flags } => {
|
||||
C2MeCommand::Data { payload, flags, .. } => {
|
||||
assert_eq!(payload.as_ref(), &[0xBB, 0xCC]);
|
||||
assert_eq!(flags, 2);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ use std::time::{Duration, Instant};
|
||||
use parking_lot::Mutex;
|
||||
|
||||
const CLEANUP_INTERVAL: Duration = Duration::from_secs(30);
|
||||
const MAX_BEOBACHTEN_ENTRIES: usize = 65_536;
|
||||
|
||||
#[derive(Default)]
|
||||
struct BeobachtenInner {
|
||||
@@ -48,12 +49,23 @@ impl BeobachtenStore {
|
||||
Self::cleanup_if_needed(&mut guard, now, ttl);
|
||||
|
||||
let key = (class.to_string(), ip);
|
||||
let entry = guard.entries.entry(key).or_insert(BeobachtenEntry {
|
||||
tries: 0,
|
||||
last_seen: now,
|
||||
});
|
||||
entry.tries = entry.tries.saturating_add(1);
|
||||
entry.last_seen = now;
|
||||
if let Some(entry) = guard.entries.get_mut(&key) {
|
||||
entry.tries = entry.tries.saturating_add(1);
|
||||
entry.last_seen = now;
|
||||
return;
|
||||
}
|
||||
|
||||
if guard.entries.len() >= MAX_BEOBACHTEN_ENTRIES {
|
||||
return;
|
||||
}
|
||||
|
||||
guard.entries.insert(
|
||||
key,
|
||||
BeobachtenEntry {
|
||||
tries: 1,
|
||||
last_seen: now,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub fn snapshot_text(&self, ttl: Duration) -> String {
|
||||
|
||||
@@ -649,6 +649,25 @@ async fn duplicate_cleanup_entries_do_not_break_future_admission() {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn duplicate_cleanup_entries_are_coalesced_until_drain() {
|
||||
let tracker = UserIpTracker::new();
|
||||
let ip = ip_from_idx(7150);
|
||||
|
||||
tracker.enqueue_cleanup("coalesced-cleanup".to_string(), ip);
|
||||
tracker.enqueue_cleanup("coalesced-cleanup".to_string(), ip);
|
||||
tracker.enqueue_cleanup("coalesced-cleanup".to_string(), ip);
|
||||
|
||||
assert_eq!(
|
||||
tracker.cleanup_queue_len_for_tests(),
|
||||
1,
|
||||
"duplicate queued cleanup entries must retain one allocation slot"
|
||||
);
|
||||
|
||||
tracker.drain_cleanup_queue().await;
|
||||
assert_eq!(tracker.cleanup_queue_len_for_tests(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn stress_repeated_queue_poison_recovery_preserves_admission_progress() {
|
||||
let tracker = UserIpTracker::new();
|
||||
|
||||
@@ -130,6 +130,14 @@ impl TlsFrontCache {
|
||||
warn!(file = %name, "Skipping TLS cache entry with invalid domain");
|
||||
continue;
|
||||
}
|
||||
if !cert_info_matches_domain(&cached) {
|
||||
warn!(
|
||||
file = %name,
|
||||
domain = %cached.domain,
|
||||
"Skipping TLS cache entry with mismatched certificate metadata"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
// fetched_at is skipped during deserialization; approximate with file mtime if available.
|
||||
if let Ok(meta) = entry.metadata().await
|
||||
&& let Ok(modified) = meta.modified()
|
||||
@@ -209,10 +217,100 @@ impl TlsFrontCache {
|
||||
}
|
||||
}
|
||||
|
||||
fn cert_info_matches_domain(cached: &CachedTlsData) -> bool {
|
||||
let Some(cert_info) = cached.cert_info.as_ref() else {
|
||||
return true;
|
||||
};
|
||||
if !cert_info.san_names.is_empty() {
|
||||
return cert_info
|
||||
.san_names
|
||||
.iter()
|
||||
.any(|name| dns_name_matches_domain(name, &cached.domain));
|
||||
}
|
||||
cert_info
|
||||
.subject_cn
|
||||
.as_deref()
|
||||
.map_or(true, |name| dns_name_matches_domain(name, &cached.domain))
|
||||
}
|
||||
|
||||
fn dns_name_matches_domain(pattern: &str, domain: &str) -> bool {
|
||||
let pattern = normalize_dns_name(pattern);
|
||||
let domain = normalize_dns_name(domain);
|
||||
if pattern == domain {
|
||||
return true;
|
||||
}
|
||||
|
||||
let Some(suffix) = pattern.strip_prefix("*.") else {
|
||||
return false;
|
||||
};
|
||||
let Some(prefix) = domain.strip_suffix(suffix) else {
|
||||
return false;
|
||||
};
|
||||
prefix.ends_with('.') && !prefix[..prefix.len() - 1].contains('.')
|
||||
}
|
||||
|
||||
fn normalize_dns_name(value: &str) -> String {
|
||||
value.trim().trim_end_matches('.').to_ascii_lowercase()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn cached_with_cert_info(
|
||||
domain: &str,
|
||||
subject_cn: Option<&str>,
|
||||
san_names: Vec<&str>,
|
||||
) -> CachedTlsData {
|
||||
CachedTlsData {
|
||||
server_hello_template: ParsedServerHello {
|
||||
version: [0x03, 0x03],
|
||||
random: [0u8; 32],
|
||||
session_id: Vec::new(),
|
||||
cipher_suite: [0x13, 0x01],
|
||||
compression: 0,
|
||||
extensions: Vec::new(),
|
||||
},
|
||||
cert_info: Some(crate::tls_front::types::ParsedCertificateInfo {
|
||||
not_after_unix: None,
|
||||
not_before_unix: None,
|
||||
issuer_cn: None,
|
||||
subject_cn: subject_cn.map(str::to_string),
|
||||
san_names: san_names.into_iter().map(str::to_string).collect(),
|
||||
}),
|
||||
cert_payload: None,
|
||||
app_data_records_sizes: vec![1024],
|
||||
total_app_data_len: 1024,
|
||||
behavior_profile: TlsBehaviorProfile::default(),
|
||||
fetched_at: SystemTime::now(),
|
||||
domain: domain.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cert_info_domain_match_accepts_exact_san() {
|
||||
let cached = cached_with_cert_info("b.com", Some("a.com"), vec!["b.com"]);
|
||||
assert!(cert_info_matches_domain(&cached));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cert_info_domain_match_rejects_wrong_san() {
|
||||
let cached = cached_with_cert_info("b.com", Some("b.com"), vec!["a.com"]);
|
||||
assert!(!cert_info_matches_domain(&cached));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cert_info_domain_match_accepts_single_label_wildcard_san() {
|
||||
let cached = cached_with_cert_info("api.b.com", None, vec!["*.b.com"]);
|
||||
assert!(cert_info_matches_domain(&cached));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cert_info_domain_match_rejects_multi_label_wildcard_san() {
|
||||
let cached = cached_with_cert_info("deep.api.b.com", None, vec!["*.b.com"]);
|
||||
assert!(!cert_info_matches_domain(&cached));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_take_full_cert_budget_for_ip_uses_ttl() {
|
||||
let cache = TlsFrontCache::new(&["example.com".to_string()], 1024, "tlsfront-test-cache");
|
||||
|
||||
@@ -46,6 +46,7 @@ mod send_adversarial_tests;
|
||||
mod wire;
|
||||
|
||||
use bytes::Bytes;
|
||||
use tokio::sync::OwnedSemaphorePermit;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
pub use config_updater::{
|
||||
@@ -68,9 +69,32 @@ pub use secret::{fetch_proxy_secret, fetch_proxy_secret_with_upstream};
|
||||
pub(crate) use selftest::{bnd_snapshot, timeskew_snapshot, upstream_bnd_snapshots};
|
||||
pub use wire::proto_flags_for_tag;
|
||||
|
||||
/// Holds D2C queued-byte capacity until a routed payload is consumed or dropped.
|
||||
pub struct RouteBytePermit {
|
||||
_permit: OwnedSemaphorePermit,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for RouteBytePermit {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RouteBytePermit").finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl RouteBytePermit {
|
||||
pub(crate) fn new(permit: OwnedSemaphorePermit) -> Self {
|
||||
Self { _permit: permit }
|
||||
}
|
||||
}
|
||||
|
||||
/// Response routed from middle proxy readers to client relay tasks.
|
||||
#[derive(Debug)]
|
||||
pub enum MeResponse {
|
||||
Data { flags: u32, data: Bytes },
|
||||
/// Downstream payload with its queued-byte reservation.
|
||||
Data {
|
||||
flags: u32,
|
||||
data: Bytes,
|
||||
route_permit: Option<RouteBytePermit>,
|
||||
},
|
||||
Ack(u32),
|
||||
Close,
|
||||
}
|
||||
|
||||
@@ -84,6 +84,7 @@ async fn route_data_with_retry(
|
||||
MeResponse::Data {
|
||||
flags,
|
||||
data: data.clone(),
|
||||
route_permit: None,
|
||||
},
|
||||
timeout_ms,
|
||||
)
|
||||
@@ -639,7 +640,7 @@ mod tests {
|
||||
let routed = route_data_with_retry(®, conn_id, 0, Bytes::from_static(b"a"), 20).await;
|
||||
assert!(matches!(routed, RouteResult::Routed));
|
||||
match rx.recv().await {
|
||||
Some(MeResponse::Data { flags, data }) => {
|
||||
Some(MeResponse::Data { flags, data, .. }) => {
|
||||
assert_eq!(flags, 0);
|
||||
assert_eq!(data, Bytes::from_static(b"a"));
|
||||
}
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use dashmap::DashMap;
|
||||
use tokio::sync::mpsc::error::TrySendError;
|
||||
use tokio::sync::{Mutex, mpsc};
|
||||
use tokio::sync::{Mutex, Semaphore, mpsc};
|
||||
|
||||
use super::MeResponse;
|
||||
use super::codec::WriterCommand;
|
||||
use super::{MeResponse, RouteBytePermit};
|
||||
|
||||
const ROUTE_BACKPRESSURE_BASE_TIMEOUT_MS: u64 = 25;
|
||||
const ROUTE_BACKPRESSURE_HIGH_TIMEOUT_MS: u64 = 120;
|
||||
const ROUTE_BACKPRESSURE_HIGH_WATERMARK_PCT: u8 = 80;
|
||||
const ROUTE_QUEUED_BYTE_PERMIT_UNIT: usize = 16 * 1024;
|
||||
const ROUTE_QUEUED_PERMITS_PER_SLOT: usize = 4;
|
||||
const ROUTE_QUEUED_MAX_FRAME_PERMITS: usize = 1024;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RouteResult {
|
||||
@@ -53,6 +57,7 @@ pub(super) struct WriterActivitySnapshot {
|
||||
|
||||
struct RoutingTable {
|
||||
map: DashMap<u64, mpsc::Sender<MeResponse>>,
|
||||
byte_budget: DashMap<u64, Arc<Semaphore>>,
|
||||
}
|
||||
|
||||
struct WriterTable {
|
||||
@@ -105,6 +110,7 @@ pub struct ConnRegistry {
|
||||
route_backpressure_base_timeout_ms: AtomicU64,
|
||||
route_backpressure_high_timeout_ms: AtomicU64,
|
||||
route_backpressure_high_watermark_pct: AtomicU8,
|
||||
route_byte_permits_per_conn: usize,
|
||||
}
|
||||
|
||||
impl ConnRegistry {
|
||||
@@ -116,10 +122,23 @@ impl ConnRegistry {
|
||||
}
|
||||
|
||||
pub fn with_route_channel_capacity(route_channel_capacity: usize) -> Self {
|
||||
let route_channel_capacity = route_channel_capacity.max(1);
|
||||
Self::with_route_limits(
|
||||
route_channel_capacity,
|
||||
Self::route_byte_permit_budget(route_channel_capacity),
|
||||
)
|
||||
}
|
||||
|
||||
fn with_route_limits(
|
||||
route_channel_capacity: usize,
|
||||
route_byte_permits_per_conn: usize,
|
||||
) -> Self {
|
||||
let start = rand::random::<u64>() | 1;
|
||||
let route_channel_capacity = route_channel_capacity.max(1);
|
||||
Self {
|
||||
routing: RoutingTable {
|
||||
map: DashMap::new(),
|
||||
byte_budget: DashMap::new(),
|
||||
},
|
||||
writers: WriterTable {
|
||||
map: DashMap::new(),
|
||||
@@ -131,15 +150,30 @@ impl ConnRegistry {
|
||||
inner: Mutex::new(BindingInner::new()),
|
||||
},
|
||||
next_id: AtomicU64::new(start),
|
||||
route_channel_capacity: route_channel_capacity.max(1),
|
||||
route_channel_capacity,
|
||||
route_backpressure_base_timeout_ms: AtomicU64::new(ROUTE_BACKPRESSURE_BASE_TIMEOUT_MS),
|
||||
route_backpressure_high_timeout_ms: AtomicU64::new(ROUTE_BACKPRESSURE_HIGH_TIMEOUT_MS),
|
||||
route_backpressure_high_watermark_pct: AtomicU8::new(
|
||||
ROUTE_BACKPRESSURE_HIGH_WATERMARK_PCT,
|
||||
),
|
||||
route_byte_permits_per_conn: route_byte_permits_per_conn.max(1),
|
||||
}
|
||||
}
|
||||
|
||||
fn route_data_permits(data_len: usize) -> u32 {
|
||||
data_len
|
||||
.max(1)
|
||||
.div_ceil(ROUTE_QUEUED_BYTE_PERMIT_UNIT)
|
||||
.min(u32::MAX as usize) as u32
|
||||
}
|
||||
|
||||
fn route_byte_permit_budget(route_channel_capacity: usize) -> usize {
|
||||
route_channel_capacity
|
||||
.saturating_mul(ROUTE_QUEUED_PERMITS_PER_SLOT)
|
||||
.max(ROUTE_QUEUED_MAX_FRAME_PERMITS)
|
||||
.max(1)
|
||||
}
|
||||
|
||||
pub fn route_channel_capacity(&self) -> usize {
|
||||
self.route_channel_capacity
|
||||
}
|
||||
@@ -149,6 +183,14 @@ impl ConnRegistry {
|
||||
Self::with_route_channel_capacity(4096)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn with_route_byte_permits_for_tests(
|
||||
route_channel_capacity: usize,
|
||||
route_byte_permits_per_conn: usize,
|
||||
) -> Self {
|
||||
Self::with_route_limits(route_channel_capacity, route_byte_permits_per_conn)
|
||||
}
|
||||
|
||||
pub fn update_route_backpressure_policy(
|
||||
&self,
|
||||
base_timeout_ms: u64,
|
||||
@@ -170,6 +212,10 @@ impl ConnRegistry {
|
||||
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
|
||||
let (tx, rx) = mpsc::channel(self.route_channel_capacity);
|
||||
self.routing.map.insert(id, tx);
|
||||
self.routing.byte_budget.insert(
|
||||
id,
|
||||
Arc::new(Semaphore::new(self.route_byte_permits_per_conn)),
|
||||
);
|
||||
(id, rx)
|
||||
}
|
||||
|
||||
@@ -186,6 +232,7 @@ impl ConnRegistry {
|
||||
/// Unregister connection, returning associated writer_id if any.
|
||||
pub async fn unregister(&self, id: u64) -> Option<u64> {
|
||||
self.routing.map.remove(&id);
|
||||
self.routing.byte_budget.remove(&id);
|
||||
self.hot_binding.map.remove(&id);
|
||||
let mut binding = self.binding.inner.lock().await;
|
||||
binding.meta.remove(&id);
|
||||
@@ -206,6 +253,64 @@ impl ConnRegistry {
|
||||
None
|
||||
}
|
||||
|
||||
async fn attach_route_byte_permit(
|
||||
&self,
|
||||
id: u64,
|
||||
resp: MeResponse,
|
||||
timeout_ms: Option<u64>,
|
||||
) -> std::result::Result<MeResponse, RouteResult> {
|
||||
let MeResponse::Data {
|
||||
flags,
|
||||
data,
|
||||
route_permit,
|
||||
} = resp
|
||||
else {
|
||||
return Ok(resp);
|
||||
};
|
||||
|
||||
if route_permit.is_some() {
|
||||
return Ok(MeResponse::Data {
|
||||
flags,
|
||||
data,
|
||||
route_permit,
|
||||
});
|
||||
}
|
||||
|
||||
let Some(semaphore) = self
|
||||
.routing
|
||||
.byte_budget
|
||||
.get(&id)
|
||||
.map(|entry| entry.value().clone())
|
||||
else {
|
||||
return Err(RouteResult::NoConn);
|
||||
};
|
||||
let permits = Self::route_data_permits(data.len());
|
||||
let permit = match timeout_ms {
|
||||
Some(0) => semaphore
|
||||
.try_acquire_many_owned(permits)
|
||||
.map_err(|_| RouteResult::QueueFullHigh)?,
|
||||
Some(timeout_ms) => {
|
||||
let acquire = semaphore.acquire_many_owned(permits);
|
||||
match tokio::time::timeout(Duration::from_millis(timeout_ms.max(1)), acquire).await
|
||||
{
|
||||
Ok(Ok(permit)) => permit,
|
||||
Ok(Err(_)) => return Err(RouteResult::ChannelClosed),
|
||||
Err(_) => return Err(RouteResult::QueueFullHigh),
|
||||
}
|
||||
}
|
||||
None => semaphore
|
||||
.acquire_many_owned(permits)
|
||||
.await
|
||||
.map_err(|_| RouteResult::ChannelClosed)?,
|
||||
};
|
||||
|
||||
Ok(MeResponse::Data {
|
||||
flags,
|
||||
data,
|
||||
route_permit: Some(RouteBytePermit::new(permit)),
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult {
|
||||
let tx = self.routing.map.get(&id).map(|entry| entry.value().clone());
|
||||
@@ -214,15 +319,23 @@ impl ConnRegistry {
|
||||
return RouteResult::NoConn;
|
||||
};
|
||||
|
||||
let base_timeout_ms = self
|
||||
.route_backpressure_base_timeout_ms
|
||||
.load(Ordering::Relaxed)
|
||||
.max(1);
|
||||
let resp = match self
|
||||
.attach_route_byte_permit(id, resp, Some(base_timeout_ms))
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(result) => return result,
|
||||
};
|
||||
|
||||
match tx.try_send(resp) {
|
||||
Ok(()) => RouteResult::Routed,
|
||||
Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed,
|
||||
Err(TrySendError::Full(resp)) => {
|
||||
// Absorb short bursts without dropping/closing the session immediately.
|
||||
let base_timeout_ms = self
|
||||
.route_backpressure_base_timeout_ms
|
||||
.load(Ordering::Relaxed)
|
||||
.max(1);
|
||||
let high_timeout_ms = self
|
||||
.route_backpressure_high_timeout_ms
|
||||
.load(Ordering::Relaxed)
|
||||
@@ -266,6 +379,10 @@ impl ConnRegistry {
|
||||
let Some(tx) = tx else {
|
||||
return RouteResult::NoConn;
|
||||
};
|
||||
let resp = match self.attach_route_byte_permit(id, resp, Some(0)).await {
|
||||
Ok(resp) => resp,
|
||||
Err(result) => return result,
|
||||
};
|
||||
|
||||
match tx.try_send(resp) {
|
||||
Ok(()) => RouteResult::Routed,
|
||||
@@ -289,6 +406,13 @@ impl ConnRegistry {
|
||||
let Some(tx) = tx else {
|
||||
return RouteResult::NoConn;
|
||||
};
|
||||
let resp = match self
|
||||
.attach_route_byte_permit(id, resp, Some(timeout_ms))
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp,
|
||||
Err(result) => return result,
|
||||
};
|
||||
|
||||
match tx.try_send(resp) {
|
||||
Ok(()) => RouteResult::Routed,
|
||||
@@ -541,8 +665,10 @@ impl ConnRegistry {
|
||||
mod tests {
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
|
||||
use super::ConnMeta;
|
||||
use super::ConnRegistry;
|
||||
use bytes::Bytes;
|
||||
|
||||
use super::{ConnMeta, ConnRegistry, RouteResult};
|
||||
use crate::transport::middle_proxy::MeResponse;
|
||||
|
||||
#[tokio::test]
|
||||
async fn writer_activity_snapshot_tracks_writer_and_dc_load() {
|
||||
@@ -608,6 +734,55 @@ mod tests {
|
||||
assert_eq!(snapshot.active_sessions_by_target_dc.get(&4), Some(&1));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn route_data_is_bounded_by_byte_permits_before_channel_capacity() {
|
||||
let registry = ConnRegistry::with_route_byte_permits_for_tests(4, 1);
|
||||
let (conn_id, mut rx) = registry.register().await;
|
||||
let routed = registry
|
||||
.route_nowait(
|
||||
conn_id,
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xAA]),
|
||||
route_permit: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(routed, RouteResult::Routed));
|
||||
|
||||
let blocked = registry
|
||||
.route_nowait(
|
||||
conn_id,
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xBB]),
|
||||
route_permit: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
assert!(
|
||||
matches!(blocked, RouteResult::QueueFullHigh),
|
||||
"byte budget must reject data before count capacity is exhausted"
|
||||
);
|
||||
|
||||
drop(rx.recv().await);
|
||||
|
||||
let routed_after_drain = registry
|
||||
.route_nowait(
|
||||
conn_id,
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xCC]),
|
||||
route_permit: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
assert!(
|
||||
matches!(routed_after_drain, RouteResult::Routed),
|
||||
"receiving queued data must release byte permits"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn bind_writer_rebinds_conn_atomically() {
|
||||
let registry = ConnRegistry::new();
|
||||
|
||||
Reference in New Issue
Block a user