Merge pull request #745 from telemt/flow

API PATCH fixes + No IP tracking with disabled unique-IP limits + Bound hot-path pressure in ME Relay and Handshake + Bounded ME Route fairness and IP-Cleanup-Backlog + Bound relay queues by bytes
This commit is contained in:
Alexey
2026-04-25 14:45:34 +03:00
committed by GitHub
19 changed files with 754 additions and 99 deletions

2
Cargo.lock generated
View File

@@ -2791,7 +2791,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
[[package]]
name = "telemt"
version = "3.4.6"
version = "3.4.7"
dependencies = [
"aes",
"anyhow",

View File

@@ -1,6 +1,6 @@
[package]
name = "telemt"
version = "3.4.6"
version = "3.4.7"
edition = "2024"
[features]

View File

@@ -28,6 +28,7 @@ mod config_store;
mod events;
mod http_utils;
mod model;
mod patch;
mod runtime_edge;
mod runtime_init;
mod runtime_min;

View File

@@ -5,6 +5,7 @@ use chrono::{DateTime, Utc};
use hyper::StatusCode;
use serde::{Deserialize, Serialize};
use super::patch::{Patch, patch_field};
use crate::crypto::SecureRandom;
const MAX_USERNAME_LEN: usize = 64;
@@ -507,11 +508,16 @@ pub(super) struct CreateUserRequest {
#[derive(Deserialize)]
pub(super) struct PatchUserRequest {
pub(super) secret: Option<String>,
pub(super) user_ad_tag: Option<String>,
pub(super) max_tcp_conns: Option<usize>,
pub(super) expiration_rfc3339: Option<String>,
pub(super) data_quota_bytes: Option<u64>,
pub(super) max_unique_ips: Option<usize>,
#[serde(default, deserialize_with = "patch_field")]
pub(super) user_ad_tag: Patch<String>,
#[serde(default, deserialize_with = "patch_field")]
pub(super) max_tcp_conns: Patch<usize>,
#[serde(default, deserialize_with = "patch_field")]
pub(super) expiration_rfc3339: Patch<String>,
#[serde(default, deserialize_with = "patch_field")]
pub(super) data_quota_bytes: Patch<u64>,
#[serde(default, deserialize_with = "patch_field")]
pub(super) max_unique_ips: Patch<usize>,
}
#[derive(Default, Deserialize)]
@@ -530,6 +536,20 @@ pub(super) fn parse_optional_expiration(
Ok(Some(parsed.with_timezone(&Utc)))
}
pub(super) fn parse_patch_expiration(
value: &Patch<String>,
) -> Result<Patch<DateTime<Utc>>, ApiFailure> {
match value {
Patch::Unchanged => Ok(Patch::Unchanged),
Patch::Remove => Ok(Patch::Remove),
Patch::Set(raw) => {
let parsed = DateTime::parse_from_rfc3339(raw)
.map_err(|_| ApiFailure::bad_request("expiration_rfc3339 must be valid RFC3339"))?;
Ok(Patch::Set(parsed.with_timezone(&Utc)))
}
}
}
pub(super) fn is_valid_user_secret(secret: &str) -> bool {
secret.len() == 32 && secret.chars().all(|c| c.is_ascii_hexdigit())
}

130
src/api/patch.rs Normal file
View File

@@ -0,0 +1,130 @@
use serde::Deserialize;
/// Three-state field for JSON Merge Patch semantics on the `PATCH /v1/users/{user}`
/// endpoint.
///
/// `Unchanged` is produced when the JSON body omits the field entirely and tells the
/// handler to leave the corresponding configuration entry untouched. `Remove` is
/// produced when the JSON body sets the field to `null` and instructs the handler to
/// drop the entry from the corresponding access HashMap. `Set` carries an explicit
/// new value, including zero, which is preserved verbatim in the configuration.
#[derive(Debug)]
pub(super) enum Patch<T> {
Unchanged,
Remove,
Set(T),
}
impl<T> Default for Patch<T> {
fn default() -> Self {
Self::Unchanged
}
}
/// Serde deserializer adapter for fields that follow JSON Merge Patch semantics.
///
/// Pair this with `#[serde(default, deserialize_with = "patch_field")]` on a
/// `Patch<T>` field. An omitted field falls back to `Patch::Unchanged` via
/// `Default`; an explicit JSON `null` becomes `Patch::Remove`; any other value
/// becomes `Patch::Set(v)`.
pub(super) fn patch_field<'de, D, T>(deserializer: D) -> Result<Patch<T>, D::Error>
where
D: serde::Deserializer<'de>,
T: serde::Deserialize<'de>,
{
Option::<T>::deserialize(deserializer).map(|opt| match opt {
Some(value) => Patch::Set(value),
None => Patch::Remove,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::model::{PatchUserRequest, parse_patch_expiration};
use chrono::{TimeZone, Utc};
use serde::Deserialize;
#[derive(Deserialize)]
struct Holder {
#[serde(default, deserialize_with = "patch_field")]
value: Patch<u64>,
}
fn parse(json: &str) -> Holder {
serde_json::from_str(json).expect("valid json")
}
#[test]
fn omitted_field_yields_unchanged() {
let h = parse("{}");
assert!(matches!(h.value, Patch::Unchanged));
}
#[test]
fn explicit_null_yields_remove() {
let h = parse(r#"{"value": null}"#);
assert!(matches!(h.value, Patch::Remove));
}
#[test]
fn explicit_value_yields_set() {
let h = parse(r#"{"value": 42}"#);
assert!(matches!(h.value, Patch::Set(42)));
}
#[test]
fn explicit_zero_yields_set_zero() {
let h = parse(r#"{"value": 0}"#);
assert!(matches!(h.value, Patch::Set(0)));
}
#[test]
fn parse_patch_expiration_passes_unchanged_and_remove_through() {
assert!(matches!(
parse_patch_expiration(&Patch::Unchanged),
Ok(Patch::Unchanged)
));
assert!(matches!(
parse_patch_expiration(&Patch::Remove),
Ok(Patch::Remove)
));
}
#[test]
fn parse_patch_expiration_parses_set_value() {
let parsed =
parse_patch_expiration(&Patch::Set("2030-01-02T03:04:05Z".into())).expect("valid");
match parsed {
Patch::Set(dt) => {
assert_eq!(dt, Utc.with_ymd_and_hms(2030, 1, 2, 3, 4, 5).unwrap());
}
other => panic!("expected Patch::Set, got {:?}", other),
}
}
#[test]
fn parse_patch_expiration_rejects_invalid_set_value() {
assert!(parse_patch_expiration(&Patch::Set("not-a-date".into())).is_err());
}
#[test]
fn patch_user_request_deserializes_mixed_states() {
let raw = r#"{
"secret": "00112233445566778899aabbccddeeff",
"max_tcp_conns": 0,
"max_unique_ips": null,
"data_quota_bytes": 1024
}"#;
let req: PatchUserRequest = serde_json::from_str(raw).expect("valid json");
assert_eq!(
req.secret.as_deref(),
Some("00112233445566778899aabbccddeeff")
);
assert!(matches!(req.max_tcp_conns, Patch::Set(0)));
assert!(matches!(req.max_unique_ips, Patch::Remove));
assert!(matches!(req.data_quota_bytes, Patch::Set(1024)));
assert!(matches!(req.expiration_rfc3339, Patch::Unchanged));
assert!(matches!(req.user_ad_tag, Patch::Unchanged));
}
}

View File

@@ -14,8 +14,9 @@ use super::config_store::{
use super::model::{
ApiFailure, CreateUserRequest, CreateUserResponse, PatchUserRequest, RotateSecretRequest,
UserInfo, UserLinks, is_valid_ad_tag, is_valid_user_secret, is_valid_username,
parse_optional_expiration, random_user_secret,
parse_optional_expiration, parse_patch_expiration, random_user_secret,
};
use super::patch::Patch;
pub(super) async fn create_user(
body: CreateUserRequest,
@@ -182,14 +183,14 @@ pub(super) async fn patch_user(
"secret must be exactly 32 hex characters",
));
}
if let Some(ad_tag) = body.user_ad_tag.as_ref()
if let Patch::Set(ad_tag) = &body.user_ad_tag
&& !is_valid_ad_tag(ad_tag)
{
return Err(ApiFailure::bad_request(
"user_ad_tag must be exactly 32 hex characters",
));
}
let expiration = parse_optional_expiration(body.expiration_rfc3339.as_deref())?;
let expiration = parse_patch_expiration(&body.expiration_rfc3339)?;
let _guard = shared.mutation_lock.lock().await;
let mut cfg = load_config_from_disk(&shared.config_path).await?;
ensure_expected_revision(&shared.config_path, expected_revision.as_deref()).await?;
@@ -205,38 +206,71 @@ pub(super) async fn patch_user(
if let Some(secret) = body.secret {
cfg.access.users.insert(user.to_string(), secret);
}
if let Some(ad_tag) = body.user_ad_tag {
match body.user_ad_tag {
Patch::Unchanged => {}
Patch::Remove => {
cfg.access.user_ad_tags.remove(user);
}
Patch::Set(ad_tag) => {
cfg.access.user_ad_tags.insert(user.to_string(), ad_tag);
}
if let Some(limit) = body.max_tcp_conns {
}
match body.max_tcp_conns {
Patch::Unchanged => {}
Patch::Remove => {
cfg.access.user_max_tcp_conns.remove(user);
}
Patch::Set(limit) => {
cfg.access
.user_max_tcp_conns
.insert(user.to_string(), limit);
}
if let Some(expiration) = expiration {
}
match expiration {
Patch::Unchanged => {}
Patch::Remove => {
cfg.access.user_expirations.remove(user);
}
Patch::Set(expiration) => {
cfg.access
.user_expirations
.insert(user.to_string(), expiration);
}
if let Some(quota) = body.data_quota_bytes {
}
match body.data_quota_bytes {
Patch::Unchanged => {}
Patch::Remove => {
cfg.access.user_data_quota.remove(user);
}
Patch::Set(quota) => {
cfg.access.user_data_quota.insert(user.to_string(), quota);
}
let mut updated_limit = None;
if let Some(limit) = body.max_unique_ips {
}
// Capture how the per-user IP limit changed, so the in-memory ip_tracker
// can be synced (set or removed) after the config is persisted.
let max_unique_ips_change = match body.max_unique_ips {
Patch::Unchanged => None,
Patch::Remove => {
cfg.access.user_max_unique_ips.remove(user);
Some(None)
}
Patch::Set(limit) => {
cfg.access
.user_max_unique_ips
.insert(user.to_string(), limit);
updated_limit = Some(limit);
Some(Some(limit))
}
};
cfg.validate()
.map_err(|e| ApiFailure::bad_request(format!("config validation failed: {}", e)))?;
let revision = save_config_to_disk(&shared.config_path, &cfg).await?;
drop(_guard);
if let Some(limit) = updated_limit {
shared.ip_tracker.set_user_limit(user, limit).await;
match max_unique_ips_change {
Some(Some(limit)) => shared.ip_tracker.set_user_limit(user, limit).await,
Some(None) => shared.ip_tracker.remove_user_limit(user).await,
None => {}
}
let (detected_ip_v4, detected_ip_v6) = shared.detected_link_ips();
let users = users_from_config(

View File

@@ -22,7 +22,7 @@ pub struct UserIpTracker {
limit_mode: Arc<RwLock<UserMaxUniqueIpsMode>>,
limit_window: Arc<RwLock<Duration>>,
last_compact_epoch_secs: Arc<AtomicU64>,
cleanup_queue: Arc<Mutex<Vec<(String, IpAddr)>>>,
cleanup_queue: Arc<Mutex<HashMap<(String, IpAddr), usize>>>,
cleanup_drain_lock: Arc<AsyncMutex<()>>,
}
@@ -45,17 +45,21 @@ impl UserIpTracker {
limit_mode: Arc::new(RwLock::new(UserMaxUniqueIpsMode::ActiveWindow)),
limit_window: Arc::new(RwLock::new(Duration::from_secs(30))),
last_compact_epoch_secs: Arc::new(AtomicU64::new(0)),
cleanup_queue: Arc::new(Mutex::new(Vec::new())),
cleanup_queue: Arc::new(Mutex::new(HashMap::new())),
cleanup_drain_lock: Arc::new(AsyncMutex::new(())),
}
}
pub fn enqueue_cleanup(&self, user: String, ip: IpAddr) {
match self.cleanup_queue.lock() {
Ok(mut queue) => queue.push((user, ip)),
Ok(mut queue) => {
let count = queue.entry((user, ip)).or_insert(0);
*count = count.saturating_add(1);
}
Err(poisoned) => {
let mut queue = poisoned.into_inner();
queue.push((user.clone(), ip));
let count = queue.entry((user.clone(), ip)).or_insert(0);
*count = count.saturating_add(1);
self.cleanup_queue.clear_poison();
tracing::warn!(
"UserIpTracker cleanup_queue lock poisoned; recovered and enqueued IP cleanup for {} ({})",
@@ -75,7 +79,9 @@ impl UserIpTracker {
}
#[cfg(test)]
pub(crate) fn cleanup_queue_mutex_for_tests(&self) -> Arc<Mutex<Vec<(String, IpAddr)>>> {
pub(crate) fn cleanup_queue_mutex_for_tests(
&self,
) -> Arc<Mutex<HashMap<(String, IpAddr), usize>>> {
Arc::clone(&self.cleanup_queue)
}
@@ -105,11 +111,14 @@ impl UserIpTracker {
};
let mut active_ips = self.active_ips.write().await;
for (user, ip) in to_remove {
for ((user, ip), pending_count) in to_remove {
if pending_count == 0 {
continue;
}
if let Some(user_ips) = active_ips.get_mut(&user) {
if let Some(count) = user_ips.get_mut(&ip) {
if *count > 1 {
*count -= 1;
if *count > pending_count {
*count -= pending_count;
} else {
user_ips.remove(&ip);
}

View File

@@ -31,16 +31,24 @@ struct UserConnectionReservation {
ip_tracker: Arc<UserIpTracker>,
user: String,
ip: IpAddr,
tracks_ip: bool,
active: bool,
}
impl UserConnectionReservation {
fn new(stats: Arc<Stats>, ip_tracker: Arc<UserIpTracker>, user: String, ip: IpAddr) -> Self {
fn new(
stats: Arc<Stats>,
ip_tracker: Arc<UserIpTracker>,
user: String,
ip: IpAddr,
tracks_ip: bool,
) -> Self {
Self {
stats,
ip_tracker,
user,
ip,
tracks_ip,
active: true,
}
}
@@ -49,7 +57,9 @@ impl UserConnectionReservation {
if !self.active {
return;
}
if self.tracks_ip {
self.ip_tracker.remove_ip(&self.user, self.ip).await;
}
self.active = false;
self.stats.decrement_user_curr_connects(&self.user);
}
@@ -62,8 +72,10 @@ impl Drop for UserConnectionReservation {
}
self.active = false;
self.stats.decrement_user_curr_connects(&self.user);
if self.tracks_ip {
self.ip_tracker.enqueue_cleanup(self.user.clone(), self.ip);
}
}
}
use crate::config::ProxyConfig;
@@ -1600,6 +1612,8 @@ impl RunningClientHandler {
});
}
let tracks_ip = ip_tracker.get_user_limit(user).await.is_some();
if tracks_ip {
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
Ok(()) => {}
Err(reason) => {
@@ -1615,12 +1629,14 @@ impl RunningClientHandler {
});
}
}
}
Ok(UserConnectionReservation::new(
stats,
ip_tracker,
user.to_string(),
peer_addr.ip(),
tracks_ip,
))
}
@@ -1663,10 +1679,10 @@ impl RunningClientHandler {
});
}
if ip_tracker.get_user_limit(user).await.is_some() {
match ip_tracker.check_and_add(user, peer_addr.ip()).await {
Ok(()) => {
ip_tracker.remove_ip(user, peer_addr.ip()).await;
stats.decrement_user_curr_connects(user);
}
Err(reason) => {
stats.decrement_user_curr_connects(user);
@@ -1681,7 +1697,9 @@ impl RunningClientHandler {
});
}
}
}
stats.decrement_user_curr_connects(user);
Ok(())
}
}

View File

@@ -55,6 +55,7 @@ const STICKY_HINT_MAX_ENTRIES: usize = 65_536;
const CANDIDATE_HINT_TRACK_CAP: usize = 64;
const OVERLOAD_CANDIDATE_BUDGET_HINTED: usize = 16;
const OVERLOAD_CANDIDATE_BUDGET_UNHINTED: usize = 8;
const EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD: usize = 64;
const RECENT_USER_RING_SCAN_LIMIT: usize = 32;
type HmacSha256 = Hmac<Sha256>;
@@ -551,6 +552,19 @@ fn auth_probe_note_saturation_in(shared: &ProxySharedState, now: Instant) {
}
}
fn auth_probe_note_expensive_invalid_scan_in(
shared: &ProxySharedState,
now: Instant,
validation_checks: usize,
overload: bool,
) {
if overload || validation_checks < EXPENSIVE_INVALID_SCAN_SATURATION_THRESHOLD {
return;
}
auth_probe_note_saturation_in(shared, now);
}
fn auth_probe_record_failure_in(shared: &ProxySharedState, peer_ip: IpAddr, now: Instant) {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = &shared.handshake.auth_probe;
@@ -1378,7 +1392,14 @@ where
}
if !matched {
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
let failure_now = Instant::now();
auth_probe_note_expensive_invalid_scan_in(
shared,
failure_now,
validation_checks,
overload,
);
auth_probe_record_failure_in(shared, peer.ip(), failure_now);
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,
@@ -1753,7 +1774,14 @@ where
}
if !matched {
auth_probe_record_failure_in(shared, peer.ip(), Instant::now());
let failure_now = Instant::now();
auth_probe_note_expensive_invalid_scan_in(
shared,
failure_now,
validation_checks,
overload,
);
auth_probe_record_failure_in(shared, peer.ip(), failure_now);
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,

View File

@@ -12,7 +12,7 @@ use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{mpsc, oneshot, watch};
use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot, watch};
use tokio::time::timeout;
use tracing::{debug, info, trace, warn};
@@ -36,7 +36,11 @@ use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
enum C2MeCommand {
Data { payload: PooledBuffer, flags: u32 },
Data {
payload: PooledBuffer,
flags: u32,
_permit: OwnedSemaphorePermit,
},
Close,
}
@@ -47,6 +51,8 @@ const DESYNC_ERROR_CLASS: &str = "frame_too_large_crypto_desync";
const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
const C2ME_QUEUED_BYTE_PERMIT_UNIT: usize = 16 * 1024;
const C2ME_QUEUED_PERMITS_PER_SLOT: usize = 4;
const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1);
const TINY_FRAME_DEBT_PER_TINY: u32 = 8;
const TINY_FRAME_DEBT_LIMIT: u32 = 512;
@@ -571,6 +577,43 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool
has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET
}
fn c2me_payload_permits(payload_len: usize) -> u32 {
payload_len
.max(1)
.div_ceil(C2ME_QUEUED_BYTE_PERMIT_UNIT)
.min(u32::MAX as usize) as u32
}
fn c2me_queued_permit_budget(channel_capacity: usize, frame_limit: usize) -> usize {
channel_capacity
.saturating_mul(C2ME_QUEUED_PERMITS_PER_SLOT)
.max(c2me_payload_permits(frame_limit) as usize)
.max(1)
}
async fn acquire_c2me_payload_permit(
semaphore: &Arc<Semaphore>,
payload_len: usize,
send_timeout: Option<Duration>,
stats: &Stats,
) -> Result<OwnedSemaphorePermit> {
let permits = c2me_payload_permits(payload_len);
let acquire = semaphore.clone().acquire_many_owned(permits);
match send_timeout {
Some(send_timeout) => match timeout(send_timeout, acquire).await {
Ok(Ok(permit)) => Ok(permit),
Ok(Err(_)) => Err(ProxyError::Proxy("ME sender byte budget closed".into())),
Err(_) => {
stats.increment_me_c2me_send_timeout_total();
Err(ProxyError::Proxy("ME sender byte budget timeout".into()))
}
},
None => acquire
.await
.map_err(|_| ProxyError::Proxy("ME sender byte budget closed".into())),
}
}
fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 {
limit.saturating_add(overshoot)
}
@@ -1122,13 +1165,19 @@ where
0 => None,
timeout_ms => Some(Duration::from_millis(timeout_ms)),
};
let c2me_byte_budget = c2me_queued_permit_budget(c2me_channel_capacity, frame_limit);
let c2me_byte_semaphore = Arc::new(Semaphore::new(c2me_byte_budget));
let (c2me_tx, mut c2me_rx) = mpsc::channel::<C2MeCommand>(c2me_channel_capacity);
let me_pool_c2me = me_pool.clone();
let c2me_sender = tokio::spawn(async move {
let mut sent_since_yield = 0usize;
while let Some(cmd) = c2me_rx.recv().await {
match cmd {
C2MeCommand::Data { payload, flags } => {
C2MeCommand::Data {
payload,
flags,
_permit,
} => {
me_pool_c2me
.send_proxy_req(
conn_id,
@@ -1624,11 +1673,29 @@ where
if payload.len() >= 8 && payload[..8].iter().all(|b| *b == 0) {
flags |= RPC_FLAG_NOT_ENCRYPTED;
}
let payload_permit = match acquire_c2me_payload_permit(
&c2me_byte_semaphore,
payload.len(),
c2me_send_timeout,
stats.as_ref(),
)
.await
{
Ok(permit) => permit,
Err(e) => {
main_result = Err(e);
break;
}
};
// Keep client read loop lightweight: route heavy ME send path via a dedicated task.
if enqueue_c2me_command_in(
shared.as_ref(),
&c2me_tx,
C2MeCommand::Data { payload, flags },
C2MeCommand::Data {
payload,
flags,
_permit: payload_permit,
},
c2me_send_timeout,
stats.as_ref(),
)
@@ -2262,7 +2329,7 @@ where
W: AsyncWrite + Unpin + Send + 'static,
{
match response {
MeResponse::Data { flags, data } => {
MeResponse::Data { flags, data, .. } => {
if batched {
trace!(conn_id, bytes = data.len(), flags, "ME->C data (batched)");
} else {

View File

@@ -282,7 +282,7 @@ async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() {
assert_eq!(stats.get_user_curr_connects(&user), 1);
let reservation =
UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip);
UserConnectionReservation::new(stats.clone(), ip_tracker.clone(), user.clone(), ip, true);
// Drop the reservation synchronously without any tokio::spawn/await yielding!
drop(reservation);
@@ -320,6 +320,7 @@ async fn relay_task_abort_releases_user_gate_and_ip_reservation() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 8).await;
let mut cfg = ProxyConfig::default();
cfg.access.user_max_tcp_conns.insert(user.to_string(), 8);
@@ -437,6 +438,7 @@ async fn relay_cutover_releases_user_gate_and_ip_reservation() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 8).await;
let mut cfg = ProxyConfig::default();
cfg.access.user_max_tcp_conns.insert(user.to_string(), 8);
@@ -2879,6 +2881,7 @@ async fn explicit_reservation_release_cleans_user_and_ip_immediately() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 4).await;
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user,
@@ -2917,6 +2920,7 @@ async fn explicit_reservation_release_does_not_double_decrement_on_drop() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 4).await;
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user,
@@ -2947,6 +2951,7 @@ async fn drop_fallback_eventually_cleans_user_and_ip_reservation() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user,
@@ -3029,6 +3034,7 @@ async fn release_abort_storm_does_not_leak_user_or_ip_reservations() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, ATTEMPTS + 16).await;
for idx in 0..ATTEMPTS {
let peer = SocketAddr::new(
@@ -3079,6 +3085,7 @@ async fn release_abort_loop_preserves_immediate_same_ip_reacquire() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
for _ in 0..ITERATIONS {
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
@@ -3137,6 +3144,7 @@ async fn adversarial_mixed_release_drop_abort_wave_converges_to_zero() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, RESERVATIONS + 8).await;
let mut reservations = Vec::with_capacity(RESERVATIONS);
for idx in 0..RESERVATIONS {
@@ -3217,6 +3225,8 @@ async fn parallel_users_abort_release_isolation_preserves_independent_cleanup()
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user_a, 64).await;
ip_tracker.set_user_limit(user_b, 64).await;
let mut tasks = tokio::task::JoinSet::new();
for idx in 0..64usize {
@@ -3278,6 +3288,7 @@ async fn concurrent_release_storm_leaves_zero_user_and_ip_footprint() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, RESERVATIONS + 8).await;
let mut reservations = Vec::with_capacity(RESERVATIONS);
for idx in 0..RESERVATIONS {
@@ -3332,6 +3343,7 @@ async fn relay_connect_error_releases_user_and_ip_before_return() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 8).await;
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 1);
@@ -3427,6 +3439,7 @@ async fn mixed_release_and_drop_same_ip_preserves_counter_correctness() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static(
user,
@@ -3487,6 +3500,7 @@ async fn drop_one_of_two_same_ip_reservations_keeps_ip_active() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
let reservation_a = RunningClientHandler::acquire_user_connection_reservation_static(
user,
@@ -3696,6 +3710,7 @@ async fn cross_thread_drop_uses_captured_runtime_for_ip_cleanup() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 8).await;
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user,
@@ -3740,6 +3755,7 @@ async fn immediate_reacquire_after_cross_thread_drop_succeeds() {
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user,

View File

@@ -1252,6 +1252,97 @@ async fn tls_overload_budget_limits_candidate_scan_depth() {
);
}
#[tokio::test]
async fn tls_expensive_invalid_scan_activates_saturation_budget() {
let mut config = ProxyConfig::default();
config.access.users.clear();
config.access.ignore_time_skew = true;
for idx in 0..80u8 {
config.access.users.insert(
format!("user-{idx}"),
format!("{:032x}", u128::from(idx) + 1),
);
}
config.rebuild_runtime_user_auth().unwrap();
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let shared = ProxySharedState::new();
let attacker_secret = [0xEFu8; 16];
let handshake = make_valid_tls_handshake(&attacker_secret, 0);
let first_peer: SocketAddr = "198.51.100.214:44326".parse().unwrap();
let first = handle_tls_handshake_with_shared(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
first_peer,
&config,
&replay_checker,
&rng,
None,
shared.as_ref(),
)
.await;
assert!(matches!(first, HandshakeResult::BadClient { .. }));
assert!(
auth_probe_saturation_state_for_testing_in_shared(shared.as_ref())
.lock()
.unwrap()
.is_some(),
"expensive invalid scan must activate global saturation"
);
assert_eq!(
shared
.handshake
.auth_expensive_checks_total
.load(Ordering::Relaxed),
80,
"first invalid probe preserves full first-hit compatibility before enabling saturation"
);
{
let mut saturation = auth_probe_saturation_state_for_testing_in_shared(shared.as_ref())
.lock()
.unwrap();
let state = saturation.as_mut().expect("saturation must be present");
state.blocked_until = Instant::now() + Duration::from_millis(200);
}
let second_peer: SocketAddr = "198.51.100.215:44326".parse().unwrap();
let second = handle_tls_handshake_with_shared(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
second_peer,
&config,
&replay_checker,
&rng,
None,
shared.as_ref(),
)
.await;
assert!(matches!(second, HandshakeResult::BadClient { .. }));
assert_eq!(
shared
.handshake
.auth_budget_exhausted_total
.load(Ordering::Relaxed),
1,
"second invalid probe must be capped by overload budget"
);
assert_eq!(
shared
.handshake
.auth_expensive_checks_total
.load(Ordering::Relaxed),
80 + OVERLOAD_CANDIDATE_BUDGET_UNHINTED as u64,
"saturation budget must bound follow-up invalid scans"
);
}
#[tokio::test]
async fn mtproto_runtime_snapshot_prefers_preferred_user_hint() {
let mut config = ProxyConfig::default();

View File

@@ -70,6 +70,7 @@ async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() {
MeResponse::Data {
flags: 0,
data: payload.clone(),
route_permit: None,
},
&mut writer,
ProtoTag::Intermediate,
@@ -139,6 +140,7 @@ async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() {
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xAA, 0xBB, 0xCC]),
route_permit: None,
},
&mut writer,
ProtoTag::Intermediate,

View File

@@ -12,6 +12,12 @@ fn make_pooled_payload(data: &[u8]) -> PooledBuffer {
payload
}
fn make_c2me_permit() -> tokio::sync::OwnedSemaphorePermit {
Arc::new(tokio::sync::Semaphore::new(1))
.try_acquire_many_owned(1)
.expect("test permit must be available")
}
#[test]
#[ignore = "Tracking for M-04: Verify should_emit_full_desync returns true on first occurrence and false on duplicate within window"]
fn should_emit_full_desync_filters_duplicates() {
@@ -107,6 +113,7 @@ async fn c2me_channel_full_path_yields_then_sends() {
tx.send(C2MeCommand::Data {
payload: make_pooled_payload(&[0xAA]),
flags: 1,
_permit: make_c2me_permit(),
})
.await
.expect("priming queue with one frame must succeed");
@@ -119,6 +126,7 @@ async fn c2me_channel_full_path_yields_then_sends() {
C2MeCommand::Data {
payload: make_pooled_payload(&[0xBB, 0xCC]),
flags: 2,
_permit: make_c2me_permit(),
},
None,
&stats,
@@ -138,7 +146,7 @@ async fn c2me_channel_full_path_yields_then_sends() {
.expect("receiver should observe primed frame")
.expect("first queued command must exist");
match first {
C2MeCommand::Data { payload, flags } => {
C2MeCommand::Data { payload, flags, .. } => {
assert_eq!(payload.as_ref(), &[0xAA]);
assert_eq!(flags, 1);
}
@@ -155,7 +163,7 @@ async fn c2me_channel_full_path_yields_then_sends() {
.expect("receiver should observe backpressure-resumed frame")
.expect("second queued command must exist");
match second {
C2MeCommand::Data { payload, flags } => {
C2MeCommand::Data { payload, flags, .. } => {
assert_eq!(payload.as_ref(), &[0xBB, 0xCC]);
assert_eq!(flags, 2);
}

View File

@@ -7,6 +7,7 @@ use std::time::{Duration, Instant};
use parking_lot::Mutex;
const CLEANUP_INTERVAL: Duration = Duration::from_secs(30);
const MAX_BEOBACHTEN_ENTRIES: usize = 65_536;
#[derive(Default)]
struct BeobachtenInner {
@@ -48,12 +49,23 @@ impl BeobachtenStore {
Self::cleanup_if_needed(&mut guard, now, ttl);
let key = (class.to_string(), ip);
let entry = guard.entries.entry(key).or_insert(BeobachtenEntry {
tries: 0,
last_seen: now,
});
if let Some(entry) = guard.entries.get_mut(&key) {
entry.tries = entry.tries.saturating_add(1);
entry.last_seen = now;
return;
}
if guard.entries.len() >= MAX_BEOBACHTEN_ENTRIES {
return;
}
guard.entries.insert(
key,
BeobachtenEntry {
tries: 1,
last_seen: now,
},
);
}
pub fn snapshot_text(&self, ttl: Duration) -> String {

View File

@@ -649,6 +649,25 @@ async fn duplicate_cleanup_entries_do_not_break_future_admission() {
);
}
#[tokio::test]
async fn duplicate_cleanup_entries_are_coalesced_until_drain() {
let tracker = UserIpTracker::new();
let ip = ip_from_idx(7150);
tracker.enqueue_cleanup("coalesced-cleanup".to_string(), ip);
tracker.enqueue_cleanup("coalesced-cleanup".to_string(), ip);
tracker.enqueue_cleanup("coalesced-cleanup".to_string(), ip);
assert_eq!(
tracker.cleanup_queue_len_for_tests(),
1,
"duplicate queued cleanup entries must retain one allocation slot"
);
tracker.drain_cleanup_queue().await;
assert_eq!(tracker.cleanup_queue_len_for_tests(), 0);
}
#[tokio::test]
async fn stress_repeated_queue_poison_recovery_preserves_admission_progress() {
let tracker = UserIpTracker::new();

View File

@@ -46,6 +46,7 @@ mod send_adversarial_tests;
mod wire;
use bytes::Bytes;
use tokio::sync::OwnedSemaphorePermit;
#[allow(unused_imports)]
pub use config_updater::{
@@ -68,9 +69,32 @@ pub use secret::{fetch_proxy_secret, fetch_proxy_secret_with_upstream};
pub(crate) use selftest::{bnd_snapshot, timeskew_snapshot, upstream_bnd_snapshots};
pub use wire::proto_flags_for_tag;
/// Holds D2C queued-byte capacity until a routed payload is consumed or dropped.
pub struct RouteBytePermit {
_permit: OwnedSemaphorePermit,
}
impl std::fmt::Debug for RouteBytePermit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RouteBytePermit").finish_non_exhaustive()
}
}
impl RouteBytePermit {
pub(crate) fn new(permit: OwnedSemaphorePermit) -> Self {
Self { _permit: permit }
}
}
/// Response routed from middle proxy readers to client relay tasks.
#[derive(Debug)]
pub enum MeResponse {
Data { flags: u32, data: Bytes },
/// Downstream payload with its queued-byte reservation.
Data {
flags: u32,
data: Bytes,
route_permit: Option<RouteBytePermit>,
},
Ack(u32),
Close,
}

View File

@@ -84,6 +84,7 @@ async fn route_data_with_retry(
MeResponse::Data {
flags,
data: data.clone(),
route_permit: None,
},
timeout_ms,
)
@@ -639,7 +640,7 @@ mod tests {
let routed = route_data_with_retry(&reg, conn_id, 0, Bytes::from_static(b"a"), 20).await;
assert!(matches!(routed, RouteResult::Routed));
match rx.recv().await {
Some(MeResponse::Data { flags, data }) => {
Some(MeResponse::Data { flags, data, .. }) => {
assert_eq!(flags, 0);
assert_eq!(data, Bytes::from_static(b"a"));
}

View File

@@ -1,18 +1,22 @@
use std::collections::{HashMap, HashSet};
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use dashmap::DashMap;
use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::{Mutex, mpsc};
use tokio::sync::{Mutex, Semaphore, mpsc};
use super::MeResponse;
use super::codec::WriterCommand;
use super::{MeResponse, RouteBytePermit};
const ROUTE_BACKPRESSURE_BASE_TIMEOUT_MS: u64 = 25;
const ROUTE_BACKPRESSURE_HIGH_TIMEOUT_MS: u64 = 120;
const ROUTE_BACKPRESSURE_HIGH_WATERMARK_PCT: u8 = 80;
const ROUTE_QUEUED_BYTE_PERMIT_UNIT: usize = 16 * 1024;
const ROUTE_QUEUED_PERMITS_PER_SLOT: usize = 4;
const ROUTE_QUEUED_MAX_FRAME_PERMITS: usize = 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RouteResult {
@@ -53,6 +57,7 @@ pub(super) struct WriterActivitySnapshot {
struct RoutingTable {
map: DashMap<u64, mpsc::Sender<MeResponse>>,
byte_budget: DashMap<u64, Arc<Semaphore>>,
}
struct WriterTable {
@@ -105,6 +110,7 @@ pub struct ConnRegistry {
route_backpressure_base_timeout_ms: AtomicU64,
route_backpressure_high_timeout_ms: AtomicU64,
route_backpressure_high_watermark_pct: AtomicU8,
route_byte_permits_per_conn: usize,
}
impl ConnRegistry {
@@ -116,10 +122,23 @@ impl ConnRegistry {
}
pub fn with_route_channel_capacity(route_channel_capacity: usize) -> Self {
let route_channel_capacity = route_channel_capacity.max(1);
Self::with_route_limits(
route_channel_capacity,
Self::route_byte_permit_budget(route_channel_capacity),
)
}
fn with_route_limits(
route_channel_capacity: usize,
route_byte_permits_per_conn: usize,
) -> Self {
let start = rand::random::<u64>() | 1;
let route_channel_capacity = route_channel_capacity.max(1);
Self {
routing: RoutingTable {
map: DashMap::new(),
byte_budget: DashMap::new(),
},
writers: WriterTable {
map: DashMap::new(),
@@ -131,15 +150,30 @@ impl ConnRegistry {
inner: Mutex::new(BindingInner::new()),
},
next_id: AtomicU64::new(start),
route_channel_capacity: route_channel_capacity.max(1),
route_channel_capacity,
route_backpressure_base_timeout_ms: AtomicU64::new(ROUTE_BACKPRESSURE_BASE_TIMEOUT_MS),
route_backpressure_high_timeout_ms: AtomicU64::new(ROUTE_BACKPRESSURE_HIGH_TIMEOUT_MS),
route_backpressure_high_watermark_pct: AtomicU8::new(
ROUTE_BACKPRESSURE_HIGH_WATERMARK_PCT,
),
route_byte_permits_per_conn: route_byte_permits_per_conn.max(1),
}
}
fn route_data_permits(data_len: usize) -> u32 {
data_len
.max(1)
.div_ceil(ROUTE_QUEUED_BYTE_PERMIT_UNIT)
.min(u32::MAX as usize) as u32
}
fn route_byte_permit_budget(route_channel_capacity: usize) -> usize {
route_channel_capacity
.saturating_mul(ROUTE_QUEUED_PERMITS_PER_SLOT)
.max(ROUTE_QUEUED_MAX_FRAME_PERMITS)
.max(1)
}
pub fn route_channel_capacity(&self) -> usize {
self.route_channel_capacity
}
@@ -149,6 +183,14 @@ impl ConnRegistry {
Self::with_route_channel_capacity(4096)
}
#[cfg(test)]
fn with_route_byte_permits_for_tests(
route_channel_capacity: usize,
route_byte_permits_per_conn: usize,
) -> Self {
Self::with_route_limits(route_channel_capacity, route_byte_permits_per_conn)
}
pub fn update_route_backpressure_policy(
&self,
base_timeout_ms: u64,
@@ -170,6 +212,10 @@ impl ConnRegistry {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = mpsc::channel(self.route_channel_capacity);
self.routing.map.insert(id, tx);
self.routing.byte_budget.insert(
id,
Arc::new(Semaphore::new(self.route_byte_permits_per_conn)),
);
(id, rx)
}
@@ -186,6 +232,7 @@ impl ConnRegistry {
/// Unregister connection, returning associated writer_id if any.
pub async fn unregister(&self, id: u64) -> Option<u64> {
self.routing.map.remove(&id);
self.routing.byte_budget.remove(&id);
self.hot_binding.map.remove(&id);
let mut binding = self.binding.inner.lock().await;
binding.meta.remove(&id);
@@ -206,6 +253,64 @@ impl ConnRegistry {
None
}
async fn attach_route_byte_permit(
&self,
id: u64,
resp: MeResponse,
timeout_ms: Option<u64>,
) -> std::result::Result<MeResponse, RouteResult> {
let MeResponse::Data {
flags,
data,
route_permit,
} = resp
else {
return Ok(resp);
};
if route_permit.is_some() {
return Ok(MeResponse::Data {
flags,
data,
route_permit,
});
}
let Some(semaphore) = self
.routing
.byte_budget
.get(&id)
.map(|entry| entry.value().clone())
else {
return Err(RouteResult::NoConn);
};
let permits = Self::route_data_permits(data.len());
let permit = match timeout_ms {
Some(0) => semaphore
.try_acquire_many_owned(permits)
.map_err(|_| RouteResult::QueueFullHigh)?,
Some(timeout_ms) => {
let acquire = semaphore.acquire_many_owned(permits);
match tokio::time::timeout(Duration::from_millis(timeout_ms.max(1)), acquire).await
{
Ok(Ok(permit)) => permit,
Ok(Err(_)) => return Err(RouteResult::ChannelClosed),
Err(_) => return Err(RouteResult::QueueFullHigh),
}
}
None => semaphore
.acquire_many_owned(permits)
.await
.map_err(|_| RouteResult::ChannelClosed)?,
};
Ok(MeResponse::Data {
flags,
data,
route_permit: Some(RouteBytePermit::new(permit)),
})
}
#[allow(dead_code)]
pub async fn route(&self, id: u64, resp: MeResponse) -> RouteResult {
let tx = self.routing.map.get(&id).map(|entry| entry.value().clone());
@@ -214,15 +319,23 @@ impl ConnRegistry {
return RouteResult::NoConn;
};
let base_timeout_ms = self
.route_backpressure_base_timeout_ms
.load(Ordering::Relaxed)
.max(1);
let resp = match self
.attach_route_byte_permit(id, resp, Some(base_timeout_ms))
.await
{
Ok(resp) => resp,
Err(result) => return result,
};
match tx.try_send(resp) {
Ok(()) => RouteResult::Routed,
Err(TrySendError::Closed(_)) => RouteResult::ChannelClosed,
Err(TrySendError::Full(resp)) => {
// Absorb short bursts without dropping/closing the session immediately.
let base_timeout_ms = self
.route_backpressure_base_timeout_ms
.load(Ordering::Relaxed)
.max(1);
let high_timeout_ms = self
.route_backpressure_high_timeout_ms
.load(Ordering::Relaxed)
@@ -266,6 +379,10 @@ impl ConnRegistry {
let Some(tx) = tx else {
return RouteResult::NoConn;
};
let resp = match self.attach_route_byte_permit(id, resp, Some(0)).await {
Ok(resp) => resp,
Err(result) => return result,
};
match tx.try_send(resp) {
Ok(()) => RouteResult::Routed,
@@ -289,6 +406,13 @@ impl ConnRegistry {
let Some(tx) = tx else {
return RouteResult::NoConn;
};
let resp = match self
.attach_route_byte_permit(id, resp, Some(timeout_ms))
.await
{
Ok(resp) => resp,
Err(result) => return result,
};
match tx.try_send(resp) {
Ok(()) => RouteResult::Routed,
@@ -541,8 +665,10 @@ impl ConnRegistry {
mod tests {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use super::ConnMeta;
use super::ConnRegistry;
use bytes::Bytes;
use super::{ConnMeta, ConnRegistry, RouteResult};
use crate::transport::middle_proxy::MeResponse;
#[tokio::test]
async fn writer_activity_snapshot_tracks_writer_and_dc_load() {
@@ -608,6 +734,55 @@ mod tests {
assert_eq!(snapshot.active_sessions_by_target_dc.get(&4), Some(&1));
}
#[tokio::test]
async fn route_data_is_bounded_by_byte_permits_before_channel_capacity() {
let registry = ConnRegistry::with_route_byte_permits_for_tests(4, 1);
let (conn_id, mut rx) = registry.register().await;
let routed = registry
.route_nowait(
conn_id,
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xAA]),
route_permit: None,
},
)
.await;
assert!(matches!(routed, RouteResult::Routed));
let blocked = registry
.route_nowait(
conn_id,
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xBB]),
route_permit: None,
},
)
.await;
assert!(
matches!(blocked, RouteResult::QueueFullHigh),
"byte budget must reject data before count capacity is exhausted"
);
drop(rx.recv().await);
let routed_after_drain = registry
.route_nowait(
conn_id,
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xCC]),
route_permit: None,
},
)
.await;
assert!(
matches!(routed_after_drain, RouteResult::Routed),
"receiving queued data must release byte permits"
);
}
#[tokio::test]
async fn bind_writer_rebinds_conn_atomically() {
let registry = ConnRegistry::new();