Traffic Control + Fairness + Evaluating hard-idle timeout + Improve FakeTLS server-flight fidelity + PROXY Protocol V2 UNKNOWN/LOCAL misuse fixes: merge pull request #714 from telemt/flow

Traffic Control + Fairness + Evaluating hard-idle timeout + Improve FakeTLS server-flight fidelity + PROXY Protocol V2 UNKNOWN/LOCAL misuse fixes
This commit is contained in:
Alexey
2026-04-17 11:54:18 +03:00
committed by GitHub
29 changed files with 3785 additions and 191 deletions

2
Cargo.lock generated
View File

@@ -2780,7 +2780,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
[[package]]
name = "telemt"
version = "3.4.0"
version = "3.4.1"
dependencies = [
"aes",
"anyhow",

View File

@@ -1,6 +1,6 @@
[package]
name = "telemt"
version = "3.4.0"
version = "3.4.1"
edition = "2024"
[features]
@@ -98,4 +98,3 @@ harness = false
[profile.release]
lto = "fat"
codegen-units = 1

View File

@@ -0,0 +1,225 @@
# TLS Front Profile Fidelity
## Overview
This document describes how Telemt reuses captured TLS behavior in the FakeTLS server flight and how to validate the result on a real deployment.
When TLS front emulation is enabled, Telemt can capture useful server-side TLS behavior from the selected origin and reuse that behavior in the emulated success path. The goal is not to reproduce the origin byte-for-byte, but to reduce stable synthetic traits and make the emitted server flight structurally closer to the captured profile.
## Why this change exists
The project already captures useful server-side TLS behavior in the TLS front fetch path:
- `change_cipher_spec_count`
- `app_data_record_sizes`
- `ticket_record_sizes`
Before this change, the emulator used only part of that information. This left a gap between captured origin behavior and emitted FakeTLS server flight.
## What is implemented
- The emulator now replays the observed `ChangeCipherSpec` count from the fetched behavior profile.
- The emulator now replays observed ticket-like tail ApplicationData record sizes when raw or merged TLS profile data is available.
- The emulator now preserves more of the profiled encrypted-flight structure instead of collapsing it into a smaller synthetic shape.
- The emulator still falls back to the previous synthetic behavior when the cached profile does not contain raw TLS behavior information.
- Operator-configured `tls_new_session_tickets` still works as an additive fallback when the profile does not provide enough tail records.
## Practical benefit
- Reduced distinguishability between profiled origin TLS behavior and emulated TLS behavior.
- Lower chance of stable server-flight fingerprints caused by fixed CCS count or synthetic-only tail record sizes.
- Better reuse of already captured TLS profile data without changing MTProto logic, KDF routing, or transport architecture.
## Limitations
This mechanism does not aim to make Telemt byte-identical to the origin server.
It also does not change:
- MTProto business logic;
- KDF routing behavior;
- the overall transport architecture.
The practical goal is narrower:
- reuse more captured profile data;
- reduce fixed synthetic behavior in the server flight;
- preserve a valid FakeTLS success path while changing the emitted shape on the wire.
## Validation targets
- Correct count of emulated `ChangeCipherSpec` records.
- Correct replay of observed ticket-tail record sizes.
- No regression in existing ALPN and payload-placement behavior.
## How to validate the result
Recommended validation consists of two layers:
- focused unit and security tests for CCS-count replay and ticket-tail replay;
- real packet-capture comparison for a selected origin and a successful FakeTLS session.
When testing on the network, the expected result is:
- a valid FakeTLS and MTProto success path is preserved;
- the early encrypted server flight changes shape when richer profile data is available;
- the change is visible on the wire without changing MTProto logic or transport architecture.
This validation is intended to show better reuse of captured TLS profile data.
It is not intended to prove byte-level equivalence with the real origin server.
## How to test on a real deployment
The strongest practical validation is a side-by-side trace comparison between:
- a real TLS origin server used as `mask_host`;
- a Telemt FakeTLS success-path connection for the same SNI;
- optional captures from different Telemt builds or configurations.
The purpose of the comparison is to inspect the shape of the server flight:
- record order;
- count of `ChangeCipherSpec` records;
- count and grouping of early encrypted `ApplicationData` records;
- lengths of tail or continuation `ApplicationData` records.
## Recommended environment
Use a Linux host or Docker container for the cleanest reproduction.
Recommended setup:
1. One Telemt instance.
2. One real HTTPS origin as `mask_host`.
3. One Telegram client configured with an `ee` proxy link for the Telemt instance.
4. `tcpdump` or Wireshark available for capture analysis.
## Step-by-step test procedure
### 1. Prepare the origin
1. Choose a real HTTPS origin.
2. Set both `censorship.tls_domain` and `censorship.mask_host` to that hostname.
3. Confirm that a direct TLS request works:
```bash
openssl s_client -connect ORIGIN_IP:443 -servername YOUR_DOMAIN </dev/null
```
### 2. Configure Telemt
Use a configuration that enables:
- `censorship.mask = true`
- `censorship.tls_emulation = true`
- `censorship.mask_host`
- `censorship.mask_port`
Recommended for cleaner testing:
- keep `censorship.tls_new_session_tickets = 0`, so the result depends primarily on fetched profile data rather than operator-forced synthetic tail records;
- keep `censorship.tls_fetch.strict_route = true`, if cleaner provenance for captured profile data is important.
### 3. Refresh TLS profile data
1. Start Telemt.
2. Let it fetch TLS front profile data for the configured domain.
3. If `tls_front_dir` is persisted, confirm that the TLS front cache is populated.
Persisted cache artifacts are useful, but they are not required if packet captures already demonstrate the runtime result.
### 4. Capture a direct-origin trace
From a separate client host, connect directly to the origin:
```bash
openssl s_client -connect ORIGIN_IP:443 -servername YOUR_DOMAIN </dev/null
```
Capture with:
```bash
sudo tcpdump -i any -w origin-direct.pcap host ORIGIN_IP and port 443
```
### 5. Capture a Telemt FakeTLS success-path trace
Now connect to Telemt with a real Telegram client through an `ee` proxy link that targets the Telemt instance.
`openssl s_client` is useful for direct-origin capture and fallback sanity checks, but it does not exercise the successful FakeTLS and MTProto path.
Capture with:
```bash
sudo tcpdump -i any -w telemt-emulated.pcap host TELEMT_IP and port 443
```
### 6. Decode TLS record structure
Use `tshark` to print record-level structure:
```bash
tshark -r origin-direct.pcap -Y "tls.record" -T fields \
-e frame.number \
-e ip.src \
-e ip.dst \
-e tls.record.content_type \
-e tls.record.length
```
```bash
tshark -r telemt-emulated.pcap -Y "tls.record" -T fields \
-e frame.number \
-e ip.src \
-e ip.dst \
-e tls.record.content_type \
-e tls.record.length
```
Focus on the server flight after ClientHello:
- `22` = Handshake
- `20` = ChangeCipherSpec
- `23` = ApplicationData
### 7. Build a comparison table
A compact table like the following is usually enough:
| Path | CCS count | AppData count in first encrypted flight | Tail AppData lengths |
| --- | --- | --- | --- |
| Origin | `N` | `M` | `[a, b, ...]` |
| Telemt build A | `...` | `...` | `...` |
| Telemt build B | `...` | `...` | `...` |
The comparison should make it easy to see that:
- the FakeTLS success path remains valid;
- the early encrypted server flight changes when richer profile data is reused;
- the result is backed by packet evidence.
## Example capture set
One practical example of this workflow uses:
- `origin-direct-nginx.pcap`
- `telemt-ee-before-nginx.pcap`
- `telemt-ee-after-nginx.pcap`
Practical notes:
- `origin` was captured as a direct TLS 1.2 connection to `nginx.org`;
- `before` and `after` were captured on the Telemt FakeTLS success path with a real Telegram client;
- the first server-side FakeTLS response remains valid in both cases;
- the early encrypted server-flight segmentation differs between `before` and `after`, which is consistent with better reuse of captured profile data;
- this kind of result shows a wire-visible effect without breaking the success path, but it does not claim full indistinguishability from the origin.
## Stronger validation
For broader confidence, repeat the same comparison on:
1. one CDN-backed origin;
2. one regular nginx origin;
3. one origin with a multi-record encrypted flight and visible ticket-like tails.
If the same directional improvement appears across all three, confidence in the result will be much higher than for a single-origin example.

View File

@@ -0,0 +1,225 @@
# Fidelity TLS Front Profile
## Обзор
Этот документ описывает, как Telemt переиспользует захваченное TLS-поведение в FakeTLS server flight и как проверять результат на реальной инсталляции.
Когда включена TLS front emulation, Telemt может собирать полезное серверное TLS-поведение выбранного origin и использовать его в emulated success path. Цель здесь не в побайтном копировании origin, а в уменьшении устойчивых synthetic признаков и в том, чтобы emitted server flight был структурно ближе к захваченному profile.
## Зачем нужно это изменение
Проект уже умеет собирать полезное серверное TLS-поведение в пути TLS front fetch:
- `change_cipher_spec_count`
- `app_data_record_sizes`
- `ticket_record_sizes`
До этого изменения эмулятор использовал только часть этой информации. Из-за этого оставался разрыв между захваченным поведением origin и тем FakeTLS server flight, который реально уходил на провод.
## Что реализовано
- Эмулятор теперь воспроизводит наблюдаемое значение `ChangeCipherSpec` из полученного `behavior_profile`.
- Эмулятор теперь воспроизводит наблюдаемые размеры ticket-like tail ApplicationData records, когда доступны raw или merged TLS profile data.
- Эмулятор теперь сохраняет больше структуры профилированного encrypted flight, а не схлопывает его в более маленькую synthetic форму.
- Для профилей без raw TLS behavior по-прежнему сохраняется прежний synthetic fallback.
- Операторский `tls_new_session_tickets` по-прежнему работает как дополнительный fallback, если профиль не даёт достаточного количества tail records.
## Практическая польза
- Снижается различимость между профилированным origin TLS-поведением и эмулируемым TLS-поведением.
- Уменьшается шанс устойчивых server-flight fingerprint, вызванных фиксированным CCS count или полностью synthetic tail record sizes.
- Уже собранные TLS profile data используются лучше, без изменения MTProto logic, KDF routing или transport architecture.
## Ограничения
Этот механизм не ставит целью сделать Telemt побайтно идентичным origin server.
Он также не меняет:
- MTProto business logic;
- поведение KDF routing;
- общую transport architecture.
Практическая цель уже:
- использовать больше уже собранных profile data;
- уменьшить fixed synthetic behavior в server flight;
- сохранить валидный FakeTLS success path, одновременно меняя форму emitted traffic на проводе.
## Цели валидации
- Корректное количество эмулируемых `ChangeCipherSpec` records.
- Корректное воспроизведение наблюдаемых ticket-tail record sizes.
- Отсутствие регрессии в существующем ALPN и payload-placement behavior.
## Как проверять результат
Рекомендуемая валидация состоит из двух слоёв:
- focused unit и security tests для CCS-count replay и ticket-tail replay;
- сравнение реальных packet capture для выбранного origin и успешной FakeTLS session.
При проверке на сети ожидаемый результат такой:
- валидный FakeTLS и MTProto success path сохраняется;
- форма раннего encrypted server flight меняется, когда доступно более богатое profile data;
- изменение видно на проводе без изменения MTProto logic и transport architecture.
Такая проверка нужна для подтверждения того, что уже собранные TLS profile data используются лучше.
Она не предназначена для доказательства побайтной эквивалентности с реальным origin server.
## Как проверить на реальной инсталляции
Самая сильная практическая проверка — side-by-side trace comparison между:
- реальным TLS origin server, используемым как `mask_host`;
- Telemt FakeTLS success-path connection для того же SNI;
- при необходимости capture от разных Telemt builds или configurations.
Смысл сравнения состоит в том, чтобы посмотреть на форму server flight:
- порядок records;
- количество `ChangeCipherSpec` records;
- количество и группировку ранних encrypted `ApplicationData` records;
- размеры tail или continuation `ApplicationData` records.
## Рекомендуемое окружение
Для самой чистой проверки лучше использовать Linux host или Docker container.
Рекомендуемый setup:
1. Один экземпляр Telemt.
2. Один реальный HTTPS origin как `mask_host`.
3. Один Telegram client, настроенный на `ee` proxy link для Telemt instance.
4. `tcpdump` или Wireshark для анализа capture.
## Пошаговая процедура проверки
### 1. Подготовить origin
1. Выберите реальный HTTPS origin.
2. Установите и `censorship.tls_domain`, и `censorship.mask_host` в hostname этого origin.
3. Убедитесь, что прямой TLS request работает:
```bash
openssl s_client -connect ORIGIN_IP:443 -servername YOUR_DOMAIN </dev/null
```
### 2. Настроить Telemt
Используйте config, где включены:
- `censorship.mask = true`
- `censorship.tls_emulation = true`
- `censorship.mask_host`
- `censorship.mask_port`
Для более чистой проверки рекомендуется:
- держать `censorship.tls_new_session_tickets = 0`, чтобы результат в первую очередь зависел от fetched profile data, а не от операторских synthetic tail records;
- держать `censorship.tls_fetch.strict_route = true`, если важна более чистая provenance для captured profile data.
### 3. Обновить TLS profile data
1. Запустите Telemt.
2. Дайте ему получить TLS front profile data для выбранного домена.
3. Если `tls_front_dir` хранится persistently, убедитесь, что TLS front cache заполнен.
Persisted cache artifacts полезны, но не обязательны, если packet capture уже показывают runtime result.
### 4. Снять direct-origin trace
С отдельной клиентской машины подключитесь напрямую к origin:
```bash
openssl s_client -connect ORIGIN_IP:443 -servername YOUR_DOMAIN </dev/null
```
Capture:
```bash
sudo tcpdump -i any -w origin-direct.pcap host ORIGIN_IP and port 443
```
### 5. Снять Telemt FakeTLS success-path trace
Теперь подключитесь к Telemt через реальный Telegram client с `ee` proxy link, который указывает на Telemt instance.
`openssl s_client` полезен для direct-origin capture и для fallback sanity checks, но он не проходит успешный FakeTLS и MTProto path.
Capture:
```bash
sudo tcpdump -i any -w telemt-emulated.pcap host TELEMT_IP and port 443
```
### 6. Декодировать структуру TLS records
Используйте `tshark`, чтобы вывести record-level structure:
```bash
tshark -r origin-direct.pcap -Y "tls.record" -T fields \
-e frame.number \
-e ip.src \
-e ip.dst \
-e tls.record.content_type \
-e tls.record.length
```
```bash
tshark -r telemt-emulated.pcap -Y "tls.record" -T fields \
-e frame.number \
-e ip.src \
-e ip.dst \
-e tls.record.content_type \
-e tls.record.length
```
Смотрите на server flight после ClientHello:
- `22` = Handshake
- `20` = ChangeCipherSpec
- `23` = ApplicationData
### 7. Собрать сравнительную таблицу
Обычно достаточно короткой таблицы такого вида:
| Path | CCS count | AppData count in first encrypted flight | Tail AppData lengths |
| --- | --- | --- | --- |
| Origin | `N` | `M` | `[a, b, ...]` |
| Telemt build A | `...` | `...` | `...` |
| Telemt build B | `...` | `...` | `...` |
По такой таблице должно быть легко увидеть, что:
- FakeTLS success path остаётся валидным;
- ранний encrypted server flight меняется, когда переиспользуется более богатое profile data;
- результат подтверждён packet evidence.
## Пример набора capture
Один практический пример такой проверки использует:
- `origin-direct-nginx.pcap`
- `telemt-ee-before-nginx.pcap`
- `telemt-ee-after-nginx.pcap`
Практические замечания:
- `origin` снимался как прямое TLS 1.2 connection к `nginx.org`;
- `before` и `after` снимались на Telemt FakeTLS success path с реальным Telegram client;
- первый server-side FakeTLS response остаётся валидным в обоих случаях;
- сегментация раннего encrypted server flight отличается между `before` и `after`, что согласуется с лучшим использованием captured profile data;
- такой результат показывает заметный эффект на проводе без поломки success path, но не заявляет полной неотличимости от origin.
## Более сильная валидация
Для более широкой проверки повторите ту же процедуру ещё на:
1. одном CDN-backed origin;
2. одном regular nginx origin;
3. одном origin с multi-record encrypted flight и заметными ticket-like tails.
Если одно и то же направление улучшения повторится на всех трёх, уверенность в результате будет значительно выше, чем для одного origin example.

View File

@@ -121,6 +121,9 @@ pub struct HotFields {
pub user_max_tcp_conns_global_each: usize,
pub user_expirations: std::collections::HashMap<String, chrono::DateTime<chrono::Utc>>,
pub user_data_quota: std::collections::HashMap<String, u64>,
pub user_rate_limits: std::collections::HashMap<String, crate::config::RateLimitBps>,
pub cidr_rate_limits:
std::collections::HashMap<ipnetwork::IpNetwork, crate::config::RateLimitBps>,
pub user_max_unique_ips: std::collections::HashMap<String, usize>,
pub user_max_unique_ips_global_each: usize,
pub user_max_unique_ips_mode: crate::config::UserMaxUniqueIpsMode,
@@ -245,6 +248,8 @@ impl HotFields {
user_max_tcp_conns_global_each: cfg.access.user_max_tcp_conns_global_each,
user_expirations: cfg.access.user_expirations.clone(),
user_data_quota: cfg.access.user_data_quota.clone(),
user_rate_limits: cfg.access.user_rate_limits.clone(),
cidr_rate_limits: cfg.access.cidr_rate_limits.clone(),
user_max_unique_ips: cfg.access.user_max_unique_ips.clone(),
user_max_unique_ips_global_each: cfg.access.user_max_unique_ips_global_each,
user_max_unique_ips_mode: cfg.access.user_max_unique_ips_mode,
@@ -545,6 +550,8 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig {
cfg.access.user_max_tcp_conns_global_each = new.access.user_max_tcp_conns_global_each;
cfg.access.user_expirations = new.access.user_expirations.clone();
cfg.access.user_data_quota = new.access.user_data_quota.clone();
cfg.access.user_rate_limits = new.access.user_rate_limits.clone();
cfg.access.cidr_rate_limits = new.access.cidr_rate_limits.clone();
cfg.access.user_max_unique_ips = new.access.user_max_unique_ips.clone();
cfg.access.user_max_unique_ips_global_each = new.access.user_max_unique_ips_global_each;
cfg.access.user_max_unique_ips_mode = new.access.user_max_unique_ips_mode;
@@ -1183,6 +1190,18 @@ fn log_changes(
new_hot.user_data_quota.len()
);
}
if old_hot.user_rate_limits != new_hot.user_rate_limits {
info!(
"config reload: user_rate_limits updated ({} entries)",
new_hot.user_rate_limits.len()
);
}
if old_hot.cidr_rate_limits != new_hot.cidr_rate_limits {
info!(
"config reload: cidr_rate_limits updated ({} entries)",
new_hot.cidr_rate_limits.len()
);
}
if old_hot.user_max_unique_ips != new_hot.user_max_unique_ips {
info!(
"config reload: user_max_unique_ips updated ({} entries)",

View File

@@ -861,6 +861,22 @@ impl ProxyConfig {
));
}
for (user, limit) in &config.access.user_rate_limits {
if limit.up_bps == 0 && limit.down_bps == 0 {
return Err(ProxyError::Config(format!(
"access.user_rate_limits.{user} must set at least one non-zero direction"
)));
}
}
for (cidr, limit) in &config.access.cidr_rate_limits {
if limit.up_bps == 0 && limit.down_bps == 0 {
return Err(ProxyError::Config(format!(
"access.cidr_rate_limits.{cidr} must set at least one non-zero direction"
)));
}
}
if config.general.me_reinit_every_secs == 0 {
return Err(ProxyError::Config(
"general.me_reinit_every_secs must be > 0".to_string(),

View File

@@ -1826,6 +1826,21 @@ pub struct AccessConfig {
#[serde(default)]
pub user_data_quota: HashMap<String, u64>,
/// Per-user transport rate limits in bits-per-second.
///
/// Each entry supports independent upload (`up_bps`) and download
/// (`down_bps`) ceilings. A value of `0` in one direction means
/// "unlimited" for that direction.
#[serde(default)]
pub user_rate_limits: HashMap<String, RateLimitBps>,
/// Per-CIDR aggregate transport rate limits in bits-per-second.
///
/// Matching uses longest-prefix-wins semantics. A value of `0` in one
/// direction means "unlimited" for that direction.
#[serde(default)]
pub cidr_rate_limits: HashMap<IpNetwork, RateLimitBps>,
#[serde(default)]
pub user_max_unique_ips: HashMap<String, usize>,
@@ -1859,6 +1874,8 @@ impl Default for AccessConfig {
user_max_tcp_conns_global_each: default_user_max_tcp_conns_global_each(),
user_expirations: HashMap::new(),
user_data_quota: HashMap::new(),
user_rate_limits: HashMap::new(),
cidr_rate_limits: HashMap::new(),
user_max_unique_ips: HashMap::new(),
user_max_unique_ips_global_each: default_user_max_unique_ips_global_each(),
user_max_unique_ips_mode: UserMaxUniqueIpsMode::default(),
@@ -1870,6 +1887,14 @@ impl Default for AccessConfig {
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct RateLimitBps {
#[serde(default)]
pub up_bps: u64,
#[serde(default)]
pub down_bps: u64,
}
// ============= Aux Structures =============
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]

View File

@@ -8,6 +8,7 @@ use std::io::{self, Read, Write};
use std::os::unix::fs::OpenOptionsExt;
use std::path::{Path, PathBuf};
use nix::errno::Errno;
use nix::fcntl::{Flock, FlockArg};
use nix::unistd::{self, ForkResult, Gid, Pid, Uid, chdir, close, fork, getpid, setsid};
use tracing::{debug, info, warn};
@@ -157,15 +158,15 @@ fn redirect_stdio_to_devnull() -> Result<(), DaemonError> {
unsafe {
// Redirect stdin (fd 0)
if libc::dup2(devnull_fd, 0) < 0 {
return Err(DaemonError::RedirectFailed(nix::errno::Errno::last()));
return Err(DaemonError::RedirectFailed(Errno::last()));
}
// Redirect stdout (fd 1)
if libc::dup2(devnull_fd, 1) < 0 {
return Err(DaemonError::RedirectFailed(nix::errno::Errno::last()));
return Err(DaemonError::RedirectFailed(Errno::last()));
}
// Redirect stderr (fd 2)
if libc::dup2(devnull_fd, 2) < 0 {
return Err(DaemonError::RedirectFailed(nix::errno::Errno::last()));
return Err(DaemonError::RedirectFailed(Errno::last()));
}
}
@@ -337,6 +338,27 @@ fn is_process_running(pid: i32) -> bool {
nix::sys::signal::kill(Pid::from_raw(pid), None).is_ok()
}
// macOS gates nix::unistd::setgroups differently in the current dependency set,
// so call libc directly there while preserving the original nix path elsewhere.
fn set_supplementary_groups(gid: Gid) -> Result<(), nix::Error> {
#[cfg(target_os = "macos")]
{
let groups = [gid.as_raw()];
let rc = unsafe {
libc::setgroups(
i32::try_from(groups.len()).expect("single supplementary group must fit in c_int"),
groups.as_ptr(),
)
};
if rc == 0 { Ok(()) } else { Err(Errno::last()) }
}
#[cfg(not(target_os = "macos"))]
{
unistd::setgroups(&[gid])
}
}
/// Drops privileges to the specified user and group.
///
/// This should be called after binding privileged ports but before entering
@@ -368,7 +390,7 @@ pub fn drop_privileges(
if let Some(gid) = target_gid {
unistd::setgid(gid).map_err(DaemonError::PrivilegeDrop)?;
unistd::setgroups(&[gid]).map_err(DaemonError::PrivilegeDrop)?;
set_supplementary_groups(gid).map_err(DaemonError::PrivilegeDrop)?;
info!(gid = gid.as_raw(), "Dropped group privileges");
}

View File

@@ -664,6 +664,11 @@ async fn run_telemt_core(
));
let buffer_pool = Arc::new(BufferPool::with_config(64 * 1024, 4096));
let shared_state = ProxySharedState::new();
shared_state.traffic_limiter.apply_policy(
config.access.user_rate_limits.clone(),
config.access.cidr_rate_limits.clone(),
);
connectivity::run_startup_connectivity(
&config,
@@ -695,6 +700,7 @@ async fn run_telemt_core(
beobachten.clone(),
api_config_tx.clone(),
me_pool.clone(),
shared_state.clone(),
)
.await;
let config_rx = runtime_watches.config_rx;
@@ -711,7 +717,6 @@ async fn run_telemt_core(
)
.await;
let _admission_tx_hold = admission_tx;
let shared_state = ProxySharedState::new();
conntrack_control::spawn_conntrack_controller(
config_rx.clone(),
stats.clone(),

View File

@@ -51,6 +51,7 @@ pub(crate) async fn spawn_runtime_tasks(
beobachten: Arc<BeobachtenStore>,
api_config_tx: watch::Sender<Arc<ProxyConfig>>,
me_pool_for_policy: Option<Arc<MePool>>,
shared_state: Arc<ProxySharedState>,
) -> RuntimeWatches {
let um_clone = upstream_manager.clone();
let dc_overrides_for_health = config.dc_overrides.clone();
@@ -182,6 +183,41 @@ pub(crate) async fn spawn_runtime_tasks(
}
});
let limiter = shared_state.traffic_limiter.clone();
limiter.apply_policy(
config.access.user_rate_limits.clone(),
config.access.cidr_rate_limits.clone(),
);
let mut config_rx_rate_limits = config_rx.clone();
tokio::spawn(async move {
let mut prev_user_limits = config_rx_rate_limits
.borrow()
.access
.user_rate_limits
.clone();
let mut prev_cidr_limits = config_rx_rate_limits
.borrow()
.access
.cidr_rate_limits
.clone();
loop {
if config_rx_rate_limits.changed().await.is_err() {
break;
}
let cfg = config_rx_rate_limits.borrow_and_update().clone();
if prev_user_limits != cfg.access.user_rate_limits
|| prev_cidr_limits != cfg.access.cidr_rate_limits
{
limiter.apply_policy(
cfg.access.user_rate_limits.clone(),
cfg.access.cidr_rate_limits.clone(),
);
prev_user_limits = cfg.access.user_rate_limits.clone();
prev_cidr_limits = cfg.access.cidr_rate_limits.clone();
}
}
});
let beobachten_writer = beobachten.clone();
let config_rx_beobachten = config_rx.clone();
tokio::spawn(async move {

View File

@@ -575,6 +575,139 @@ async fn render_metrics(
}
);
let limiter_metrics = shared_state.traffic_limiter.metrics_snapshot();
let _ = writeln!(
out,
"# HELP telemt_rate_limiter_throttle_total Traffic limiter throttle events by scope and direction"
);
let _ = writeln!(out, "# TYPE telemt_rate_limiter_throttle_total counter");
let _ = writeln!(
out,
"telemt_rate_limiter_throttle_total{{scope=\"user\",direction=\"up\"}} {}",
if core_enabled {
limiter_metrics.user_throttle_up_total
} else {
0
}
);
let _ = writeln!(
out,
"telemt_rate_limiter_throttle_total{{scope=\"user\",direction=\"down\"}} {}",
if core_enabled {
limiter_metrics.user_throttle_down_total
} else {
0
}
);
let _ = writeln!(
out,
"telemt_rate_limiter_throttle_total{{scope=\"cidr\",direction=\"up\"}} {}",
if core_enabled {
limiter_metrics.cidr_throttle_up_total
} else {
0
}
);
let _ = writeln!(
out,
"telemt_rate_limiter_throttle_total{{scope=\"cidr\",direction=\"down\"}} {}",
if core_enabled {
limiter_metrics.cidr_throttle_down_total
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_rate_limiter_wait_ms_total Traffic limiter accumulated wait time in milliseconds by scope and direction"
);
let _ = writeln!(out, "# TYPE telemt_rate_limiter_wait_ms_total counter");
let _ = writeln!(
out,
"telemt_rate_limiter_wait_ms_total{{scope=\"user\",direction=\"up\"}} {}",
if core_enabled {
limiter_metrics.user_wait_up_ms_total
} else {
0
}
);
let _ = writeln!(
out,
"telemt_rate_limiter_wait_ms_total{{scope=\"user\",direction=\"down\"}} {}",
if core_enabled {
limiter_metrics.user_wait_down_ms_total
} else {
0
}
);
let _ = writeln!(
out,
"telemt_rate_limiter_wait_ms_total{{scope=\"cidr\",direction=\"up\"}} {}",
if core_enabled {
limiter_metrics.cidr_wait_up_ms_total
} else {
0
}
);
let _ = writeln!(
out,
"telemt_rate_limiter_wait_ms_total{{scope=\"cidr\",direction=\"down\"}} {}",
if core_enabled {
limiter_metrics.cidr_wait_down_ms_total
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_rate_limiter_active_leases Active relay leases under rate limiting by scope"
);
let _ = writeln!(out, "# TYPE telemt_rate_limiter_active_leases gauge");
let _ = writeln!(
out,
"telemt_rate_limiter_active_leases{{scope=\"user\"}} {}",
if core_enabled {
limiter_metrics.user_active_leases
} else {
0
}
);
let _ = writeln!(
out,
"telemt_rate_limiter_active_leases{{scope=\"cidr\"}} {}",
if core_enabled {
limiter_metrics.cidr_active_leases
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_rate_limiter_policy_entries Active rate-limit policy entries by scope"
);
let _ = writeln!(out, "# TYPE telemt_rate_limiter_policy_entries gauge");
let _ = writeln!(
out,
"telemt_rate_limiter_policy_entries{{scope=\"user\"}} {}",
if core_enabled {
limiter_metrics.user_policy_entries
} else {
0
}
);
let _ = writeln!(
out,
"telemt_rate_limiter_policy_entries{{scope=\"cidr\"}} {}",
if core_enabled {
limiter_metrics.cidr_policy_entries
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_upstream_connect_attempt_total Upstream connect attempts across all requests"
@@ -1177,6 +1310,143 @@ async fn render_metrics(
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_fair_pressure_state Worker-local fairness pressure state"
);
let _ = writeln!(out, "# TYPE telemt_me_fair_pressure_state gauge");
let _ = writeln!(
out,
"telemt_me_fair_pressure_state {}",
if me_allows_normal {
stats.get_me_fair_pressure_state_gauge()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_fair_active_flows Fair-scheduler active flow count"
);
let _ = writeln!(out, "# TYPE telemt_me_fair_active_flows gauge");
let _ = writeln!(
out,
"telemt_me_fair_active_flows {}",
if me_allows_normal {
stats.get_me_fair_active_flows_gauge()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_fair_queued_bytes Fair-scheduler queued bytes"
);
let _ = writeln!(out, "# TYPE telemt_me_fair_queued_bytes gauge");
let _ = writeln!(
out,
"telemt_me_fair_queued_bytes {}",
if me_allows_normal {
stats.get_me_fair_queued_bytes_gauge()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_fair_flow_state_gauge Fair-scheduler flow health classes"
);
let _ = writeln!(out, "# TYPE telemt_me_fair_flow_state_gauge gauge");
let _ = writeln!(
out,
"telemt_me_fair_flow_state_gauge{{class=\"standing\"}} {}",
if me_allows_normal {
stats.get_me_fair_standing_flows_gauge()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_fair_flow_state_gauge{{class=\"backpressured\"}} {}",
if me_allows_normal {
stats.get_me_fair_backpressured_flows_gauge()
} else {
0
}
);
let _ = writeln!(
out,
"# HELP telemt_me_fair_events_total Fair-scheduler event counters"
);
let _ = writeln!(out, "# TYPE telemt_me_fair_events_total counter");
let _ = writeln!(
out,
"telemt_me_fair_events_total{{event=\"scheduler_round\"}} {}",
if me_allows_normal {
stats.get_me_fair_scheduler_rounds_total()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_fair_events_total{{event=\"deficit_grant\"}} {}",
if me_allows_normal {
stats.get_me_fair_deficit_grants_total()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_fair_events_total{{event=\"deficit_skip\"}} {}",
if me_allows_normal {
stats.get_me_fair_deficit_skips_total()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_fair_events_total{{event=\"enqueue_reject\"}} {}",
if me_allows_normal {
stats.get_me_fair_enqueue_rejects_total()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_fair_events_total{{event=\"shed_drop\"}} {}",
if me_allows_normal {
stats.get_me_fair_shed_drops_total()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_fair_events_total{{event=\"penalty\"}} {}",
if me_allows_normal {
stats.get_me_fair_penalties_total()
} else {
0
}
);
let _ = writeln!(
out,
"telemt_me_fair_events_total{{event=\"downstream_stall\"}} {}",
if me_allows_normal {
stats.get_me_fair_downstream_stalls_total()
} else {
0
}
);
let _ = writeln!(
out,

View File

@@ -316,6 +316,9 @@ where
stats.increment_user_connects(user);
let _direct_connection_lease = stats.acquire_direct_connection_lease();
let traffic_lease = shared
.traffic_limiter
.acquire_lease(user, success.peer.ip());
let buffer_pool_trim = Arc::clone(&buffer_pool);
let relay_activity_timeout = if shared.conntrack_pressure_active() {
@@ -329,7 +332,7 @@ where
} else {
Duration::from_secs(1800)
};
let relay_result = crate::proxy::relay::relay_bidirectional_with_activity_timeout(
let relay_result = crate::proxy::relay::relay_bidirectional_with_activity_timeout_and_lease(
client_reader,
client_writer,
tg_reader,
@@ -340,6 +343,7 @@ where
Arc::clone(&stats),
config.access.user_data_quota.get(user).copied(),
buffer_pool,
traffic_lease,
relay_activity_timeout,
);
tokio::pin!(relay_result);

View File

@@ -28,6 +28,7 @@ use crate::proxy::route_mode::{
use crate::proxy::shared_state::{
ConntrackCloseEvent, ConntrackClosePublishResult, ConntrackCloseReason, ProxySharedState,
};
use crate::proxy::traffic_limiter::{RateDirection, TrafficLease, next_refill_delay};
use crate::stats::{
MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, QuotaReserveError, Stats, UserStats,
};
@@ -286,6 +287,10 @@ impl RelayClientIdleState {
self.last_client_frame_at = now;
self.soft_idle_marked = false;
}
fn on_client_tiny_frame(&mut self, now: Instant) {
self.last_client_frame_at = now;
}
}
impl MeD2cFlushPolicy {
@@ -595,6 +600,41 @@ async fn reserve_user_quota_with_yield(
}
}
async fn wait_for_traffic_budget(
lease: Option<&Arc<TrafficLease>>,
direction: RateDirection,
bytes: u64,
) {
if bytes == 0 {
return;
}
let Some(lease) = lease else {
return;
};
let mut remaining = bytes;
while remaining > 0 {
let consume = lease.try_consume(direction, remaining);
if consume.granted > 0 {
remaining = remaining.saturating_sub(consume.granted);
continue;
}
let wait_started_at = Instant::now();
tokio::time::sleep(next_refill_delay()).await;
let wait_ms = wait_started_at
.elapsed()
.as_millis()
.min(u128::from(u64::MAX)) as u64;
lease.observe_wait_ms(
direction,
consume.blocked_user,
consume.blocked_cidr,
wait_ms,
);
}
}
fn classify_me_d2c_flush_reason(
flush_immediately: bool,
batch_frames: usize,
@@ -985,6 +1025,7 @@ where
let quota_limit = config.access.user_data_quota.get(&user).copied();
let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user));
let peer = success.peer;
let traffic_lease = shared.traffic_limiter.acquire_lease(&user, peer.ip());
let proto_tag = success.proto_tag;
let pool_generation = me_pool.current_generation();
@@ -1120,6 +1161,7 @@ where
let rng_clone = rng.clone();
let user_clone = user.clone();
let quota_user_stats_me_writer = quota_user_stats.clone();
let traffic_lease_me_writer = traffic_lease.clone();
let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone();
let bytes_me2c_clone = bytes_me2c.clone();
let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config);
@@ -1153,7 +1195,7 @@ where
let first_is_downstream_activity =
matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_));
match process_me_writer_response(
match process_me_writer_response_with_traffic_lease(
first,
&mut writer,
proto_tag,
@@ -1164,6 +1206,7 @@ where
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
traffic_lease_me_writer.as_ref(),
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@@ -1213,7 +1256,7 @@ where
let next_is_downstream_activity =
matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_));
match process_me_writer_response(
match process_me_writer_response_with_traffic_lease(
next,
&mut writer,
proto_tag,
@@ -1224,6 +1267,7 @@ where
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
traffic_lease_me_writer.as_ref(),
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@@ -1276,7 +1320,7 @@ where
Ok(Some(next)) => {
let next_is_downstream_activity =
matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_));
match process_me_writer_response(
match process_me_writer_response_with_traffic_lease(
next,
&mut writer,
proto_tag,
@@ -1287,6 +1331,7 @@ where
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
traffic_lease_me_writer.as_ref(),
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@@ -1341,7 +1386,7 @@ where
let extra_is_downstream_activity =
matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_));
match process_me_writer_response(
match process_me_writer_response_with_traffic_lease(
extra,
&mut writer,
proto_tag,
@@ -1352,6 +1397,7 @@ where
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
traffic_lease_me_writer.as_ref(),
bytes_me2c_clone.as_ref(),
conn_id,
d2c_flush_policy.ack_flush_immediate,
@@ -1542,6 +1588,12 @@ where
match payload_result {
Ok(Some((payload, quickack))) => {
trace!(conn_id, bytes = payload.len(), "C->ME frame");
wait_for_traffic_budget(
traffic_lease.as_ref(),
RateDirection::Up,
payload.len() as u64,
)
.await;
forensics.bytes_c2me = forensics
.bytes_c2me
.saturating_add(payload.len() as u64);
@@ -1762,40 +1814,6 @@ where
let downstream_ms = last_downstream_activity_ms.load(Ordering::Relaxed);
let hard_deadline =
hard_deadline(idle_policy, idle_state, session_started_at, downstream_ms);
if now >= hard_deadline {
clear_relay_idle_candidate_in(shared, forensics.conn_id);
stats.increment_relay_idle_hard_close_total();
let client_idle_secs = now
.saturating_duration_since(idle_state.last_client_frame_at)
.as_secs();
let downstream_idle_secs = now
.saturating_duration_since(
session_started_at + Duration::from_millis(downstream_ms),
)
.as_secs();
warn!(
trace_id = format_args!("0x{:016x}", forensics.trace_id),
conn_id = forensics.conn_id,
user = %forensics.user,
read_label,
client_idle_secs,
downstream_idle_secs,
soft_idle_secs = idle_policy.soft_idle.as_secs(),
hard_idle_secs = idle_policy.hard_idle.as_secs(),
grace_secs = idle_policy.grace_after_downstream_activity.as_secs(),
"Middle-relay hard idle close"
);
return Err(ProxyError::Io(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!(
"middle-relay hard idle timeout while reading {read_label}: client_idle_secs={client_idle_secs}, downstream_idle_secs={downstream_idle_secs}, soft_idle_secs={}, hard_idle_secs={}, grace_secs={}",
idle_policy.soft_idle.as_secs(),
idle_policy.hard_idle.as_secs(),
idle_policy.grace_after_downstream_activity.as_secs(),
),
)));
}
if !idle_state.soft_idle_marked
&& now.saturating_duration_since(idle_state.last_client_frame_at)
>= idle_policy.soft_idle
@@ -1850,7 +1868,45 @@ where
),
)));
}
Err(_) => {}
Err(_) => {
let now = Instant::now();
let downstream_ms = last_downstream_activity_ms.load(Ordering::Relaxed);
let hard_deadline =
hard_deadline(idle_policy, idle_state, session_started_at, downstream_ms);
if now >= hard_deadline {
clear_relay_idle_candidate_in(shared, forensics.conn_id);
stats.increment_relay_idle_hard_close_total();
let client_idle_secs = now
.saturating_duration_since(idle_state.last_client_frame_at)
.as_secs();
let downstream_idle_secs = now
.saturating_duration_since(
session_started_at + Duration::from_millis(downstream_ms),
)
.as_secs();
warn!(
trace_id = format_args!("0x{:016x}", forensics.trace_id),
conn_id = forensics.conn_id,
user = %forensics.user,
read_label,
client_idle_secs,
downstream_idle_secs,
soft_idle_secs = idle_policy.soft_idle.as_secs(),
hard_idle_secs = idle_policy.hard_idle.as_secs(),
grace_secs = idle_policy.grace_after_downstream_activity.as_secs(),
"Middle-relay hard idle close"
);
return Err(ProxyError::Io(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!(
"middle-relay hard idle timeout while reading {read_label}: client_idle_secs={client_idle_secs}, downstream_idle_secs={downstream_idle_secs}, soft_idle_secs={}, hard_idle_secs={}, grace_secs={}",
idle_policy.soft_idle.as_secs(),
idle_policy.hard_idle.as_secs(),
idle_policy.grace_after_downstream_activity.as_secs(),
),
)));
}
}
}
}
@@ -1941,6 +1997,7 @@ where
};
if len == 0 {
idle_state.on_client_tiny_frame(Instant::now());
idle_state.tiny_frame_debt = idle_state
.tiny_frame_debt
.saturating_add(TINY_FRAME_DEBT_PER_TINY);
@@ -2160,6 +2217,46 @@ async fn process_me_writer_response<W>(
ack_flush_immediate: bool,
batched: bool,
) -> Result<MeWriterResponseOutcome>
where
W: AsyncWrite + Unpin + Send + 'static,
{
process_me_writer_response_with_traffic_lease(
response,
client_writer,
proto_tag,
rng,
frame_buf,
stats,
user,
quota_user_stats,
quota_limit,
quota_soft_overshoot_bytes,
None,
bytes_me2c,
conn_id,
ack_flush_immediate,
batched,
)
.await
}
async fn process_me_writer_response_with_traffic_lease<W>(
response: MeResponse,
client_writer: &mut CryptoWriter<W>,
proto_tag: ProtoTag,
rng: &SecureRandom,
frame_buf: &mut Vec<u8>,
stats: &Stats,
user: &str,
quota_user_stats: Option<&UserStats>,
quota_limit: Option<u64>,
quota_soft_overshoot_bytes: u64,
traffic_lease: Option<&Arc<TrafficLease>>,
bytes_me2c: &AtomicU64,
conn_id: u64,
ack_flush_immediate: bool,
batched: bool,
) -> Result<MeWriterResponseOutcome>
where
W: AsyncWrite + Unpin + Send + 'static,
{
@@ -2183,6 +2280,7 @@ where
});
}
}
wait_for_traffic_budget(traffic_lease, RateDirection::Down, data_len).await;
let write_mode =
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
@@ -2220,6 +2318,7 @@ where
} else {
trace!(conn_id, confirm, "ME->C quickack");
}
wait_for_traffic_budget(traffic_lease, RateDirection::Down, 4).await;
write_client_ack(client_writer, proto_tag, confirm).await?;
stats.increment_me_d2c_ack_frames_total();

View File

@@ -68,6 +68,7 @@ pub mod relay;
pub mod route_mode;
pub mod session_eviction;
pub mod shared_state;
pub mod traffic_limiter;
pub use client::ClientHandler;
#[allow(unused_imports)]

View File

@@ -52,6 +52,7 @@
//! - `SharedCounters` (atomics) let the watchdog read stats without locking
use crate::error::{ProxyError, Result};
use crate::proxy::traffic_limiter::{RateDirection, TrafficLease, next_refill_delay};
use crate::stats::{Stats, UserStats};
use crate::stream::BufferPool;
use std::io;
@@ -61,7 +62,7 @@ use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes};
use tokio::time::Instant;
use tokio::time::{Instant, Sleep};
use tracing::{debug, trace, warn};
// ============= Constants =============
@@ -210,12 +211,24 @@ struct StatsIo<S> {
stats: Arc<Stats>,
user: String,
user_stats: Arc<UserStats>,
traffic_lease: Option<Arc<TrafficLease>>,
c2s_rate_debt_bytes: u64,
c2s_wait: RateWaitState,
s2c_wait: RateWaitState,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
quota_bytes_since_check: u64,
epoch: Instant,
}
#[derive(Default)]
struct RateWaitState {
sleep: Option<Pin<Box<Sleep>>>,
started_at: Option<Instant>,
blocked_user: bool,
blocked_cidr: bool,
}
impl<S> StatsIo<S> {
fn new(
inner: S,
@@ -225,6 +238,28 @@ impl<S> StatsIo<S> {
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
epoch: Instant,
) -> Self {
Self::new_with_traffic_lease(
inner,
counters,
stats,
user,
None,
quota_limit,
quota_exceeded,
epoch,
)
}
fn new_with_traffic_lease(
inner: S,
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
traffic_lease: Option<Arc<TrafficLease>>,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
epoch: Instant,
) -> Self {
// Mark initial activity so the watchdog doesn't fire before data flows
counters.touch(Instant::now(), epoch);
@@ -235,12 +270,88 @@ impl<S> StatsIo<S> {
stats,
user,
user_stats,
traffic_lease,
c2s_rate_debt_bytes: 0,
c2s_wait: RateWaitState::default(),
s2c_wait: RateWaitState::default(),
quota_limit,
quota_exceeded,
quota_bytes_since_check: 0,
epoch,
}
}
fn record_wait(
wait: &mut RateWaitState,
lease: Option<&Arc<TrafficLease>>,
direction: RateDirection,
) {
let Some(started_at) = wait.started_at.take() else {
return;
};
let wait_ms = started_at.elapsed().as_millis().min(u128::from(u64::MAX)) as u64;
if let Some(lease) = lease {
lease.observe_wait_ms(direction, wait.blocked_user, wait.blocked_cidr, wait_ms);
}
wait.blocked_user = false;
wait.blocked_cidr = false;
}
fn arm_wait(wait: &mut RateWaitState, blocked_user: bool, blocked_cidr: bool) {
if wait.sleep.is_none() {
wait.sleep = Some(Box::pin(tokio::time::sleep(next_refill_delay())));
wait.started_at = Some(Instant::now());
}
wait.blocked_user |= blocked_user;
wait.blocked_cidr |= blocked_cidr;
}
fn poll_wait(
wait: &mut RateWaitState,
cx: &mut Context<'_>,
lease: Option<&Arc<TrafficLease>>,
direction: RateDirection,
) -> Poll<()> {
let Some(sleep) = wait.sleep.as_mut() else {
return Poll::Ready(());
};
if sleep.as_mut().poll(cx).is_pending() {
return Poll::Pending;
}
wait.sleep = None;
Self::record_wait(wait, lease, direction);
Poll::Ready(())
}
fn settle_c2s_rate_debt(&mut self, cx: &mut Context<'_>) -> Poll<()> {
let Some(lease) = self.traffic_lease.as_ref() else {
self.c2s_rate_debt_bytes = 0;
return Poll::Ready(());
};
while self.c2s_rate_debt_bytes > 0 {
let consume = lease.try_consume(RateDirection::Up, self.c2s_rate_debt_bytes);
if consume.granted > 0 {
self.c2s_rate_debt_bytes = self.c2s_rate_debt_bytes.saturating_sub(consume.granted);
continue;
}
Self::arm_wait(
&mut self.c2s_wait,
consume.blocked_user,
consume.blocked_cidr,
);
if Self::poll_wait(&mut self.c2s_wait, cx, Some(lease), RateDirection::Up).is_pending()
{
return Poll::Pending;
}
}
if Self::poll_wait(&mut self.c2s_wait, cx, Some(lease), RateDirection::Up).is_pending() {
return Poll::Pending;
}
Poll::Ready(())
}
}
#[derive(Debug)]
@@ -286,6 +397,25 @@ fn should_immediate_quota_check(remaining_before: u64, charge_bytes: u64) -> boo
remaining_before <= QUOTA_NEAR_LIMIT_BYTES || charge_bytes >= QUOTA_LARGE_CHARGE_BYTES
}
fn refund_reserved_quota_bytes(user_stats: &UserStats, reserved_bytes: u64) {
if reserved_bytes == 0 {
return;
}
let mut current = user_stats.quota_used.load(Ordering::Relaxed);
loop {
let next = current.saturating_sub(reserved_bytes);
match user_stats.quota_used.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return,
Err(observed) => current = observed,
}
}
}
impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
fn poll_read(
self: Pin<&mut Self>,
@@ -296,6 +426,9 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
if this.quota_exceeded.load(Ordering::Acquire) {
return Poll::Ready(Err(quota_io_error()));
}
if this.settle_c2s_rate_debt(cx).is_pending() {
return Poll::Pending;
}
let mut remaining_before = None;
if let Some(limit) = this.quota_limit {
@@ -377,6 +510,11 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
.add_user_octets_from_handle(this.user_stats.as_ref(), n_to_charge);
this.stats
.increment_user_msgs_from_handle(this.user_stats.as_ref());
if this.traffic_lease.is_some() {
this.c2s_rate_debt_bytes =
this.c2s_rate_debt_bytes.saturating_add(n_to_charge);
let _ = this.settle_c2s_rate_debt(cx);
}
trace!(user = %this.user, bytes = n, "C->S");
}
@@ -398,28 +536,66 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
return Poll::Ready(Err(quota_io_error()));
}
let mut shaper_reserved_bytes = 0u64;
let mut write_buf = buf;
if let Some(lease) = this.traffic_lease.as_ref() {
if !buf.is_empty() {
loop {
let consume = lease.try_consume(RateDirection::Down, buf.len() as u64);
if consume.granted > 0 {
shaper_reserved_bytes = consume.granted;
if consume.granted < buf.len() as u64 {
write_buf = &buf[..consume.granted as usize];
}
let _ = Self::poll_wait(
&mut this.s2c_wait,
cx,
Some(lease),
RateDirection::Down,
);
break;
}
Self::arm_wait(
&mut this.s2c_wait,
consume.blocked_user,
consume.blocked_cidr,
);
if Self::poll_wait(&mut this.s2c_wait, cx, Some(lease), RateDirection::Down)
.is_pending()
{
return Poll::Pending;
}
}
} else {
let _ = Self::poll_wait(&mut this.s2c_wait, cx, Some(lease), RateDirection::Down);
}
}
let mut remaining_before = None;
let mut reserved_bytes = 0u64;
let mut write_buf = buf;
if let Some(limit) = this.quota_limit {
if !buf.is_empty() {
if !write_buf.is_empty() {
let mut reserve_rounds = 0usize;
while reserved_bytes == 0 {
let used_before = this.user_stats.quota_used();
let remaining = limit.saturating_sub(used_before);
if remaining == 0 {
if let Some(lease) = this.traffic_lease.as_ref() {
lease.refund(RateDirection::Down, shaper_reserved_bytes);
}
this.quota_exceeded.store(true, Ordering::Release);
return Poll::Ready(Err(quota_io_error()));
}
remaining_before = Some(remaining);
let desired = remaining.min(buf.len() as u64);
let desired = remaining.min(write_buf.len() as u64);
let mut saw_contention = false;
for _ in 0..QUOTA_RESERVE_SPIN_RETRIES {
match this.user_stats.quota_try_reserve(desired, limit) {
Ok(_) => {
reserved_bytes = desired;
write_buf = &buf[..desired as usize];
write_buf = &write_buf[..desired as usize];
break;
}
Err(crate::stats::QuotaReserveError::LimitExceeded) => {
@@ -434,6 +610,9 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
if reserved_bytes == 0 {
reserve_rounds = reserve_rounds.saturating_add(1);
if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS {
if let Some(lease) = this.traffic_lease.as_ref() {
lease.refund(RateDirection::Down, shaper_reserved_bytes);
}
this.quota_exceeded.store(true, Ordering::Release);
return Poll::Ready(Err(quota_io_error()));
}
@@ -446,6 +625,9 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
let used_before = this.user_stats.quota_used();
let remaining = limit.saturating_sub(used_before);
if remaining == 0 {
if let Some(lease) = this.traffic_lease.as_ref() {
lease.refund(RateDirection::Down, shaper_reserved_bytes);
}
this.quota_exceeded.store(true, Ordering::Release);
return Poll::Ready(Err(quota_io_error()));
}
@@ -456,23 +638,20 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
match Pin::new(&mut this.inner).poll_write(cx, write_buf) {
Poll::Ready(Ok(n)) => {
if reserved_bytes > n as u64 {
let refund = reserved_bytes - n as u64;
let mut current = this.user_stats.quota_used.load(Ordering::Relaxed);
loop {
let next = current.saturating_sub(refund);
match this.user_stats.quota_used.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(observed) => current = observed,
refund_reserved_quota_bytes(
this.user_stats.as_ref(),
reserved_bytes - n as u64,
);
}
if shaper_reserved_bytes > n as u64
&& let Some(lease) = this.traffic_lease.as_ref()
{
lease.refund(RateDirection::Down, shaper_reserved_bytes - n as u64);
}
}
if n > 0 {
if let Some(lease) = this.traffic_lease.as_ref() {
Self::record_wait(&mut this.s2c_wait, Some(lease), RateDirection::Down);
}
let n_to_charge = n as u64;
// S→C: data written to client
@@ -512,37 +691,23 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
}
Poll::Ready(Err(err)) => {
if reserved_bytes > 0 {
let mut current = this.user_stats.quota_used.load(Ordering::Relaxed);
loop {
let next = current.saturating_sub(reserved_bytes);
match this.user_stats.quota_used.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(observed) => current = observed,
}
refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes);
}
if shaper_reserved_bytes > 0
&& let Some(lease) = this.traffic_lease.as_ref()
{
lease.refund(RateDirection::Down, shaper_reserved_bytes);
}
Poll::Ready(Err(err))
}
Poll::Pending => {
if reserved_bytes > 0 {
let mut current = this.user_stats.quota_used.load(Ordering::Relaxed);
loop {
let next = current.saturating_sub(reserved_bytes);
match this.user_stats.quota_used.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(observed) => current = observed,
}
refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes);
}
if shaper_reserved_bytes > 0
&& let Some(lease) = this.traffic_lease.as_ref()
{
lease.refund(RateDirection::Down, shaper_reserved_bytes);
}
Poll::Pending
}
@@ -627,6 +792,43 @@ pub async fn relay_bidirectional_with_activity_timeout<CR, CW, SR, SW>(
_buffer_pool: Arc<BufferPool>,
activity_timeout: Duration,
) -> Result<()>
where
CR: AsyncRead + Unpin + Send + 'static,
CW: AsyncWrite + Unpin + Send + 'static,
SR: AsyncRead + Unpin + Send + 'static,
SW: AsyncWrite + Unpin + Send + 'static,
{
relay_bidirectional_with_activity_timeout_and_lease(
client_reader,
client_writer,
server_reader,
server_writer,
c2s_buf_size,
s2c_buf_size,
user,
stats,
quota_limit,
_buffer_pool,
None,
activity_timeout,
)
.await
}
pub async fn relay_bidirectional_with_activity_timeout_and_lease<CR, CW, SR, SW>(
client_reader: CR,
client_writer: CW,
server_reader: SR,
server_writer: SW,
c2s_buf_size: usize,
s2c_buf_size: usize,
user: &str,
stats: Arc<Stats>,
quota_limit: Option<u64>,
_buffer_pool: Arc<BufferPool>,
traffic_lease: Option<Arc<TrafficLease>>,
activity_timeout: Duration,
) -> Result<()>
where
CR: AsyncRead + Unpin + Send + 'static,
CW: AsyncWrite + Unpin + Send + 'static,
@@ -644,11 +846,12 @@ where
let mut server = CombinedStream::new(server_reader, server_writer);
// Wrap client with stats/activity tracking
let mut client = StatsIo::new(
let mut client = StatsIo::new_with_traffic_lease(
client_combined,
Arc::clone(&counters),
Arc::clone(&stats),
user_owned.clone(),
traffic_lease,
quota_limit,
Arc::clone(&quota_exceeded),
epoch,

View File

@@ -10,6 +10,7 @@ use tokio::sync::mpsc;
use crate::proxy::handshake::{AuthProbeSaturationState, AuthProbeState};
use crate::proxy::middle_relay::{DesyncDedupRotationState, RelayIdleCandidateRegistry};
use crate::proxy::traffic_limiter::TrafficLimiter;
const HANDSHAKE_RECENT_USER_RING_LEN: usize = 64;
@@ -65,6 +66,7 @@ pub(crate) struct MiddleRelaySharedState {
pub(crate) struct ProxySharedState {
pub(crate) handshake: HandshakeSharedState,
pub(crate) middle_relay: MiddleRelaySharedState,
pub(crate) traffic_limiter: Arc<TrafficLimiter>,
pub(crate) conntrack_pressure_active: AtomicBool,
pub(crate) conntrack_close_tx: Mutex<Option<mpsc::Sender<ConntrackCloseEvent>>>,
}
@@ -98,6 +100,7 @@ impl ProxySharedState {
relay_idle_registry: Mutex::new(RelayIdleCandidateRegistry::default()),
relay_idle_mark_seq: AtomicU64::new(0),
},
traffic_limiter: TrafficLimiter::new(),
conntrack_pressure_active: AtomicBool::new(false),
conntrack_close_tx: Mutex::new(None),
})

View File

@@ -0,0 +1,853 @@
use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::net::IpAddr;
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use arc_swap::ArcSwap;
use dashmap::DashMap;
use ipnetwork::IpNetwork;
use crate::config::RateLimitBps;
const REGISTRY_SHARDS: usize = 64;
const FAIR_EPOCH_MS: u64 = 20;
const MAX_BORROW_CHUNK_BYTES: u64 = 32 * 1024;
const CLEANUP_INTERVAL_SECS: u64 = 60;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RateDirection {
Up,
Down,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TrafficConsumeResult {
pub granted: u64,
pub blocked_user: bool,
pub blocked_cidr: bool,
}
#[derive(Debug, Clone, Copy)]
pub struct TrafficLimiterMetricsSnapshot {
pub user_throttle_up_total: u64,
pub user_throttle_down_total: u64,
pub cidr_throttle_up_total: u64,
pub cidr_throttle_down_total: u64,
pub user_wait_up_ms_total: u64,
pub user_wait_down_ms_total: u64,
pub cidr_wait_up_ms_total: u64,
pub cidr_wait_down_ms_total: u64,
pub user_active_leases: u64,
pub cidr_active_leases: u64,
pub user_policy_entries: u64,
pub cidr_policy_entries: u64,
}
#[derive(Default)]
struct ScopeMetrics {
throttle_up_total: AtomicU64,
throttle_down_total: AtomicU64,
wait_up_ms_total: AtomicU64,
wait_down_ms_total: AtomicU64,
active_leases: AtomicU64,
policy_entries: AtomicU64,
}
impl ScopeMetrics {
fn throttle(&self, direction: RateDirection) {
match direction {
RateDirection::Up => {
self.throttle_up_total.fetch_add(1, Ordering::Relaxed);
}
RateDirection::Down => {
self.throttle_down_total.fetch_add(1, Ordering::Relaxed);
}
}
}
fn wait_ms(&self, direction: RateDirection, wait_ms: u64) {
match direction {
RateDirection::Up => {
self.wait_up_ms_total.fetch_add(wait_ms, Ordering::Relaxed);
}
RateDirection::Down => {
self.wait_down_ms_total
.fetch_add(wait_ms, Ordering::Relaxed);
}
}
}
}
#[derive(Default)]
struct AtomicRatePair {
up_bps: AtomicU64,
down_bps: AtomicU64,
}
impl AtomicRatePair {
fn set(&self, limits: RateLimitBps) {
self.up_bps.store(limits.up_bps, Ordering::Relaxed);
self.down_bps.store(limits.down_bps, Ordering::Relaxed);
}
fn get(&self, direction: RateDirection) -> u64 {
match direction {
RateDirection::Up => self.up_bps.load(Ordering::Relaxed),
RateDirection::Down => self.down_bps.load(Ordering::Relaxed),
}
}
}
#[derive(Default)]
struct DirectionBucket {
epoch: AtomicU64,
used: AtomicU64,
}
impl DirectionBucket {
fn sync_epoch(&self, epoch: u64) {
let current = self.epoch.load(Ordering::Relaxed);
if current == epoch {
return;
}
if current < epoch
&& self
.epoch
.compare_exchange(current, epoch, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
self.used.store(0, Ordering::Relaxed);
}
}
fn try_consume(&self, cap_bps: u64, requested: u64) -> u64 {
if requested == 0 {
return 0;
}
if cap_bps == 0 {
return requested;
}
let epoch = current_epoch();
self.sync_epoch(epoch);
let cap_epoch = bytes_per_epoch(cap_bps);
loop {
let used = self.used.load(Ordering::Relaxed);
if used >= cap_epoch {
return 0;
}
let remaining = cap_epoch.saturating_sub(used);
let grant = requested.min(remaining);
if grant == 0 {
return 0;
}
let next = used.saturating_add(grant);
if self
.used
.compare_exchange_weak(used, next, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
return grant;
}
}
}
fn refund(&self, bytes: u64) {
if bytes == 0 {
return;
}
decrement_atomic_saturating(&self.used, bytes);
}
}
struct UserBucket {
rates: AtomicRatePair,
up: DirectionBucket,
down: DirectionBucket,
active_leases: AtomicU64,
}
impl UserBucket {
fn new(limits: RateLimitBps) -> Self {
let rates = AtomicRatePair::default();
rates.set(limits);
Self {
rates,
up: DirectionBucket::default(),
down: DirectionBucket::default(),
active_leases: AtomicU64::new(0),
}
}
fn set_rates(&self, limits: RateLimitBps) {
self.rates.set(limits);
}
fn try_consume(&self, direction: RateDirection, requested: u64) -> u64 {
let cap_bps = self.rates.get(direction);
match direction {
RateDirection::Up => self.up.try_consume(cap_bps, requested),
RateDirection::Down => self.down.try_consume(cap_bps, requested),
}
}
fn refund(&self, direction: RateDirection, bytes: u64) {
match direction {
RateDirection::Up => self.up.refund(bytes),
RateDirection::Down => self.down.refund(bytes),
}
}
}
#[derive(Default)]
struct CidrDirectionBucket {
epoch: AtomicU64,
used: AtomicU64,
active_users: AtomicU64,
}
impl CidrDirectionBucket {
fn sync_epoch(&self, epoch: u64) {
let current = self.epoch.load(Ordering::Relaxed);
if current == epoch {
return;
}
if current < epoch
&& self
.epoch
.compare_exchange(current, epoch, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
self.used.store(0, Ordering::Relaxed);
self.active_users.store(0, Ordering::Relaxed);
}
}
fn try_consume(
&self,
user_state: &CidrUserDirectionState,
cap_epoch: u64,
requested: u64,
) -> u64 {
if requested == 0 || cap_epoch == 0 {
return 0;
}
let epoch = current_epoch();
self.sync_epoch(epoch);
user_state.sync_epoch_and_mark_active(epoch, &self.active_users);
let active_users = self.active_users.load(Ordering::Relaxed).max(1);
let fair_share = cap_epoch.saturating_div(active_users).max(1);
loop {
let total_used = self.used.load(Ordering::Relaxed);
if total_used >= cap_epoch {
return 0;
}
let total_remaining = cap_epoch.saturating_sub(total_used);
let user_used = user_state.used.load(Ordering::Relaxed);
let guaranteed_remaining = fair_share.saturating_sub(user_used);
let grant = if guaranteed_remaining > 0 {
requested.min(guaranteed_remaining).min(total_remaining)
} else {
requested.min(total_remaining).min(MAX_BORROW_CHUNK_BYTES)
};
if grant == 0 {
return 0;
}
let next_total = total_used.saturating_add(grant);
if self
.used
.compare_exchange_weak(total_used, next_total, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
user_state.used.fetch_add(grant, Ordering::Relaxed);
return grant;
}
}
}
fn refund(&self, bytes: u64) {
if bytes == 0 {
return;
}
decrement_atomic_saturating(&self.used, bytes);
}
}
#[derive(Default)]
struct CidrUserDirectionState {
epoch: AtomicU64,
used: AtomicU64,
}
impl CidrUserDirectionState {
fn sync_epoch_and_mark_active(&self, epoch: u64, active_users: &AtomicU64) {
let current = self.epoch.load(Ordering::Relaxed);
if current == epoch {
return;
}
if current < epoch
&& self
.epoch
.compare_exchange(current, epoch, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
self.used.store(0, Ordering::Relaxed);
active_users.fetch_add(1, Ordering::Relaxed);
}
}
fn refund(&self, bytes: u64) {
if bytes == 0 {
return;
}
decrement_atomic_saturating(&self.used, bytes);
}
}
struct CidrUserShare {
active_conns: AtomicU64,
up: CidrUserDirectionState,
down: CidrUserDirectionState,
}
impl CidrUserShare {
fn new() -> Self {
Self {
active_conns: AtomicU64::new(0),
up: CidrUserDirectionState::default(),
down: CidrUserDirectionState::default(),
}
}
}
struct CidrBucket {
rates: AtomicRatePair,
up: CidrDirectionBucket,
down: CidrDirectionBucket,
users: ShardedRegistry<CidrUserShare>,
active_leases: AtomicU64,
}
impl CidrBucket {
fn new(limits: RateLimitBps) -> Self {
let rates = AtomicRatePair::default();
rates.set(limits);
Self {
rates,
up: CidrDirectionBucket::default(),
down: CidrDirectionBucket::default(),
users: ShardedRegistry::new(REGISTRY_SHARDS),
active_leases: AtomicU64::new(0),
}
}
fn set_rates(&self, limits: RateLimitBps) {
self.rates.set(limits);
}
fn acquire_user_share(&self, user: &str) -> Arc<CidrUserShare> {
let share = self.users.get_or_insert_with(user, CidrUserShare::new);
share.active_conns.fetch_add(1, Ordering::Relaxed);
share
}
fn release_user_share(&self, user: &str, share: &Arc<CidrUserShare>) {
decrement_atomic_saturating(&share.active_conns, 1);
let share_for_remove = Arc::clone(share);
let _ = self.users.remove_if(user, |candidate| {
Arc::ptr_eq(candidate, &share_for_remove)
&& candidate.active_conns.load(Ordering::Relaxed) == 0
});
}
fn try_consume_for_user(
&self,
direction: RateDirection,
share: &CidrUserShare,
requested: u64,
) -> u64 {
let cap_bps = self.rates.get(direction);
if cap_bps == 0 {
return requested;
}
let cap_epoch = bytes_per_epoch(cap_bps);
match direction {
RateDirection::Up => self.up.try_consume(&share.up, cap_epoch, requested),
RateDirection::Down => self.down.try_consume(&share.down, cap_epoch, requested),
}
}
fn refund_for_user(&self, direction: RateDirection, share: &CidrUserShare, bytes: u64) {
match direction {
RateDirection::Up => {
self.up.refund(bytes);
share.up.refund(bytes);
}
RateDirection::Down => {
self.down.refund(bytes);
share.down.refund(bytes);
}
}
}
fn cleanup_idle_users(&self) {
self.users
.retain(|_, share| share.active_conns.load(Ordering::Relaxed) > 0);
}
}
#[derive(Clone)]
struct CidrRule {
key: String,
cidr: IpNetwork,
limits: RateLimitBps,
prefix_len: u8,
}
#[derive(Default)]
struct PolicySnapshot {
user_limits: HashMap<String, RateLimitBps>,
cidr_rules_v4: Vec<CidrRule>,
cidr_rules_v6: Vec<CidrRule>,
cidr_rule_keys: HashSet<String>,
}
impl PolicySnapshot {
fn match_cidr(&self, ip: IpAddr) -> Option<&CidrRule> {
match ip {
IpAddr::V4(_) => self
.cidr_rules_v4
.iter()
.find(|rule| rule.cidr.contains(ip)),
IpAddr::V6(_) => self
.cidr_rules_v6
.iter()
.find(|rule| rule.cidr.contains(ip)),
}
}
}
struct ShardedRegistry<T> {
shards: Box<[DashMap<String, Arc<T>>]>,
mask: usize,
}
impl<T> ShardedRegistry<T> {
fn new(shards: usize) -> Self {
let shard_count = shards.max(1).next_power_of_two();
let mut items = Vec::with_capacity(shard_count);
for _ in 0..shard_count {
items.push(DashMap::<String, Arc<T>>::new());
}
Self {
shards: items.into_boxed_slice(),
mask: shard_count.saturating_sub(1),
}
}
fn shard_index(&self, key: &str) -> usize {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
key.hash(&mut hasher);
(hasher.finish() as usize) & self.mask
}
fn get_or_insert_with<F>(&self, key: &str, make: F) -> Arc<T>
where
F: FnOnce() -> T,
{
let shard = &self.shards[self.shard_index(key)];
match shard.entry(key.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(slot) => {
let value = Arc::new(make());
slot.insert(Arc::clone(&value));
value
}
}
}
fn retain<F>(&self, predicate: F)
where
F: Fn(&String, &Arc<T>) -> bool + Copy,
{
for shard in &*self.shards {
shard.retain(|key, value| predicate(key, value));
}
}
fn remove_if<F>(&self, key: &str, predicate: F) -> bool
where
F: Fn(&Arc<T>) -> bool,
{
let shard = &self.shards[self.shard_index(key)];
let should_remove = match shard.get(key) {
Some(entry) => predicate(entry.value()),
None => false,
};
if !should_remove {
return false;
}
shard.remove(key).is_some()
}
}
pub struct TrafficLease {
limiter: Arc<TrafficLimiter>,
user_bucket: Option<Arc<UserBucket>>,
cidr_bucket: Option<Arc<CidrBucket>>,
cidr_user_key: Option<String>,
cidr_user_share: Option<Arc<CidrUserShare>>,
}
impl TrafficLease {
pub fn try_consume(&self, direction: RateDirection, requested: u64) -> TrafficConsumeResult {
if requested == 0 {
return TrafficConsumeResult {
granted: 0,
blocked_user: false,
blocked_cidr: false,
};
}
let mut granted = requested;
if let Some(user_bucket) = self.user_bucket.as_ref() {
let user_granted = user_bucket.try_consume(direction, granted);
if user_granted == 0 {
self.limiter.observe_throttle(direction, true, false);
return TrafficConsumeResult {
granted: 0,
blocked_user: true,
blocked_cidr: false,
};
}
granted = user_granted;
}
if let (Some(cidr_bucket), Some(cidr_user_share)) =
(self.cidr_bucket.as_ref(), self.cidr_user_share.as_ref())
{
let cidr_granted =
cidr_bucket.try_consume_for_user(direction, cidr_user_share, granted);
if cidr_granted < granted
&& let Some(user_bucket) = self.user_bucket.as_ref()
{
user_bucket.refund(direction, granted.saturating_sub(cidr_granted));
}
if cidr_granted == 0 {
self.limiter.observe_throttle(direction, false, true);
return TrafficConsumeResult {
granted: 0,
blocked_user: false,
blocked_cidr: true,
};
}
granted = cidr_granted;
}
TrafficConsumeResult {
granted,
blocked_user: false,
blocked_cidr: false,
}
}
pub fn refund(&self, direction: RateDirection, bytes: u64) {
if bytes == 0 {
return;
}
if let Some(user_bucket) = self.user_bucket.as_ref() {
user_bucket.refund(direction, bytes);
}
if let (Some(cidr_bucket), Some(cidr_user_share)) =
(self.cidr_bucket.as_ref(), self.cidr_user_share.as_ref())
{
cidr_bucket.refund_for_user(direction, cidr_user_share, bytes);
}
}
pub fn observe_wait_ms(
&self,
direction: RateDirection,
blocked_user: bool,
blocked_cidr: bool,
wait_ms: u64,
) {
if wait_ms == 0 {
return;
}
self.limiter
.observe_wait(direction, blocked_user, blocked_cidr, wait_ms);
}
}
impl Drop for TrafficLease {
fn drop(&mut self) {
if let Some(bucket) = self.user_bucket.as_ref() {
decrement_atomic_saturating(&bucket.active_leases, 1);
decrement_atomic_saturating(&self.limiter.user_scope.active_leases, 1);
}
if let Some(bucket) = self.cidr_bucket.as_ref() {
if let (Some(user_key), Some(share)) =
(self.cidr_user_key.as_ref(), self.cidr_user_share.as_ref())
{
bucket.release_user_share(user_key, share);
}
decrement_atomic_saturating(&bucket.active_leases, 1);
decrement_atomic_saturating(&self.limiter.cidr_scope.active_leases, 1);
}
}
}
pub struct TrafficLimiter {
policy: ArcSwap<PolicySnapshot>,
user_buckets: ShardedRegistry<UserBucket>,
cidr_buckets: ShardedRegistry<CidrBucket>,
user_scope: ScopeMetrics,
cidr_scope: ScopeMetrics,
last_cleanup_epoch_secs: AtomicU64,
}
impl TrafficLimiter {
pub fn new() -> Arc<Self> {
Arc::new(Self {
policy: ArcSwap::from_pointee(PolicySnapshot::default()),
user_buckets: ShardedRegistry::new(REGISTRY_SHARDS),
cidr_buckets: ShardedRegistry::new(REGISTRY_SHARDS),
user_scope: ScopeMetrics::default(),
cidr_scope: ScopeMetrics::default(),
last_cleanup_epoch_secs: AtomicU64::new(0),
})
}
pub fn apply_policy(
&self,
user_limits: HashMap<String, RateLimitBps>,
cidr_limits: HashMap<IpNetwork, RateLimitBps>,
) {
let filtered_users = user_limits
.into_iter()
.filter(|(_, limit)| limit.up_bps > 0 || limit.down_bps > 0)
.collect::<HashMap<_, _>>();
let mut cidr_rules_v4 = Vec::new();
let mut cidr_rules_v6 = Vec::new();
let mut cidr_rule_keys = HashSet::new();
for (cidr, limits) in cidr_limits {
if limits.up_bps == 0 && limits.down_bps == 0 {
continue;
}
let key = cidr.to_string();
let rule = CidrRule {
key: key.clone(),
cidr,
limits,
prefix_len: cidr.prefix(),
};
cidr_rule_keys.insert(key);
match rule.cidr {
IpNetwork::V4(_) => cidr_rules_v4.push(rule),
IpNetwork::V6(_) => cidr_rules_v6.push(rule),
}
}
cidr_rules_v4.sort_by(|a, b| b.prefix_len.cmp(&a.prefix_len));
cidr_rules_v6.sort_by(|a, b| b.prefix_len.cmp(&a.prefix_len));
self.user_scope
.policy_entries
.store(filtered_users.len() as u64, Ordering::Relaxed);
self.cidr_scope
.policy_entries
.store(cidr_rule_keys.len() as u64, Ordering::Relaxed);
self.policy.store(Arc::new(PolicySnapshot {
user_limits: filtered_users,
cidr_rules_v4,
cidr_rules_v6,
cidr_rule_keys,
}));
self.maybe_cleanup();
}
pub fn acquire_lease(
self: &Arc<Self>,
user: &str,
client_ip: IpAddr,
) -> Option<Arc<TrafficLease>> {
let policy = self.policy.load_full();
let mut user_bucket = None;
if let Some(limit) = policy.user_limits.get(user).copied() {
let bucket = self
.user_buckets
.get_or_insert_with(user, || UserBucket::new(limit));
bucket.set_rates(limit);
bucket.active_leases.fetch_add(1, Ordering::Relaxed);
self.user_scope
.active_leases
.fetch_add(1, Ordering::Relaxed);
user_bucket = Some(bucket);
}
let mut cidr_bucket = None;
let mut cidr_user_key = None;
let mut cidr_user_share = None;
if let Some(rule) = policy.match_cidr(client_ip) {
let bucket = self
.cidr_buckets
.get_or_insert_with(rule.key.as_str(), || CidrBucket::new(rule.limits));
bucket.set_rates(rule.limits);
bucket.active_leases.fetch_add(1, Ordering::Relaxed);
self.cidr_scope
.active_leases
.fetch_add(1, Ordering::Relaxed);
let share = bucket.acquire_user_share(user);
cidr_user_key = Some(user.to_string());
cidr_user_share = Some(share);
cidr_bucket = Some(bucket);
}
if user_bucket.is_none() && cidr_bucket.is_none() {
return None;
}
self.maybe_cleanup();
Some(Arc::new(TrafficLease {
limiter: Arc::clone(self),
user_bucket,
cidr_bucket,
cidr_user_key,
cidr_user_share,
}))
}
pub fn metrics_snapshot(&self) -> TrafficLimiterMetricsSnapshot {
TrafficLimiterMetricsSnapshot {
user_throttle_up_total: self.user_scope.throttle_up_total.load(Ordering::Relaxed),
user_throttle_down_total: self.user_scope.throttle_down_total.load(Ordering::Relaxed),
cidr_throttle_up_total: self.cidr_scope.throttle_up_total.load(Ordering::Relaxed),
cidr_throttle_down_total: self.cidr_scope.throttle_down_total.load(Ordering::Relaxed),
user_wait_up_ms_total: self.user_scope.wait_up_ms_total.load(Ordering::Relaxed),
user_wait_down_ms_total: self.user_scope.wait_down_ms_total.load(Ordering::Relaxed),
cidr_wait_up_ms_total: self.cidr_scope.wait_up_ms_total.load(Ordering::Relaxed),
cidr_wait_down_ms_total: self.cidr_scope.wait_down_ms_total.load(Ordering::Relaxed),
user_active_leases: self.user_scope.active_leases.load(Ordering::Relaxed),
cidr_active_leases: self.cidr_scope.active_leases.load(Ordering::Relaxed),
user_policy_entries: self.user_scope.policy_entries.load(Ordering::Relaxed),
cidr_policy_entries: self.cidr_scope.policy_entries.load(Ordering::Relaxed),
}
}
fn observe_throttle(&self, direction: RateDirection, blocked_user: bool, blocked_cidr: bool) {
if blocked_user {
self.user_scope.throttle(direction);
}
if blocked_cidr {
self.cidr_scope.throttle(direction);
}
}
fn observe_wait(
&self,
direction: RateDirection,
blocked_user: bool,
blocked_cidr: bool,
wait_ms: u64,
) {
if blocked_user {
self.user_scope.wait_ms(direction, wait_ms);
}
if blocked_cidr {
self.cidr_scope.wait_ms(direction, wait_ms);
}
}
fn maybe_cleanup(&self) {
let now_epoch_secs = now_epoch_secs();
let last = self.last_cleanup_epoch_secs.load(Ordering::Relaxed);
if now_epoch_secs.saturating_sub(last) < CLEANUP_INTERVAL_SECS {
return;
}
if self
.last_cleanup_epoch_secs
.compare_exchange(last, now_epoch_secs, Ordering::Relaxed, Ordering::Relaxed)
.is_err()
{
return;
}
let policy = self.policy.load_full();
self.user_buckets.retain(|user, bucket| {
bucket.active_leases.load(Ordering::Relaxed) > 0
|| policy.user_limits.contains_key(user)
});
self.cidr_buckets.retain(|cidr_key, bucket| {
bucket.cleanup_idle_users();
bucket.active_leases.load(Ordering::Relaxed) > 0
|| policy.cidr_rule_keys.contains(cidr_key)
});
}
}
pub fn next_refill_delay() -> Duration {
let start = limiter_epoch_start();
let elapsed_ms = start.elapsed().as_millis() as u64;
let epoch_pos = elapsed_ms % FAIR_EPOCH_MS;
let wait_ms = FAIR_EPOCH_MS.saturating_sub(epoch_pos).max(1);
Duration::from_millis(wait_ms)
}
fn decrement_atomic_saturating(counter: &AtomicU64, by: u64) {
if by == 0 {
return;
}
let mut current = counter.load(Ordering::Relaxed);
loop {
if current == 0 {
return;
}
let next = current.saturating_sub(by);
match counter.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) {
Ok(_) => return,
Err(actual) => current = actual,
}
}
}
fn now_epoch_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn bytes_per_epoch(bps: u64) -> u64 {
if bps == 0 {
return 0;
}
let numerator = bps.saturating_mul(FAIR_EPOCH_MS);
let bytes = numerator.saturating_div(8_000);
bytes.max(1)
}
fn current_epoch() -> u64 {
let start = limiter_epoch_start();
let elapsed_ms = start.elapsed().as_millis() as u64;
elapsed_ms / FAIR_EPOCH_MS
}
fn limiter_epoch_start() -> &'static Instant {
static START: OnceLock<Instant> = OnceLock::new();
START.get_or_init(Instant::now)
}

View File

@@ -175,6 +175,18 @@ pub struct Stats {
me_route_drop_queue_full: AtomicU64,
me_route_drop_queue_full_base: AtomicU64,
me_route_drop_queue_full_high: AtomicU64,
me_fair_pressure_state_gauge: AtomicU64,
me_fair_active_flows_gauge: AtomicU64,
me_fair_queued_bytes_gauge: AtomicU64,
me_fair_standing_flows_gauge: AtomicU64,
me_fair_backpressured_flows_gauge: AtomicU64,
me_fair_scheduler_rounds_total: AtomicU64,
me_fair_deficit_grants_total: AtomicU64,
me_fair_deficit_skips_total: AtomicU64,
me_fair_enqueue_rejects_total: AtomicU64,
me_fair_shed_drops_total: AtomicU64,
me_fair_penalties_total: AtomicU64,
me_fair_downstream_stalls_total: AtomicU64,
me_d2c_batches_total: AtomicU64,
me_d2c_batch_frames_total: AtomicU64,
me_d2c_batch_bytes_total: AtomicU64,
@@ -856,6 +868,78 @@ impl Stats {
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn set_me_fair_pressure_state_gauge(&self, value: u64) {
if self.telemetry_me_allows_normal() {
self.me_fair_pressure_state_gauge
.store(value, Ordering::Relaxed);
}
}
pub fn set_me_fair_active_flows_gauge(&self, value: u64) {
if self.telemetry_me_allows_normal() {
self.me_fair_active_flows_gauge
.store(value, Ordering::Relaxed);
}
}
pub fn set_me_fair_queued_bytes_gauge(&self, value: u64) {
if self.telemetry_me_allows_normal() {
self.me_fair_queued_bytes_gauge
.store(value, Ordering::Relaxed);
}
}
pub fn set_me_fair_standing_flows_gauge(&self, value: u64) {
if self.telemetry_me_allows_normal() {
self.me_fair_standing_flows_gauge
.store(value, Ordering::Relaxed);
}
}
pub fn set_me_fair_backpressured_flows_gauge(&self, value: u64) {
if self.telemetry_me_allows_normal() {
self.me_fair_backpressured_flows_gauge
.store(value, Ordering::Relaxed);
}
}
pub fn add_me_fair_scheduler_rounds_total(&self, value: u64) {
if self.telemetry_me_allows_normal() && value > 0 {
self.me_fair_scheduler_rounds_total
.fetch_add(value, Ordering::Relaxed);
}
}
pub fn add_me_fair_deficit_grants_total(&self, value: u64) {
if self.telemetry_me_allows_normal() && value > 0 {
self.me_fair_deficit_grants_total
.fetch_add(value, Ordering::Relaxed);
}
}
pub fn add_me_fair_deficit_skips_total(&self, value: u64) {
if self.telemetry_me_allows_normal() && value > 0 {
self.me_fair_deficit_skips_total
.fetch_add(value, Ordering::Relaxed);
}
}
pub fn add_me_fair_enqueue_rejects_total(&self, value: u64) {
if self.telemetry_me_allows_normal() && value > 0 {
self.me_fair_enqueue_rejects_total
.fetch_add(value, Ordering::Relaxed);
}
}
pub fn add_me_fair_shed_drops_total(&self, value: u64) {
if self.telemetry_me_allows_normal() && value > 0 {
self.me_fair_shed_drops_total
.fetch_add(value, Ordering::Relaxed);
}
}
pub fn add_me_fair_penalties_total(&self, value: u64) {
if self.telemetry_me_allows_normal() && value > 0 {
self.me_fair_penalties_total
.fetch_add(value, Ordering::Relaxed);
}
}
pub fn add_me_fair_downstream_stalls_total(&self, value: u64) {
if self.telemetry_me_allows_normal() && value > 0 {
self.me_fair_downstream_stalls_total
.fetch_add(value, Ordering::Relaxed);
}
}
pub fn increment_me_d2c_batches_total(&self) {
if self.telemetry_me_allows_normal() {
self.me_d2c_batches_total.fetch_add(1, Ordering::Relaxed);
@@ -1806,6 +1890,43 @@ impl Stats {
pub fn get_me_route_drop_queue_full_high(&self) -> u64 {
self.me_route_drop_queue_full_high.load(Ordering::Relaxed)
}
pub fn get_me_fair_pressure_state_gauge(&self) -> u64 {
self.me_fair_pressure_state_gauge.load(Ordering::Relaxed)
}
pub fn get_me_fair_active_flows_gauge(&self) -> u64 {
self.me_fair_active_flows_gauge.load(Ordering::Relaxed)
}
pub fn get_me_fair_queued_bytes_gauge(&self) -> u64 {
self.me_fair_queued_bytes_gauge.load(Ordering::Relaxed)
}
pub fn get_me_fair_standing_flows_gauge(&self) -> u64 {
self.me_fair_standing_flows_gauge.load(Ordering::Relaxed)
}
pub fn get_me_fair_backpressured_flows_gauge(&self) -> u64 {
self.me_fair_backpressured_flows_gauge
.load(Ordering::Relaxed)
}
pub fn get_me_fair_scheduler_rounds_total(&self) -> u64 {
self.me_fair_scheduler_rounds_total.load(Ordering::Relaxed)
}
pub fn get_me_fair_deficit_grants_total(&self) -> u64 {
self.me_fair_deficit_grants_total.load(Ordering::Relaxed)
}
pub fn get_me_fair_deficit_skips_total(&self) -> u64 {
self.me_fair_deficit_skips_total.load(Ordering::Relaxed)
}
pub fn get_me_fair_enqueue_rejects_total(&self) -> u64 {
self.me_fair_enqueue_rejects_total.load(Ordering::Relaxed)
}
pub fn get_me_fair_shed_drops_total(&self) -> u64 {
self.me_fair_shed_drops_total.load(Ordering::Relaxed)
}
pub fn get_me_fair_penalties_total(&self) -> u64 {
self.me_fair_penalties_total.load(Ordering::Relaxed)
}
pub fn get_me_fair_downstream_stalls_total(&self) -> u64 {
self.me_fair_downstream_stalls_total.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batches_total(&self) -> u64 {
self.me_d2c_batches_total.load(Ordering::Relaxed)
}

View File

@@ -11,6 +11,7 @@ use crc32fast::Hasher;
const MIN_APP_DATA: usize = 64;
const MAX_APP_DATA: usize = MAX_TLS_CIPHERTEXT_SIZE;
const MAX_TICKET_RECORDS: usize = 4;
fn jitter_and_clamp_sizes(sizes: &[usize], rng: &SecureRandom) -> Vec<usize> {
sizes
@@ -62,6 +63,53 @@ fn ensure_payload_capacity(mut sizes: Vec<usize>, payload_len: usize) -> Vec<usi
sizes
}
fn emulated_app_data_sizes(cached: &CachedTlsData) -> Vec<usize> {
match cached.behavior_profile.source {
TlsProfileSource::Raw | TlsProfileSource::Merged => {
if !cached.behavior_profile.app_data_record_sizes.is_empty() {
return cached.behavior_profile.app_data_record_sizes.clone();
}
}
TlsProfileSource::Default | TlsProfileSource::Rustls => {}
}
let mut sizes = cached.app_data_records_sizes.clone();
if sizes.is_empty() {
sizes.push(cached.total_app_data_len.max(1024));
}
sizes
}
fn emulated_change_cipher_spec_count(cached: &CachedTlsData) -> usize {
usize::from(cached.behavior_profile.change_cipher_spec_count.max(1))
}
fn emulated_ticket_record_sizes(
cached: &CachedTlsData,
new_session_tickets: u8,
rng: &SecureRandom,
) -> Vec<usize> {
let mut sizes = match cached.behavior_profile.source {
TlsProfileSource::Raw | TlsProfileSource::Merged => {
cached.behavior_profile.ticket_record_sizes.clone()
}
TlsProfileSource::Default | TlsProfileSource::Rustls => Vec::new(),
};
let target_count = sizes
.len()
.max(usize::from(
new_session_tickets.min(MAX_TICKET_RECORDS as u8),
))
.min(MAX_TICKET_RECORDS);
while sizes.len() < target_count {
sizes.push(rng.range(48) + 48);
}
sizes
}
fn build_compact_cert_info_payload(cert_info: &ParsedCertificateInfo) -> Option<Vec<u8>> {
let mut fields = Vec::new();
@@ -180,39 +228,21 @@ pub fn build_emulated_server_hello(
server_hello.extend_from_slice(&message);
// --- ChangeCipherSpec ---
let change_cipher_spec = [
let change_cipher_spec_count = emulated_change_cipher_spec_count(cached);
let mut change_cipher_spec = Vec::with_capacity(change_cipher_spec_count * 6);
for _ in 0..change_cipher_spec_count {
change_cipher_spec.extend_from_slice(&[
TLS_RECORD_CHANGE_CIPHER,
TLS_VERSION[0],
TLS_VERSION[1],
0x00,
0x01,
0x01,
];
]);
}
// --- ApplicationData (fake encrypted records) ---
let sizes = match cached.behavior_profile.source {
TlsProfileSource::Raw | TlsProfileSource::Merged => cached
.app_data_records_sizes
.first()
.copied()
.or_else(|| {
cached
.behavior_profile
.app_data_record_sizes
.first()
.copied()
})
.map(|size| vec![size])
.unwrap_or_else(|| vec![cached.total_app_data_len.max(1024)]),
_ => {
let mut sizes = cached.app_data_records_sizes.clone();
if sizes.is_empty() {
sizes.push(cached.total_app_data_len.max(1024));
}
sizes
}
};
let mut sizes = jitter_and_clamp_sizes(&sizes, rng);
let mut sizes = jitter_and_clamp_sizes(&emulated_app_data_sizes(cached), rng);
let compact_payload = cached
.cert_info
.as_ref()
@@ -299,10 +329,7 @@ pub fn build_emulated_server_hello(
// --- Combine ---
// Optional NewSessionTicket mimic records (opaque ApplicationData for fingerprint).
let mut tickets = Vec::new();
let ticket_count = new_session_tickets.min(4);
if ticket_count > 0 {
for _ in 0..ticket_count {
let ticket_len: usize = rng.range(48) + 48;
for ticket_len in emulated_ticket_record_sizes(cached, new_session_tickets, rng) {
let mut rec = Vec::with_capacity(5 + ticket_len);
rec.push(TLS_RECORD_APPLICATION);
rec.extend_from_slice(&TLS_VERSION);
@@ -310,7 +337,6 @@ pub fn build_emulated_server_hello(
rec.extend_from_slice(&rng.bytes(ticket_len));
tickets.extend_from_slice(&rec);
}
}
let mut response = Vec::with_capacity(
server_hello.len() + change_cipher_spec.len() + app_data.len() + tickets.len(),
@@ -334,6 +360,10 @@ pub fn build_emulated_server_hello(
#[path = "tests/emulator_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "tests/emulator_profile_fidelity_security_tests.rs"]
mod emulator_profile_fidelity_security_tests;
#[cfg(test)]
mod tests {
use std::time::SystemTime;
@@ -478,7 +508,7 @@ mod tests {
}
#[test]
fn test_build_emulated_server_hello_ignores_tail_records_for_raw_profile() {
fn test_build_emulated_server_hello_replays_tail_records_for_profiled_tls() {
let mut cached = make_cached(None);
cached.app_data_records_sizes = vec![27, 3905, 537, 69];
cached.total_app_data_len = 4538;
@@ -500,11 +530,19 @@ mod tests {
let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize;
let ccs_start = 5 + hello_len;
let app_start = ccs_start + 6;
let app_len =
u16::from_be_bytes([response[app_start + 3], response[app_start + 4]]) as usize;
let mut pos = ccs_start + 6;
let mut app_lengths = Vec::new();
while pos + 5 <= response.len() {
assert_eq!(response[pos], TLS_RECORD_APPLICATION);
let record_len = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize;
app_lengths.push(record_len);
pos += 5 + record_len;
}
assert_eq!(response[app_start], TLS_RECORD_APPLICATION);
assert_eq!(app_start + 5 + app_len, response.len());
assert_eq!(app_lengths.len(), 4);
assert_eq!(app_lengths[0], 64);
assert_eq!(app_lengths[3], 69);
assert!(app_lengths[1] >= 64);
assert!(app_lengths[2] >= 64);
}
}

View File

@@ -1,6 +1,7 @@
#![allow(clippy::too_many_arguments)]
use dashmap::DashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::{Duration, Instant};
@@ -793,6 +794,51 @@ async fn connect_tcp_with_upstream(
))
}
fn socket_addrs_from_upstream_stream(
stream: &UpstreamStream,
) -> (Option<SocketAddr>, Option<SocketAddr>) {
match stream {
UpstreamStream::Tcp(tcp) => (tcp.local_addr().ok(), tcp.peer_addr().ok()),
UpstreamStream::Shadowsocks(_) => (None, None),
}
}
fn build_tls_fetch_proxy_header(
proxy_protocol: u8,
src_addr: Option<SocketAddr>,
dst_addr: Option<SocketAddr>,
) -> Option<Vec<u8>> {
match proxy_protocol {
0 => None,
2 => {
let header = match (src_addr, dst_addr) {
(Some(src @ SocketAddr::V4(_)), Some(dst @ SocketAddr::V4(_)))
| (Some(src @ SocketAddr::V6(_)), Some(dst @ SocketAddr::V6(_))) => {
ProxyProtocolV2Builder::new().with_addrs(src, dst).build()
}
_ => ProxyProtocolV2Builder::new().build(),
};
Some(header)
}
_ => {
let header = match (src_addr, dst_addr) {
(Some(SocketAddr::V4(src)), Some(SocketAddr::V4(dst))) => {
ProxyProtocolV1Builder::new()
.tcp4(src.into(), dst.into())
.build()
}
(Some(SocketAddr::V6(src)), Some(SocketAddr::V6(dst))) => {
ProxyProtocolV1Builder::new()
.tcp6(src.into(), dst.into())
.build()
}
_ => ProxyProtocolV1Builder::new().build(),
};
Some(header)
}
}
}
fn encode_tls13_certificate_message(cert_chain_der: &[Vec<u8>]) -> Option<Vec<u8>> {
if cert_chain_der.is_empty() {
return None;
@@ -824,7 +870,7 @@ async fn fetch_via_raw_tls_stream<S>(
mut stream: S,
sni: &str,
connect_timeout: Duration,
proxy_protocol: u8,
proxy_header: Option<Vec<u8>>,
profile: TlsFetchProfile,
grease_enabled: bool,
deterministic: bool,
@@ -835,11 +881,7 @@ where
let rng = SecureRandom::new();
let client_hello = build_client_hello(sni, &rng, profile, grease_enabled, deterministic);
timeout(connect_timeout, async {
if proxy_protocol > 0 {
let header = match proxy_protocol {
2 => ProxyProtocolV2Builder::new().build(),
_ => ProxyProtocolV1Builder::new().build(),
};
if let Some(header) = proxy_header.as_ref() {
stream.write_all(&header).await?;
}
stream.write_all(&client_hello).await?;
@@ -921,11 +963,12 @@ async fn fetch_via_raw_tls(
sock = %sock_path,
"Raw TLS fetch using mask unix socket"
);
let proxy_header = build_tls_fetch_proxy_header(proxy_protocol, None, None);
return fetch_via_raw_tls_stream(
stream,
sni,
connect_timeout,
proxy_protocol,
proxy_header,
profile,
grease_enabled,
deterministic,
@@ -956,11 +999,13 @@ async fn fetch_via_raw_tls(
let stream =
connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route)
.await?;
let (src_addr, dst_addr) = socket_addrs_from_upstream_stream(&stream);
let proxy_header = build_tls_fetch_proxy_header(proxy_protocol, src_addr, dst_addr);
fetch_via_raw_tls_stream(
stream,
sni,
connect_timeout,
proxy_protocol,
proxy_header,
profile,
grease_enabled,
deterministic,
@@ -972,17 +1017,13 @@ async fn fetch_via_rustls_stream<S>(
mut stream: S,
host: &str,
sni: &str,
proxy_protocol: u8,
proxy_header: Option<Vec<u8>>,
) -> Result<TlsFetchResult>
where
S: AsyncRead + AsyncWrite + Unpin,
{
// rustls handshake path for certificate and basic negotiated metadata.
if proxy_protocol > 0 {
let header = match proxy_protocol {
2 => ProxyProtocolV2Builder::new().build(),
_ => ProxyProtocolV1Builder::new().build(),
};
if let Some(header) = proxy_header.as_ref() {
stream.write_all(&header).await?;
stream.flush().await?;
}
@@ -1082,7 +1123,8 @@ async fn fetch_via_rustls(
sock = %sock_path,
"Rustls fetch using mask unix socket"
);
return fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await;
let proxy_header = build_tls_fetch_proxy_header(proxy_protocol, None, None);
return fetch_via_rustls_stream(stream, host, sni, proxy_header).await;
}
Ok(Err(e)) => {
warn!(
@@ -1108,7 +1150,9 @@ async fn fetch_via_rustls(
let stream =
connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route)
.await?;
fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await
let (src_addr, dst_addr) = socket_addrs_from_upstream_stream(&stream);
let proxy_header = build_tls_fetch_proxy_header(proxy_protocol, src_addr, dst_addr);
fetch_via_rustls_stream(stream, host, sni, proxy_header).await
}
/// Fetch real TLS metadata with an adaptive multi-profile strategy.
@@ -1278,11 +1322,13 @@ pub async fn fetch_real_tls(
#[cfg(test)]
mod tests {
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use super::{
ProfileCacheValue, TlsFetchStrategy, build_client_hello, derive_behavior_profile,
encode_tls13_certificate_message, order_profiles, profile_cache, profile_cache_key,
ProfileCacheValue, TlsFetchStrategy, build_client_hello, build_tls_fetch_proxy_header,
derive_behavior_profile, encode_tls13_certificate_message, order_profiles, profile_cache,
profile_cache_key,
};
use crate::config::TlsFetchProfile;
use crate::crypto::SecureRandom;
@@ -1423,4 +1469,48 @@ mod tests {
assert_eq!(first, second);
}
#[test]
fn test_build_tls_fetch_proxy_header_v2_with_tcp_addrs() {
let src: SocketAddr = "198.51.100.10:42000".parse().expect("valid src");
let dst: SocketAddr = "203.0.113.20:443".parse().expect("valid dst");
let header = build_tls_fetch_proxy_header(2, Some(src), Some(dst)).expect("header");
assert_eq!(
&header[..12],
&[
0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a
]
);
assert_eq!(header[12], 0x21);
assert_eq!(header[13], 0x11);
assert_eq!(u16::from_be_bytes([header[14], header[15]]), 12);
assert_eq!(&header[16..20], &[198, 51, 100, 10]);
assert_eq!(&header[20..24], &[203, 0, 113, 20]);
assert_eq!(u16::from_be_bytes([header[24], header[25]]), 42000);
assert_eq!(u16::from_be_bytes([header[26], header[27]]), 443);
}
#[test]
fn test_build_tls_fetch_proxy_header_v2_mixed_family_falls_back_to_local_command() {
let src: SocketAddr = "198.51.100.10:42000".parse().expect("valid src");
let dst: SocketAddr = "[2001:db8::20]:443".parse().expect("valid dst");
let header = build_tls_fetch_proxy_header(2, Some(src), Some(dst)).expect("header");
assert_eq!(header[12], 0x20);
assert_eq!(header[13], 0x00);
assert_eq!(u16::from_be_bytes([header[14], header[15]]), 0);
}
#[test]
fn test_build_tls_fetch_proxy_header_v1_with_tcp_addrs() {
let src: SocketAddr = "198.51.100.10:42000".parse().expect("valid src");
let dst: SocketAddr = "203.0.113.20:443".parse().expect("valid dst");
let header = build_tls_fetch_proxy_header(1, Some(src), Some(dst)).expect("header");
assert_eq!(
header,
b"PROXY TCP4 198.51.100.10 203.0.113.20 42000 443\r\n"
);
}
}

View File

@@ -0,0 +1,95 @@
use std::time::SystemTime;
use crate::crypto::SecureRandom;
use crate::protocol::constants::{
TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE,
};
use crate::tls_front::emulator::build_emulated_server_hello;
use crate::tls_front::types::{
CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsProfileSource,
};
fn make_cached() -> CachedTlsData {
CachedTlsData {
server_hello_template: ParsedServerHello {
version: [0x03, 0x03],
random: [0u8; 32],
session_id: Vec::new(),
cipher_suite: [0x13, 0x01],
compression: 0,
extensions: Vec::new(),
},
cert_info: None,
cert_payload: None,
app_data_records_sizes: vec![1200, 900, 220, 180],
total_app_data_len: 2500,
behavior_profile: TlsBehaviorProfile {
change_cipher_spec_count: 2,
app_data_record_sizes: vec![1200, 900],
ticket_record_sizes: vec![220, 180],
source: TlsProfileSource::Merged,
},
fetched_at: SystemTime::now(),
domain: "example.com".to_string(),
}
}
fn record_lengths_by_type(response: &[u8], wanted_type: u8) -> Vec<usize> {
let mut out = Vec::new();
let mut pos = 0usize;
while pos + 5 <= response.len() {
let record_type = response[pos];
let record_len = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize;
if pos + 5 + record_len > response.len() {
break;
}
if record_type == wanted_type {
out.push(record_len);
}
pos += 5 + record_len;
}
out
}
#[test]
fn emulated_server_hello_replays_profile_change_cipher_spec_count() {
let cached = make_cached();
let rng = SecureRandom::new();
let response = build_emulated_server_hello(
b"secret",
&[0x71; 32],
&[0x72; 16],
&cached,
false,
&rng,
None,
0,
);
assert_eq!(response[0], TLS_RECORD_HANDSHAKE);
let ccs_records = record_lengths_by_type(&response, TLS_RECORD_CHANGE_CIPHER);
assert_eq!(ccs_records.len(), 2);
assert!(ccs_records.iter().all(|len| *len == 1));
}
#[test]
fn emulated_server_hello_replays_profile_ticket_tail_lengths() {
let cached = make_cached();
let rng = SecureRandom::new();
let response = build_emulated_server_hello(
b"secret",
&[0x81; 32],
&[0x82; 16],
&cached,
false,
&rng,
None,
0,
);
let app_records = record_lengths_by_type(&response, TLS_RECORD_APPLICATION);
assert!(app_records.len() >= 4);
assert_eq!(&app_records[app_records.len() - 2..], &[220, 180]);
}

View File

@@ -0,0 +1,13 @@
//! Backpressure-driven fairness control for ME reader routing.
//!
//! This module keeps fairness decisions worker-local:
//! each reader loop owns one scheduler instance and mutates it without locks.
mod model;
mod pressure;
mod scheduler;
#[cfg(test)]
pub(crate) use model::PressureState;
pub(crate) use model::{AdmissionDecision, DispatchAction, DispatchFeedback, SchedulerDecision};
pub(crate) use scheduler::{WorkerFairnessConfig, WorkerFairnessSnapshot, WorkerFairnessState};

View File

@@ -0,0 +1,140 @@
use std::time::Instant;
use bytes::Bytes;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(u8)]
pub(crate) enum PressureState {
Normal = 0,
Pressured = 1,
Shedding = 2,
Saturated = 3,
}
impl PressureState {
pub(crate) fn as_u8(self) -> u8 {
self as u8
}
}
impl Default for PressureState {
fn default() -> Self {
Self::Normal
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum FlowPressureClass {
Healthy,
Bursty,
Backpressured,
Standing,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum StandingQueueState {
Transient,
Standing,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum FlowSchedulerState {
Idle,
Active,
Backpressured,
Penalized,
SheddingCandidate,
}
#[derive(Debug, Clone)]
pub(crate) struct QueuedFrame {
pub(crate) conn_id: u64,
pub(crate) flags: u32,
pub(crate) data: Bytes,
pub(crate) enqueued_at: Instant,
}
impl QueuedFrame {
#[inline]
pub(crate) fn queued_bytes(&self) -> u64 {
self.data.len() as u64
}
}
#[derive(Debug, Clone)]
pub(crate) struct FlowFairnessState {
pub(crate) _flow_id: u64,
pub(crate) _worker_id: u16,
pub(crate) pending_bytes: u64,
pub(crate) deficit_bytes: i64,
pub(crate) queue_started_at: Option<Instant>,
pub(crate) last_drain_at: Option<Instant>,
pub(crate) recent_drain_bytes: u64,
pub(crate) consecutive_stalls: u8,
pub(crate) consecutive_skips: u8,
pub(crate) penalty_score: u16,
pub(crate) pressure_class: FlowPressureClass,
pub(crate) standing_state: StandingQueueState,
pub(crate) scheduler_state: FlowSchedulerState,
pub(crate) bucket_id: usize,
pub(crate) in_active_ring: bool,
}
impl FlowFairnessState {
pub(crate) fn new(flow_id: u64, worker_id: u16, bucket_id: usize) -> Self {
Self {
_flow_id: flow_id,
_worker_id: worker_id,
pending_bytes: 0,
deficit_bytes: 0,
queue_started_at: None,
last_drain_at: None,
recent_drain_bytes: 0,
consecutive_stalls: 0,
consecutive_skips: 0,
penalty_score: 0,
pressure_class: FlowPressureClass::Healthy,
standing_state: StandingQueueState::Transient,
scheduler_state: FlowSchedulerState::Idle,
bucket_id,
in_active_ring: false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum AdmissionDecision {
Admit,
RejectWorkerCap,
RejectFlowCap,
RejectBucketCap,
RejectSaturated,
RejectStandingFlow,
}
#[derive(Debug, Clone)]
pub(crate) enum SchedulerDecision {
Idle,
Dispatch(DispatchCandidate),
}
#[derive(Debug, Clone)]
pub(crate) struct DispatchCandidate {
pub(crate) frame: QueuedFrame,
pub(crate) pressure_state: PressureState,
pub(crate) flow_class: FlowPressureClass,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum DispatchFeedback {
Routed,
QueueFull,
ChannelClosed,
NoConn,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum DispatchAction {
Continue,
CloseFlow,
}

View File

@@ -0,0 +1,214 @@
use std::time::{Duration, Instant};
use super::model::PressureState;
#[derive(Debug, Clone, Copy)]
pub(crate) struct PressureSignals {
pub(crate) active_flows: usize,
pub(crate) total_queued_bytes: u64,
pub(crate) standing_flows: usize,
pub(crate) backpressured_flows: usize,
}
#[derive(Debug, Clone)]
pub(crate) struct PressureConfig {
pub(crate) evaluate_every_rounds: u32,
pub(crate) transition_hysteresis_rounds: u8,
pub(crate) standing_ratio_pressured_pct: u8,
pub(crate) standing_ratio_shedding_pct: u8,
pub(crate) standing_ratio_saturated_pct: u8,
pub(crate) queue_ratio_pressured_pct: u8,
pub(crate) queue_ratio_shedding_pct: u8,
pub(crate) queue_ratio_saturated_pct: u8,
pub(crate) reject_window: Duration,
pub(crate) rejects_pressured: u32,
pub(crate) rejects_shedding: u32,
pub(crate) rejects_saturated: u32,
pub(crate) stalls_pressured: u32,
pub(crate) stalls_shedding: u32,
pub(crate) stalls_saturated: u32,
}
impl Default for PressureConfig {
fn default() -> Self {
Self {
evaluate_every_rounds: 8,
transition_hysteresis_rounds: 3,
standing_ratio_pressured_pct: 20,
standing_ratio_shedding_pct: 35,
standing_ratio_saturated_pct: 50,
queue_ratio_pressured_pct: 65,
queue_ratio_shedding_pct: 82,
queue_ratio_saturated_pct: 94,
reject_window: Duration::from_secs(2),
rejects_pressured: 32,
rejects_shedding: 96,
rejects_saturated: 256,
stalls_pressured: 32,
stalls_shedding: 96,
stalls_saturated: 256,
}
}
}
#[derive(Debug)]
pub(crate) struct PressureEvaluator {
state: PressureState,
candidate_state: PressureState,
candidate_hits: u8,
rounds_since_eval: u32,
window_started_at: Instant,
admission_rejects_window: u32,
route_stalls_window: u32,
}
impl PressureEvaluator {
pub(crate) fn new(now: Instant) -> Self {
Self {
state: PressureState::Normal,
candidate_state: PressureState::Normal,
candidate_hits: 0,
rounds_since_eval: 0,
window_started_at: now,
admission_rejects_window: 0,
route_stalls_window: 0,
}
}
#[inline]
pub(crate) fn state(&self) -> PressureState {
self.state
}
pub(crate) fn note_admission_reject(&mut self, now: Instant, cfg: &PressureConfig) {
self.rotate_window_if_needed(now, cfg);
self.admission_rejects_window = self.admission_rejects_window.saturating_add(1);
}
pub(crate) fn note_route_stall(&mut self, now: Instant, cfg: &PressureConfig) {
self.rotate_window_if_needed(now, cfg);
self.route_stalls_window = self.route_stalls_window.saturating_add(1);
}
pub(crate) fn maybe_evaluate(
&mut self,
now: Instant,
cfg: &PressureConfig,
max_total_queued_bytes: u64,
signals: PressureSignals,
force: bool,
) -> PressureState {
self.rotate_window_if_needed(now, cfg);
self.rounds_since_eval = self.rounds_since_eval.saturating_add(1);
if !force && self.rounds_since_eval < cfg.evaluate_every_rounds.max(1) {
return self.state;
}
self.rounds_since_eval = 0;
let target = self.derive_target_state(cfg, max_total_queued_bytes, signals);
if target == self.state {
self.candidate_state = target;
self.candidate_hits = 0;
return self.state;
}
if self.candidate_state == target {
self.candidate_hits = self.candidate_hits.saturating_add(1);
} else {
self.candidate_state = target;
self.candidate_hits = 1;
}
if self.candidate_hits >= cfg.transition_hysteresis_rounds.max(1) {
self.state = target;
self.candidate_hits = 0;
}
self.state
}
fn derive_target_state(
&self,
cfg: &PressureConfig,
max_total_queued_bytes: u64,
signals: PressureSignals,
) -> PressureState {
let queue_ratio_pct = if max_total_queued_bytes == 0 {
100
} else {
((signals.total_queued_bytes.saturating_mul(100)) / max_total_queued_bytes).min(100)
as u8
};
let standing_ratio_pct = if signals.active_flows == 0 {
0
} else {
((signals.standing_flows.saturating_mul(100)) / signals.active_flows).min(100) as u8
};
let mut pressure_score = 0u8;
if queue_ratio_pct >= cfg.queue_ratio_pressured_pct {
pressure_score = pressure_score.max(1);
}
if queue_ratio_pct >= cfg.queue_ratio_shedding_pct {
pressure_score = pressure_score.max(2);
}
if queue_ratio_pct >= cfg.queue_ratio_saturated_pct {
pressure_score = pressure_score.max(3);
}
if standing_ratio_pct >= cfg.standing_ratio_pressured_pct {
pressure_score = pressure_score.max(1);
}
if standing_ratio_pct >= cfg.standing_ratio_shedding_pct {
pressure_score = pressure_score.max(2);
}
if standing_ratio_pct >= cfg.standing_ratio_saturated_pct {
pressure_score = pressure_score.max(3);
}
if self.admission_rejects_window >= cfg.rejects_pressured {
pressure_score = pressure_score.max(1);
}
if self.admission_rejects_window >= cfg.rejects_shedding {
pressure_score = pressure_score.max(2);
}
if self.admission_rejects_window >= cfg.rejects_saturated {
pressure_score = pressure_score.max(3);
}
if self.route_stalls_window >= cfg.stalls_pressured {
pressure_score = pressure_score.max(1);
}
if self.route_stalls_window >= cfg.stalls_shedding {
pressure_score = pressure_score.max(2);
}
if self.route_stalls_window >= cfg.stalls_saturated {
pressure_score = pressure_score.max(3);
}
if signals.backpressured_flows > signals.active_flows.saturating_div(2)
&& signals.active_flows > 0
{
pressure_score = pressure_score.max(2);
}
match pressure_score {
0 => PressureState::Normal,
1 => PressureState::Pressured,
2 => PressureState::Shedding,
_ => PressureState::Saturated,
}
}
fn rotate_window_if_needed(&mut self, now: Instant, cfg: &PressureConfig) {
if now.saturating_duration_since(self.window_started_at) < cfg.reject_window {
return;
}
self.window_started_at = now;
self.admission_rejects_window = 0;
self.route_stalls_window = 0;
}
}

View File

@@ -0,0 +1,556 @@
use std::collections::{HashMap, VecDeque};
use std::time::{Duration, Instant};
use bytes::Bytes;
use super::model::{
AdmissionDecision, DispatchAction, DispatchCandidate, DispatchFeedback, FlowFairnessState,
FlowPressureClass, FlowSchedulerState, PressureState, QueuedFrame, SchedulerDecision,
StandingQueueState,
};
use super::pressure::{PressureConfig, PressureEvaluator, PressureSignals};
#[derive(Debug, Clone)]
pub(crate) struct WorkerFairnessConfig {
pub(crate) worker_id: u16,
pub(crate) max_active_flows: usize,
pub(crate) max_total_queued_bytes: u64,
pub(crate) max_flow_queued_bytes: u64,
pub(crate) base_quantum_bytes: u32,
pub(crate) pressured_quantum_bytes: u32,
pub(crate) penalized_quantum_bytes: u32,
pub(crate) standing_queue_min_age: Duration,
pub(crate) standing_queue_min_backlog_bytes: u64,
pub(crate) standing_stall_threshold: u8,
pub(crate) max_consecutive_stalls_before_shed: u8,
pub(crate) max_consecutive_stalls_before_close: u8,
pub(crate) soft_bucket_count: usize,
pub(crate) soft_bucket_share_pct: u8,
pub(crate) pressure: PressureConfig,
}
impl Default for WorkerFairnessConfig {
fn default() -> Self {
Self {
worker_id: 0,
max_active_flows: 4096,
max_total_queued_bytes: 16 * 1024 * 1024,
max_flow_queued_bytes: 512 * 1024,
base_quantum_bytes: 32 * 1024,
pressured_quantum_bytes: 16 * 1024,
penalized_quantum_bytes: 8 * 1024,
standing_queue_min_age: Duration::from_millis(250),
standing_queue_min_backlog_bytes: 64 * 1024,
standing_stall_threshold: 3,
max_consecutive_stalls_before_shed: 4,
max_consecutive_stalls_before_close: 16,
soft_bucket_count: 64,
soft_bucket_share_pct: 25,
pressure: PressureConfig::default(),
}
}
}
struct FlowEntry {
fairness: FlowFairnessState,
queue: VecDeque<QueuedFrame>,
}
impl FlowEntry {
fn new(flow_id: u64, worker_id: u16, bucket_id: usize) -> Self {
Self {
fairness: FlowFairnessState::new(flow_id, worker_id, bucket_id),
queue: VecDeque::new(),
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub(crate) struct WorkerFairnessSnapshot {
pub(crate) pressure_state: PressureState,
pub(crate) active_flows: usize,
pub(crate) total_queued_bytes: u64,
pub(crate) standing_flows: usize,
pub(crate) backpressured_flows: usize,
pub(crate) scheduler_rounds: u64,
pub(crate) deficit_grants: u64,
pub(crate) deficit_skips: u64,
pub(crate) enqueue_rejects: u64,
pub(crate) shed_drops: u64,
pub(crate) fairness_penalties: u64,
pub(crate) downstream_stalls: u64,
}
pub(crate) struct WorkerFairnessState {
config: WorkerFairnessConfig,
pressure: PressureEvaluator,
flows: HashMap<u64, FlowEntry>,
active_ring: VecDeque<u64>,
total_queued_bytes: u64,
bucket_queued_bytes: Vec<u64>,
bucket_active_flows: Vec<usize>,
standing_flow_count: usize,
backpressured_flow_count: usize,
scheduler_rounds: u64,
deficit_grants: u64,
deficit_skips: u64,
enqueue_rejects: u64,
shed_drops: u64,
fairness_penalties: u64,
downstream_stalls: u64,
}
impl WorkerFairnessState {
pub(crate) fn new(config: WorkerFairnessConfig, now: Instant) -> Self {
let bucket_count = config.soft_bucket_count.max(1);
Self {
config,
pressure: PressureEvaluator::new(now),
flows: HashMap::new(),
active_ring: VecDeque::new(),
total_queued_bytes: 0,
bucket_queued_bytes: vec![0; bucket_count],
bucket_active_flows: vec![0; bucket_count],
standing_flow_count: 0,
backpressured_flow_count: 0,
scheduler_rounds: 0,
deficit_grants: 0,
deficit_skips: 0,
enqueue_rejects: 0,
shed_drops: 0,
fairness_penalties: 0,
downstream_stalls: 0,
}
}
pub(crate) fn pressure_state(&self) -> PressureState {
self.pressure.state()
}
pub(crate) fn snapshot(&self) -> WorkerFairnessSnapshot {
WorkerFairnessSnapshot {
pressure_state: self.pressure.state(),
active_flows: self.flows.len(),
total_queued_bytes: self.total_queued_bytes,
standing_flows: self.standing_flow_count,
backpressured_flows: self.backpressured_flow_count,
scheduler_rounds: self.scheduler_rounds,
deficit_grants: self.deficit_grants,
deficit_skips: self.deficit_skips,
enqueue_rejects: self.enqueue_rejects,
shed_drops: self.shed_drops,
fairness_penalties: self.fairness_penalties,
downstream_stalls: self.downstream_stalls,
}
}
pub(crate) fn enqueue_data(
&mut self,
conn_id: u64,
flags: u32,
data: Bytes,
now: Instant,
) -> AdmissionDecision {
let frame = QueuedFrame {
conn_id,
flags,
data,
enqueued_at: now,
};
let frame_bytes = frame.queued_bytes();
if self.pressure.state() == PressureState::Saturated {
self.pressure
.note_admission_reject(now, &self.config.pressure);
self.enqueue_rejects = self.enqueue_rejects.saturating_add(1);
return AdmissionDecision::RejectSaturated;
}
if self.total_queued_bytes.saturating_add(frame_bytes) > self.config.max_total_queued_bytes
{
self.pressure
.note_admission_reject(now, &self.config.pressure);
self.enqueue_rejects = self.enqueue_rejects.saturating_add(1);
self.evaluate_pressure(now, true);
return AdmissionDecision::RejectWorkerCap;
}
if !self.flows.contains_key(&conn_id) && self.flows.len() >= self.config.max_active_flows {
self.pressure
.note_admission_reject(now, &self.config.pressure);
self.enqueue_rejects = self.enqueue_rejects.saturating_add(1);
self.evaluate_pressure(now, true);
return AdmissionDecision::RejectWorkerCap;
}
let bucket_id = self.bucket_for(conn_id);
let bucket_cap = self
.config
.max_total_queued_bytes
.saturating_mul(self.config.soft_bucket_share_pct.max(1) as u64)
.saturating_div(100)
.max(self.config.max_flow_queued_bytes);
if self.bucket_queued_bytes[bucket_id].saturating_add(frame_bytes) > bucket_cap {
self.pressure
.note_admission_reject(now, &self.config.pressure);
self.enqueue_rejects = self.enqueue_rejects.saturating_add(1);
self.evaluate_pressure(now, true);
return AdmissionDecision::RejectBucketCap;
}
let entry = if let Some(flow) = self.flows.get_mut(&conn_id) {
flow
} else {
self.bucket_active_flows[bucket_id] =
self.bucket_active_flows[bucket_id].saturating_add(1);
self.flows.insert(
conn_id,
FlowEntry::new(conn_id, self.config.worker_id, bucket_id),
);
self.flows
.get_mut(&conn_id)
.expect("flow inserted must be retrievable")
};
if entry.fairness.pending_bytes.saturating_add(frame_bytes)
> self.config.max_flow_queued_bytes
{
self.pressure
.note_admission_reject(now, &self.config.pressure);
self.enqueue_rejects = self.enqueue_rejects.saturating_add(1);
self.evaluate_pressure(now, true);
return AdmissionDecision::RejectFlowCap;
}
if self.pressure.state() >= PressureState::Shedding
&& entry.fairness.standing_state == StandingQueueState::Standing
{
self.pressure
.note_admission_reject(now, &self.config.pressure);
self.enqueue_rejects = self.enqueue_rejects.saturating_add(1);
self.evaluate_pressure(now, true);
return AdmissionDecision::RejectStandingFlow;
}
entry.fairness.pending_bytes = entry.fairness.pending_bytes.saturating_add(frame_bytes);
if entry.fairness.queue_started_at.is_none() {
entry.fairness.queue_started_at = Some(now);
}
entry.queue.push_back(frame);
self.total_queued_bytes = self.total_queued_bytes.saturating_add(frame_bytes);
self.bucket_queued_bytes[bucket_id] =
self.bucket_queued_bytes[bucket_id].saturating_add(frame_bytes);
if !entry.fairness.in_active_ring {
entry.fairness.in_active_ring = true;
self.active_ring.push_back(conn_id);
}
self.evaluate_pressure(now, true);
AdmissionDecision::Admit
}
pub(crate) fn next_decision(&mut self, now: Instant) -> SchedulerDecision {
self.scheduler_rounds = self.scheduler_rounds.saturating_add(1);
self.evaluate_pressure(now, false);
let active_len = self.active_ring.len();
for _ in 0..active_len {
let Some(conn_id) = self.active_ring.pop_front() else {
break;
};
let mut candidate = None;
let mut requeue_active = false;
let mut drained_bytes = 0u64;
let mut bucket_id = 0usize;
let pressure_state = self.pressure.state();
if let Some(flow) = self.flows.get_mut(&conn_id) {
bucket_id = flow.fairness.bucket_id;
if flow.queue.is_empty() {
flow.fairness.in_active_ring = false;
flow.fairness.scheduler_state = FlowSchedulerState::Idle;
flow.fairness.pending_bytes = 0;
flow.fairness.queue_started_at = None;
continue;
}
Self::classify_flow(&self.config, pressure_state, now, &mut flow.fairness);
let quantum =
Self::effective_quantum_bytes(&self.config, pressure_state, &flow.fairness);
flow.fairness.deficit_bytes = flow
.fairness
.deficit_bytes
.saturating_add(i64::from(quantum));
self.deficit_grants = self.deficit_grants.saturating_add(1);
let front_len = flow.queue.front().map_or(0, |front| front.queued_bytes());
if flow.fairness.deficit_bytes < front_len as i64 {
flow.fairness.consecutive_skips =
flow.fairness.consecutive_skips.saturating_add(1);
self.deficit_skips = self.deficit_skips.saturating_add(1);
requeue_active = true;
} else if let Some(frame) = flow.queue.pop_front() {
drained_bytes = frame.queued_bytes();
flow.fairness.pending_bytes =
flow.fairness.pending_bytes.saturating_sub(drained_bytes);
flow.fairness.deficit_bytes = flow
.fairness
.deficit_bytes
.saturating_sub(drained_bytes as i64);
flow.fairness.consecutive_skips = 0;
flow.fairness.queue_started_at =
flow.queue.front().map(|front| front.enqueued_at);
requeue_active = !flow.queue.is_empty();
if !requeue_active {
flow.fairness.scheduler_state = FlowSchedulerState::Idle;
flow.fairness.in_active_ring = false;
}
candidate = Some(DispatchCandidate {
pressure_state,
flow_class: flow.fairness.pressure_class,
frame,
});
}
}
if drained_bytes > 0 {
self.total_queued_bytes = self.total_queued_bytes.saturating_sub(drained_bytes);
self.bucket_queued_bytes[bucket_id] =
self.bucket_queued_bytes[bucket_id].saturating_sub(drained_bytes);
}
if requeue_active {
if let Some(flow) = self.flows.get_mut(&conn_id) {
flow.fairness.in_active_ring = true;
}
self.active_ring.push_back(conn_id);
}
if let Some(candidate) = candidate {
return SchedulerDecision::Dispatch(candidate);
}
}
SchedulerDecision::Idle
}
pub(crate) fn apply_dispatch_feedback(
&mut self,
conn_id: u64,
candidate: DispatchCandidate,
feedback: DispatchFeedback,
now: Instant,
) -> DispatchAction {
match feedback {
DispatchFeedback::Routed => {
if let Some(flow) = self.flows.get_mut(&conn_id) {
flow.fairness.last_drain_at = Some(now);
flow.fairness.recent_drain_bytes = flow
.fairness
.recent_drain_bytes
.saturating_add(candidate.frame.queued_bytes());
flow.fairness.consecutive_stalls = 0;
if flow.fairness.scheduler_state != FlowSchedulerState::Idle {
flow.fairness.scheduler_state = FlowSchedulerState::Active;
}
}
self.evaluate_pressure(now, false);
DispatchAction::Continue
}
DispatchFeedback::QueueFull => {
self.pressure.note_route_stall(now, &self.config.pressure);
self.downstream_stalls = self.downstream_stalls.saturating_add(1);
let Some(flow) = self.flows.get_mut(&conn_id) else {
self.evaluate_pressure(now, true);
return DispatchAction::Continue;
};
flow.fairness.consecutive_stalls =
flow.fairness.consecutive_stalls.saturating_add(1);
flow.fairness.scheduler_state = FlowSchedulerState::Backpressured;
flow.fairness.pressure_class = FlowPressureClass::Backpressured;
let state = self.pressure.state();
let should_shed_frame = matches!(state, PressureState::Saturated)
|| (matches!(state, PressureState::Shedding)
&& flow.fairness.standing_state == StandingQueueState::Standing
&& flow.fairness.consecutive_stalls
>= self.config.max_consecutive_stalls_before_shed);
if should_shed_frame {
self.shed_drops = self.shed_drops.saturating_add(1);
self.fairness_penalties = self.fairness_penalties.saturating_add(1);
} else {
let frame_bytes = candidate.frame.queued_bytes();
flow.queue.push_front(candidate.frame);
flow.fairness.pending_bytes =
flow.fairness.pending_bytes.saturating_add(frame_bytes);
flow.fairness.queue_started_at =
flow.queue.front().map(|front| front.enqueued_at);
self.total_queued_bytes = self.total_queued_bytes.saturating_add(frame_bytes);
self.bucket_queued_bytes[flow.fairness.bucket_id] = self.bucket_queued_bytes
[flow.fairness.bucket_id]
.saturating_add(frame_bytes);
if !flow.fairness.in_active_ring {
flow.fairness.in_active_ring = true;
self.active_ring.push_back(conn_id);
}
}
if flow.fairness.consecutive_stalls
>= self.config.max_consecutive_stalls_before_close
&& self.pressure.state() == PressureState::Saturated
{
self.remove_flow(conn_id);
self.evaluate_pressure(now, true);
return DispatchAction::CloseFlow;
}
self.evaluate_pressure(now, true);
DispatchAction::Continue
}
DispatchFeedback::ChannelClosed | DispatchFeedback::NoConn => {
self.remove_flow(conn_id);
self.evaluate_pressure(now, true);
DispatchAction::CloseFlow
}
}
}
pub(crate) fn remove_flow(&mut self, conn_id: u64) {
let Some(entry) = self.flows.remove(&conn_id) else {
return;
};
self.bucket_active_flows[entry.fairness.bucket_id] =
self.bucket_active_flows[entry.fairness.bucket_id].saturating_sub(1);
let mut reclaimed = 0u64;
for frame in entry.queue {
reclaimed = reclaimed.saturating_add(frame.queued_bytes());
}
self.total_queued_bytes = self.total_queued_bytes.saturating_sub(reclaimed);
self.bucket_queued_bytes[entry.fairness.bucket_id] =
self.bucket_queued_bytes[entry.fairness.bucket_id].saturating_sub(reclaimed);
}
fn evaluate_pressure(&mut self, now: Instant, force: bool) {
let mut standing = 0usize;
let mut backpressured = 0usize;
for flow in self.flows.values_mut() {
Self::classify_flow(&self.config, self.pressure.state(), now, &mut flow.fairness);
if flow.fairness.standing_state == StandingQueueState::Standing {
standing = standing.saturating_add(1);
}
if matches!(
flow.fairness.scheduler_state,
FlowSchedulerState::Backpressured
| FlowSchedulerState::Penalized
| FlowSchedulerState::SheddingCandidate
) {
backpressured = backpressured.saturating_add(1);
}
}
self.standing_flow_count = standing;
self.backpressured_flow_count = backpressured;
let _ = self.pressure.maybe_evaluate(
now,
&self.config.pressure,
self.config.max_total_queued_bytes,
PressureSignals {
active_flows: self.flows.len(),
total_queued_bytes: self.total_queued_bytes,
standing_flows: standing,
backpressured_flows: backpressured,
},
force,
);
}
fn classify_flow(
config: &WorkerFairnessConfig,
pressure_state: PressureState,
now: Instant,
fairness: &mut FlowFairnessState,
) {
if fairness.pending_bytes == 0 {
fairness.pressure_class = FlowPressureClass::Healthy;
fairness.standing_state = StandingQueueState::Transient;
fairness.scheduler_state = FlowSchedulerState::Idle;
fairness.penalty_score = fairness.penalty_score.saturating_sub(1);
return;
}
let queue_age = fairness
.queue_started_at
.map(|ts| now.saturating_duration_since(ts))
.unwrap_or_default();
let drain_stalled = fairness
.last_drain_at
.map(|ts| now.saturating_duration_since(ts) >= config.standing_queue_min_age)
.unwrap_or(true);
let standing = fairness.pending_bytes >= config.standing_queue_min_backlog_bytes
&& queue_age >= config.standing_queue_min_age
&& (fairness.consecutive_stalls >= config.standing_stall_threshold || drain_stalled);
if standing {
fairness.standing_state = StandingQueueState::Standing;
fairness.pressure_class = FlowPressureClass::Standing;
fairness.penalty_score = fairness.penalty_score.saturating_add(1);
fairness.scheduler_state = if pressure_state >= PressureState::Shedding {
FlowSchedulerState::SheddingCandidate
} else {
FlowSchedulerState::Penalized
};
return;
}
fairness.standing_state = StandingQueueState::Transient;
if fairness.consecutive_stalls > 0 {
fairness.pressure_class = FlowPressureClass::Backpressured;
fairness.scheduler_state = FlowSchedulerState::Backpressured;
} else if fairness.pending_bytes >= config.standing_queue_min_backlog_bytes {
fairness.pressure_class = FlowPressureClass::Bursty;
fairness.scheduler_state = FlowSchedulerState::Active;
} else {
fairness.pressure_class = FlowPressureClass::Healthy;
fairness.scheduler_state = FlowSchedulerState::Active;
}
fairness.penalty_score = fairness.penalty_score.saturating_sub(1);
}
fn effective_quantum_bytes(
config: &WorkerFairnessConfig,
pressure_state: PressureState,
fairness: &FlowFairnessState,
) -> u32 {
let penalized = matches!(
fairness.scheduler_state,
FlowSchedulerState::Penalized | FlowSchedulerState::SheddingCandidate
);
if penalized {
return config.penalized_quantum_bytes.max(1);
}
match pressure_state {
PressureState::Normal => config.base_quantum_bytes.max(1),
PressureState::Pressured => config.pressured_quantum_bytes.max(1),
PressureState::Shedding => config.pressured_quantum_bytes.max(1),
PressureState::Saturated => config.penalized_quantum_bytes.max(1),
}
}
fn bucket_for(&self, conn_id: u64) -> usize {
(conn_id as usize) % self.bucket_queued_bytes.len().max(1)
}
}

View File

@@ -2,6 +2,10 @@
mod codec;
mod config_updater;
mod fairness;
#[cfg(test)]
#[path = "tests/fairness_security_tests.rs"]
mod fairness_security_tests;
mod handshake;
mod health;
#[cfg(test)]

View File

@@ -20,11 +20,15 @@ use crate::protocol::constants::*;
use crate::stats::Stats;
use super::codec::{RpcChecksumMode, WriterCommand, rpc_crc};
use super::fairness::{
AdmissionDecision, DispatchAction, DispatchFeedback, SchedulerDecision, WorkerFairnessConfig,
WorkerFairnessSnapshot, WorkerFairnessState,
};
use super::registry::RouteResult;
use super::{ConnRegistry, MeResponse};
const DATA_ROUTE_MAX_ATTEMPTS: usize = 3;
const DATA_ROUTE_QUEUE_FULL_STARVATION_THRESHOLD: u8 = 3;
const FAIRNESS_DRAIN_BUDGET_PER_LOOP: usize = 128;
fn should_close_on_route_result_for_data(result: RouteResult) -> bool {
matches!(result, RouteResult::NoConn | RouteResult::ChannelClosed)
@@ -77,6 +81,118 @@ async fn route_data_with_retry(
}
}
#[inline]
fn route_feedback(result: RouteResult) -> DispatchFeedback {
match result {
RouteResult::Routed => DispatchFeedback::Routed,
RouteResult::NoConn => DispatchFeedback::NoConn,
RouteResult::ChannelClosed => DispatchFeedback::ChannelClosed,
RouteResult::QueueFullBase | RouteResult::QueueFullHigh => DispatchFeedback::QueueFull,
}
}
fn report_route_drop(result: RouteResult, stats: &Stats) {
match result {
RouteResult::NoConn => stats.increment_me_route_drop_no_conn(),
RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(),
RouteResult::QueueFullBase => {
stats.increment_me_route_drop_queue_full();
stats.increment_me_route_drop_queue_full_base();
}
RouteResult::QueueFullHigh => {
stats.increment_me_route_drop_queue_full();
stats.increment_me_route_drop_queue_full_high();
}
RouteResult::Routed => {}
}
}
fn apply_fairness_metrics_delta(
stats: &Stats,
prev: &mut WorkerFairnessSnapshot,
current: WorkerFairnessSnapshot,
) {
stats.set_me_fair_active_flows_gauge(current.active_flows as u64);
stats.set_me_fair_queued_bytes_gauge(current.total_queued_bytes);
stats.set_me_fair_standing_flows_gauge(current.standing_flows as u64);
stats.set_me_fair_backpressured_flows_gauge(current.backpressured_flows as u64);
stats.set_me_fair_pressure_state_gauge(current.pressure_state.as_u8() as u64);
stats.add_me_fair_scheduler_rounds_total(
current
.scheduler_rounds
.saturating_sub(prev.scheduler_rounds),
);
stats.add_me_fair_deficit_grants_total(
current.deficit_grants.saturating_sub(prev.deficit_grants),
);
stats.add_me_fair_deficit_skips_total(current.deficit_skips.saturating_sub(prev.deficit_skips));
stats.add_me_fair_enqueue_rejects_total(
current.enqueue_rejects.saturating_sub(prev.enqueue_rejects),
);
stats.add_me_fair_shed_drops_total(current.shed_drops.saturating_sub(prev.shed_drops));
stats.add_me_fair_penalties_total(
current
.fairness_penalties
.saturating_sub(prev.fairness_penalties),
);
stats.add_me_fair_downstream_stalls_total(
current
.downstream_stalls
.saturating_sub(prev.downstream_stalls),
);
*prev = current;
}
async fn drain_fairness_scheduler(
fairness: &mut WorkerFairnessState,
reg: &ConnRegistry,
tx: &mpsc::Sender<WriterCommand>,
data_route_queue_full_streak: &mut HashMap<u64, u8>,
route_wait_ms: u64,
stats: &Stats,
) {
for _ in 0..FAIRNESS_DRAIN_BUDGET_PER_LOOP {
let now = Instant::now();
let SchedulerDecision::Dispatch(candidate) = fairness.next_decision(now) else {
break;
};
let cid = candidate.frame.conn_id;
let _pressure_state = candidate.pressure_state;
let _flow_class = candidate.flow_class;
let routed = route_data_with_retry(
reg,
cid,
candidate.frame.flags,
candidate.frame.data.clone(),
route_wait_ms,
)
.await;
if matches!(routed, RouteResult::Routed) {
data_route_queue_full_streak.remove(&cid);
} else {
report_route_drop(routed, stats);
}
let action = fairness.apply_dispatch_feedback(cid, candidate, route_feedback(routed), now);
if is_data_route_queue_full(routed) {
let streak = data_route_queue_full_streak.entry(cid).or_insert(0);
*streak = streak.saturating_add(1);
if should_close_on_queue_full_streak(*streak) {
fairness.remove_flow(cid);
data_route_queue_full_streak.remove(&cid);
reg.unregister(cid).await;
send_close_conn(tx, cid).await;
continue;
}
}
if action == DispatchAction::CloseFlow || should_close_on_route_result_for_data(routed) {
fairness.remove_flow(cid);
data_route_queue_full_streak.remove(&cid);
reg.unregister(cid).await;
send_close_conn(tx, cid).await;
}
}
}
pub(crate) async fn reader_loop(
mut rd: tokio::io::ReadHalf<TcpStream>,
dk: [u8; 32],
@@ -98,7 +214,21 @@ pub(crate) async fn reader_loop(
let mut raw = enc_leftover;
let mut expected_seq: i32 = 0;
let mut data_route_queue_full_streak = HashMap::<u64, u8>::new();
let mut fairness = WorkerFairnessState::new(
WorkerFairnessConfig {
worker_id: (writer_id as u16).saturating_add(1),
max_active_flows: reg.route_channel_capacity().saturating_mul(4).max(256),
max_total_queued_bytes: (reg.route_channel_capacity() as u64)
.saturating_mul(16 * 1024)
.max(4 * 1024 * 1024),
max_flow_queued_bytes: (reg.route_channel_capacity() as u64)
.saturating_mul(2 * 1024)
.clamp(64 * 1024, 2 * 1024 * 1024),
..WorkerFairnessConfig::default()
},
Instant::now(),
);
let mut fairness_snapshot = fairness.snapshot();
loop {
let mut tmp = [0u8; 65_536];
let n = tokio::select! {
@@ -181,36 +311,20 @@ pub(crate) async fn reader_loop(
let data = body.slice(12..);
trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS");
let route_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed);
let routed =
route_data_with_retry(reg.as_ref(), cid, flags, data, route_wait_ms).await;
if matches!(routed, RouteResult::Routed) {
data_route_queue_full_streak.remove(&cid);
continue;
}
match routed {
RouteResult::NoConn => stats.increment_me_route_drop_no_conn(),
RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(),
RouteResult::QueueFullBase => {
stats.increment_me_route_drop_queue_full();
stats.increment_me_route_drop_queue_full_base();
}
RouteResult::QueueFullHigh => {
let admission = fairness.enqueue_data(cid, flags, data, Instant::now());
if !matches!(admission, AdmissionDecision::Admit) {
stats.increment_me_route_drop_queue_full();
stats.increment_me_route_drop_queue_full_high();
}
RouteResult::Routed => {}
}
if should_close_on_route_result_for_data(routed) {
data_route_queue_full_streak.remove(&cid);
reg.unregister(cid).await;
send_close_conn(&tx, cid).await;
continue;
}
if is_data_route_queue_full(routed) {
let streak = data_route_queue_full_streak.entry(cid).or_insert(0);
*streak = streak.saturating_add(1);
if should_close_on_queue_full_streak(*streak) {
if should_close_on_queue_full_streak(*streak)
|| matches!(
admission,
AdmissionDecision::RejectSaturated
| AdmissionDecision::RejectStandingFlow
)
{
fairness.remove_flow(cid);
data_route_queue_full_streak.remove(&cid);
reg.unregister(cid).await;
send_close_conn(&tx, cid).await;
@@ -249,12 +363,14 @@ pub(crate) async fn reader_loop(
let _ = reg.route_nowait(cid, MeResponse::Close).await;
reg.unregister(cid).await;
data_route_queue_full_streak.remove(&cid);
fairness.remove_flow(cid);
} else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 {
let cid = u64::from_le_bytes(body[0..8].try_into().unwrap());
debug!(cid, "RPC_CLOSE_CONN from ME");
let _ = reg.route_nowait(cid, MeResponse::Close).await;
reg.unregister(cid).await;
data_route_queue_full_streak.remove(&cid);
fairness.remove_flow(cid);
} else if pt == RPC_PING_U32 && body.len() >= 8 {
let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap());
trace!(ping_id, "RPC_PING -> RPC_PONG");
@@ -310,6 +426,19 @@ pub(crate) async fn reader_loop(
"Unknown RPC"
);
}
let route_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed);
drain_fairness_scheduler(
&mut fairness,
reg.as_ref(),
&tx,
&mut data_route_queue_full_streak,
route_wait_ms,
stats.as_ref(),
)
.await;
let current_snapshot = fairness.snapshot();
apply_fairness_metrics_delta(stats.as_ref(), &mut fairness_snapshot, current_snapshot);
}
}
}

View File

@@ -140,6 +140,10 @@ impl ConnRegistry {
}
}
pub fn route_channel_capacity(&self) -> usize {
self.route_channel_capacity
}
#[cfg(test)]
pub fn new() -> Self {
Self::with_route_channel_capacity(4096)

View File

@@ -0,0 +1,185 @@
use std::time::{Duration, Instant};
use bytes::Bytes;
use crate::transport::middle_proxy::fairness::{
AdmissionDecision, DispatchAction, DispatchFeedback, PressureState, SchedulerDecision,
WorkerFairnessConfig, WorkerFairnessState,
};
fn enqueue_payload(size: usize) -> Bytes {
Bytes::from(vec![0xAB; size])
}
#[test]
fn fairness_rejects_when_worker_budget_is_exhausted() {
let now = Instant::now();
let mut fairness = WorkerFairnessState::new(
WorkerFairnessConfig {
max_total_queued_bytes: 1024,
max_flow_queued_bytes: 1024,
..WorkerFairnessConfig::default()
},
now,
);
assert_eq!(
fairness.enqueue_data(1, 0, enqueue_payload(700), now),
AdmissionDecision::Admit
);
assert_eq!(
fairness.enqueue_data(2, 0, enqueue_payload(400), now),
AdmissionDecision::RejectWorkerCap
);
let snapshot = fairness.snapshot();
assert!(snapshot.total_queued_bytes <= 1024);
assert_eq!(snapshot.enqueue_rejects, 1);
}
#[test]
fn fairness_marks_standing_queue_after_stall_and_age_threshold() {
let mut now = Instant::now();
let mut fairness = WorkerFairnessState::new(
WorkerFairnessConfig {
standing_queue_min_age: Duration::from_millis(50),
standing_queue_min_backlog_bytes: 256,
standing_stall_threshold: 1,
max_flow_queued_bytes: 4096,
max_total_queued_bytes: 4096,
..WorkerFairnessConfig::default()
},
now,
);
assert_eq!(
fairness.enqueue_data(11, 0, enqueue_payload(512), now),
AdmissionDecision::Admit
);
now += Duration::from_millis(100);
let SchedulerDecision::Dispatch(candidate) = fairness.next_decision(now) else {
panic!("expected dispatch candidate");
};
let action = fairness.apply_dispatch_feedback(11, candidate, DispatchFeedback::QueueFull, now);
assert!(matches!(action, DispatchAction::Continue));
let snapshot = fairness.snapshot();
assert_eq!(snapshot.standing_flows, 1);
assert!(snapshot.backpressured_flows >= 1);
}
#[test]
fn fairness_keeps_fast_flow_progress_under_slow_neighbor() {
let mut now = Instant::now();
let mut fairness = WorkerFairnessState::new(
WorkerFairnessConfig {
max_total_queued_bytes: 64 * 1024,
max_flow_queued_bytes: 32 * 1024,
..WorkerFairnessConfig::default()
},
now,
);
for _ in 0..16 {
assert_eq!(
fairness.enqueue_data(1, 0, enqueue_payload(512), now),
AdmissionDecision::Admit
);
assert_eq!(
fairness.enqueue_data(2, 0, enqueue_payload(512), now),
AdmissionDecision::Admit
);
}
let mut fast_routed = 0u64;
for _ in 0..128 {
now += Duration::from_millis(5);
let SchedulerDecision::Dispatch(candidate) = fairness.next_decision(now) else {
break;
};
let cid = candidate.frame.conn_id;
let feedback = if cid == 2 {
DispatchFeedback::QueueFull
} else {
fast_routed = fast_routed.saturating_add(1);
DispatchFeedback::Routed
};
let _ = fairness.apply_dispatch_feedback(cid, candidate, feedback, now);
}
let snapshot = fairness.snapshot();
assert!(fast_routed > 0, "fast flow must continue making progress");
assert!(snapshot.total_queued_bytes <= 64 * 1024);
}
#[test]
fn fairness_pressure_hysteresis_prevents_instant_flapping() {
let mut now = Instant::now();
let mut cfg = WorkerFairnessConfig::default();
cfg.max_total_queued_bytes = 4096;
cfg.max_flow_queued_bytes = 4096;
cfg.pressure.evaluate_every_rounds = 1;
cfg.pressure.transition_hysteresis_rounds = 3;
cfg.pressure.queue_ratio_pressured_pct = 40;
cfg.pressure.queue_ratio_shedding_pct = 60;
cfg.pressure.queue_ratio_saturated_pct = 80;
let mut fairness = WorkerFairnessState::new(cfg, now);
for _ in 0..4 {
assert_eq!(
fairness.enqueue_data(9, 0, enqueue_payload(900), now),
AdmissionDecision::Admit
);
}
for _ in 0..2 {
now += Duration::from_millis(1);
let _ = fairness.next_decision(now);
}
assert_eq!(
fairness.pressure_state(),
PressureState::Normal,
"state must not flip before hysteresis confirmations"
);
}
#[test]
fn fairness_randomized_sequence_preserves_memory_bounds() {
let mut now = Instant::now();
let mut fairness = WorkerFairnessState::new(
WorkerFairnessConfig {
max_total_queued_bytes: 32 * 1024,
max_flow_queued_bytes: 4 * 1024,
..WorkerFairnessConfig::default()
},
now,
);
let mut seed = 0xC0FFEE_u64;
for _ in 0..4096 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let flow = (seed % 32) + 1;
let size = ((seed >> 8) % 512 + 64) as usize;
let _ = fairness.enqueue_data(flow, 0, enqueue_payload(size), now);
now += Duration::from_millis(1);
if let SchedulerDecision::Dispatch(candidate) = fairness.next_decision(now) {
let feedback = if seed & 0x1 == 0 {
DispatchFeedback::Routed
} else {
DispatchFeedback::QueueFull
};
let _ =
fairness.apply_dispatch_feedback(candidate.frame.conn_id, candidate, feedback, now);
}
let snapshot = fairness.snapshot();
assert!(snapshot.total_queued_bytes <= 32 * 1024);
}
}