diff --git a/Cargo.lock b/Cargo.lock index 49ea79f..9b5ec98 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2780,7 +2780,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" [[package]] name = "telemt" -version = "3.4.0" +version = "3.4.1" dependencies = [ "aes", "anyhow", diff --git a/Cargo.toml b/Cargo.toml index bb45cb9..009f679 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "telemt" -version = "3.4.0" +version = "3.4.1" edition = "2024" [features] @@ -98,4 +98,3 @@ harness = false [profile.release] lto = "fat" codegen-units = 1 - diff --git a/docs/Architecture/Fronting-splitting/TLS_FRONT_PROFILE_FIDELITY.en.md b/docs/Architecture/Fronting-splitting/TLS_FRONT_PROFILE_FIDELITY.en.md new file mode 100644 index 0000000..21d6558 --- /dev/null +++ b/docs/Architecture/Fronting-splitting/TLS_FRONT_PROFILE_FIDELITY.en.md @@ -0,0 +1,225 @@ +# TLS Front Profile Fidelity + +## Overview + +This document describes how Telemt reuses captured TLS behavior in the FakeTLS server flight and how to validate the result on a real deployment. + +When TLS front emulation is enabled, Telemt can capture useful server-side TLS behavior from the selected origin and reuse that behavior in the emulated success path. The goal is not to reproduce the origin byte-for-byte, but to reduce stable synthetic traits and make the emitted server flight structurally closer to the captured profile. + +## Why this change exists + +The project already captures useful server-side TLS behavior in the TLS front fetch path: + +- `change_cipher_spec_count` +- `app_data_record_sizes` +- `ticket_record_sizes` + +Before this change, the emulator used only part of that information. This left a gap between captured origin behavior and emitted FakeTLS server flight. + +## What is implemented + +- The emulator now replays the observed `ChangeCipherSpec` count from the fetched behavior profile. +- The emulator now replays observed ticket-like tail ApplicationData record sizes when raw or merged TLS profile data is available. +- The emulator now preserves more of the profiled encrypted-flight structure instead of collapsing it into a smaller synthetic shape. +- The emulator still falls back to the previous synthetic behavior when the cached profile does not contain raw TLS behavior information. +- Operator-configured `tls_new_session_tickets` still works as an additive fallback when the profile does not provide enough tail records. + +## Practical benefit + +- Reduced distinguishability between profiled origin TLS behavior and emulated TLS behavior. +- Lower chance of stable server-flight fingerprints caused by fixed CCS count or synthetic-only tail record sizes. +- Better reuse of already captured TLS profile data without changing MTProto logic, KDF routing, or transport architecture. + +## Limitations + +This mechanism does not aim to make Telemt byte-identical to the origin server. + +It also does not change: + +- MTProto business logic; +- KDF routing behavior; +- the overall transport architecture. + +The practical goal is narrower: + +- reuse more captured profile data; +- reduce fixed synthetic behavior in the server flight; +- preserve a valid FakeTLS success path while changing the emitted shape on the wire. + +## Validation targets + +- Correct count of emulated `ChangeCipherSpec` records. +- Correct replay of observed ticket-tail record sizes. +- No regression in existing ALPN and payload-placement behavior. + +## How to validate the result + +Recommended validation consists of two layers: + +- focused unit and security tests for CCS-count replay and ticket-tail replay; +- real packet-capture comparison for a selected origin and a successful FakeTLS session. + +When testing on the network, the expected result is: + +- a valid FakeTLS and MTProto success path is preserved; +- the early encrypted server flight changes shape when richer profile data is available; +- the change is visible on the wire without changing MTProto logic or transport architecture. + +This validation is intended to show better reuse of captured TLS profile data. +It is not intended to prove byte-level equivalence with the real origin server. + +## How to test on a real deployment + +The strongest practical validation is a side-by-side trace comparison between: + +- a real TLS origin server used as `mask_host`; +- a Telemt FakeTLS success-path connection for the same SNI; +- optional captures from different Telemt builds or configurations. + +The purpose of the comparison is to inspect the shape of the server flight: + +- record order; +- count of `ChangeCipherSpec` records; +- count and grouping of early encrypted `ApplicationData` records; +- lengths of tail or continuation `ApplicationData` records. + +## Recommended environment + +Use a Linux host or Docker container for the cleanest reproduction. + +Recommended setup: + +1. One Telemt instance. +2. One real HTTPS origin as `mask_host`. +3. One Telegram client configured with an `ee` proxy link for the Telemt instance. +4. `tcpdump` or Wireshark available for capture analysis. + +## Step-by-step test procedure + +### 1. Prepare the origin + +1. Choose a real HTTPS origin. +2. Set both `censorship.tls_domain` and `censorship.mask_host` to that hostname. +3. Confirm that a direct TLS request works: + +```bash +openssl s_client -connect ORIGIN_IP:443 -servername YOUR_DOMAIN >, pub user_data_quota: std::collections::HashMap, + pub user_rate_limits: std::collections::HashMap, + pub cidr_rate_limits: + std::collections::HashMap, pub user_max_unique_ips: std::collections::HashMap, pub user_max_unique_ips_global_each: usize, pub user_max_unique_ips_mode: crate::config::UserMaxUniqueIpsMode, @@ -245,6 +248,8 @@ impl HotFields { user_max_tcp_conns_global_each: cfg.access.user_max_tcp_conns_global_each, user_expirations: cfg.access.user_expirations.clone(), user_data_quota: cfg.access.user_data_quota.clone(), + user_rate_limits: cfg.access.user_rate_limits.clone(), + cidr_rate_limits: cfg.access.cidr_rate_limits.clone(), user_max_unique_ips: cfg.access.user_max_unique_ips.clone(), user_max_unique_ips_global_each: cfg.access.user_max_unique_ips_global_each, user_max_unique_ips_mode: cfg.access.user_max_unique_ips_mode, @@ -545,6 +550,8 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig { cfg.access.user_max_tcp_conns_global_each = new.access.user_max_tcp_conns_global_each; cfg.access.user_expirations = new.access.user_expirations.clone(); cfg.access.user_data_quota = new.access.user_data_quota.clone(); + cfg.access.user_rate_limits = new.access.user_rate_limits.clone(); + cfg.access.cidr_rate_limits = new.access.cidr_rate_limits.clone(); cfg.access.user_max_unique_ips = new.access.user_max_unique_ips.clone(); cfg.access.user_max_unique_ips_global_each = new.access.user_max_unique_ips_global_each; cfg.access.user_max_unique_ips_mode = new.access.user_max_unique_ips_mode; @@ -1183,6 +1190,18 @@ fn log_changes( new_hot.user_data_quota.len() ); } + if old_hot.user_rate_limits != new_hot.user_rate_limits { + info!( + "config reload: user_rate_limits updated ({} entries)", + new_hot.user_rate_limits.len() + ); + } + if old_hot.cidr_rate_limits != new_hot.cidr_rate_limits { + info!( + "config reload: cidr_rate_limits updated ({} entries)", + new_hot.cidr_rate_limits.len() + ); + } if old_hot.user_max_unique_ips != new_hot.user_max_unique_ips { info!( "config reload: user_max_unique_ips updated ({} entries)", diff --git a/src/config/load.rs b/src/config/load.rs index e5c8202..d15773c 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -861,6 +861,22 @@ impl ProxyConfig { )); } + for (user, limit) in &config.access.user_rate_limits { + if limit.up_bps == 0 && limit.down_bps == 0 { + return Err(ProxyError::Config(format!( + "access.user_rate_limits.{user} must set at least one non-zero direction" + ))); + } + } + + for (cidr, limit) in &config.access.cidr_rate_limits { + if limit.up_bps == 0 && limit.down_bps == 0 { + return Err(ProxyError::Config(format!( + "access.cidr_rate_limits.{cidr} must set at least one non-zero direction" + ))); + } + } + if config.general.me_reinit_every_secs == 0 { return Err(ProxyError::Config( "general.me_reinit_every_secs must be > 0".to_string(), diff --git a/src/config/types.rs b/src/config/types.rs index 35b8d46..9f7e0f4 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1826,6 +1826,21 @@ pub struct AccessConfig { #[serde(default)] pub user_data_quota: HashMap, + /// Per-user transport rate limits in bits-per-second. + /// + /// Each entry supports independent upload (`up_bps`) and download + /// (`down_bps`) ceilings. A value of `0` in one direction means + /// "unlimited" for that direction. + #[serde(default)] + pub user_rate_limits: HashMap, + + /// Per-CIDR aggregate transport rate limits in bits-per-second. + /// + /// Matching uses longest-prefix-wins semantics. A value of `0` in one + /// direction means "unlimited" for that direction. + #[serde(default)] + pub cidr_rate_limits: HashMap, + #[serde(default)] pub user_max_unique_ips: HashMap, @@ -1859,6 +1874,8 @@ impl Default for AccessConfig { user_max_tcp_conns_global_each: default_user_max_tcp_conns_global_each(), user_expirations: HashMap::new(), user_data_quota: HashMap::new(), + user_rate_limits: HashMap::new(), + cidr_rate_limits: HashMap::new(), user_max_unique_ips: HashMap::new(), user_max_unique_ips_global_each: default_user_max_unique_ips_global_each(), user_max_unique_ips_mode: UserMaxUniqueIpsMode::default(), @@ -1870,6 +1887,14 @@ impl Default for AccessConfig { } } +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct RateLimitBps { + #[serde(default)] + pub up_bps: u64, + #[serde(default)] + pub down_bps: u64, +} + // ============= Aux Structures ============= #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index 6f7de16..66fca0f 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -8,6 +8,7 @@ use std::io::{self, Read, Write}; use std::os::unix::fs::OpenOptionsExt; use std::path::{Path, PathBuf}; +use nix::errno::Errno; use nix::fcntl::{Flock, FlockArg}; use nix::unistd::{self, ForkResult, Gid, Pid, Uid, chdir, close, fork, getpid, setsid}; use tracing::{debug, info, warn}; @@ -157,15 +158,15 @@ fn redirect_stdio_to_devnull() -> Result<(), DaemonError> { unsafe { // Redirect stdin (fd 0) if libc::dup2(devnull_fd, 0) < 0 { - return Err(DaemonError::RedirectFailed(nix::errno::Errno::last())); + return Err(DaemonError::RedirectFailed(Errno::last())); } // Redirect stdout (fd 1) if libc::dup2(devnull_fd, 1) < 0 { - return Err(DaemonError::RedirectFailed(nix::errno::Errno::last())); + return Err(DaemonError::RedirectFailed(Errno::last())); } // Redirect stderr (fd 2) if libc::dup2(devnull_fd, 2) < 0 { - return Err(DaemonError::RedirectFailed(nix::errno::Errno::last())); + return Err(DaemonError::RedirectFailed(Errno::last())); } } @@ -337,6 +338,27 @@ fn is_process_running(pid: i32) -> bool { nix::sys::signal::kill(Pid::from_raw(pid), None).is_ok() } +// macOS gates nix::unistd::setgroups differently in the current dependency set, +// so call libc directly there while preserving the original nix path elsewhere. +fn set_supplementary_groups(gid: Gid) -> Result<(), nix::Error> { + #[cfg(target_os = "macos")] + { + let groups = [gid.as_raw()]; + let rc = unsafe { + libc::setgroups( + i32::try_from(groups.len()).expect("single supplementary group must fit in c_int"), + groups.as_ptr(), + ) + }; + if rc == 0 { Ok(()) } else { Err(Errno::last()) } + } + + #[cfg(not(target_os = "macos"))] + { + unistd::setgroups(&[gid]) + } +} + /// Drops privileges to the specified user and group. /// /// This should be called after binding privileged ports but before entering @@ -368,7 +390,7 @@ pub fn drop_privileges( if let Some(gid) = target_gid { unistd::setgid(gid).map_err(DaemonError::PrivilegeDrop)?; - unistd::setgroups(&[gid]).map_err(DaemonError::PrivilegeDrop)?; + set_supplementary_groups(gid).map_err(DaemonError::PrivilegeDrop)?; info!(gid = gid.as_raw(), "Dropped group privileges"); } diff --git a/src/maestro/mod.rs b/src/maestro/mod.rs index f141331..4c5b98a 100644 --- a/src/maestro/mod.rs +++ b/src/maestro/mod.rs @@ -664,6 +664,11 @@ async fn run_telemt_core( )); let buffer_pool = Arc::new(BufferPool::with_config(64 * 1024, 4096)); + let shared_state = ProxySharedState::new(); + shared_state.traffic_limiter.apply_policy( + config.access.user_rate_limits.clone(), + config.access.cidr_rate_limits.clone(), + ); connectivity::run_startup_connectivity( &config, @@ -695,6 +700,7 @@ async fn run_telemt_core( beobachten.clone(), api_config_tx.clone(), me_pool.clone(), + shared_state.clone(), ) .await; let config_rx = runtime_watches.config_rx; @@ -711,7 +717,6 @@ async fn run_telemt_core( ) .await; let _admission_tx_hold = admission_tx; - let shared_state = ProxySharedState::new(); conntrack_control::spawn_conntrack_controller( config_rx.clone(), stats.clone(), diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index 5b3f2e0..da059bd 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -51,6 +51,7 @@ pub(crate) async fn spawn_runtime_tasks( beobachten: Arc, api_config_tx: watch::Sender>, me_pool_for_policy: Option>, + shared_state: Arc, ) -> RuntimeWatches { let um_clone = upstream_manager.clone(); let dc_overrides_for_health = config.dc_overrides.clone(); @@ -182,6 +183,41 @@ pub(crate) async fn spawn_runtime_tasks( } }); + let limiter = shared_state.traffic_limiter.clone(); + limiter.apply_policy( + config.access.user_rate_limits.clone(), + config.access.cidr_rate_limits.clone(), + ); + let mut config_rx_rate_limits = config_rx.clone(); + tokio::spawn(async move { + let mut prev_user_limits = config_rx_rate_limits + .borrow() + .access + .user_rate_limits + .clone(); + let mut prev_cidr_limits = config_rx_rate_limits + .borrow() + .access + .cidr_rate_limits + .clone(); + loop { + if config_rx_rate_limits.changed().await.is_err() { + break; + } + let cfg = config_rx_rate_limits.borrow_and_update().clone(); + if prev_user_limits != cfg.access.user_rate_limits + || prev_cidr_limits != cfg.access.cidr_rate_limits + { + limiter.apply_policy( + cfg.access.user_rate_limits.clone(), + cfg.access.cidr_rate_limits.clone(), + ); + prev_user_limits = cfg.access.user_rate_limits.clone(); + prev_cidr_limits = cfg.access.cidr_rate_limits.clone(); + } + } + }); + let beobachten_writer = beobachten.clone(); let config_rx_beobachten = config_rx.clone(); tokio::spawn(async move { diff --git a/src/metrics.rs b/src/metrics.rs index 1b920a8..ba44a5f 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -575,6 +575,139 @@ async fn render_metrics( } ); + let limiter_metrics = shared_state.traffic_limiter.metrics_snapshot(); + let _ = writeln!( + out, + "# HELP telemt_rate_limiter_throttle_total Traffic limiter throttle events by scope and direction" + ); + let _ = writeln!(out, "# TYPE telemt_rate_limiter_throttle_total counter"); + let _ = writeln!( + out, + "telemt_rate_limiter_throttle_total{{scope=\"user\",direction=\"up\"}} {}", + if core_enabled { + limiter_metrics.user_throttle_up_total + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_rate_limiter_throttle_total{{scope=\"user\",direction=\"down\"}} {}", + if core_enabled { + limiter_metrics.user_throttle_down_total + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_rate_limiter_throttle_total{{scope=\"cidr\",direction=\"up\"}} {}", + if core_enabled { + limiter_metrics.cidr_throttle_up_total + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_rate_limiter_throttle_total{{scope=\"cidr\",direction=\"down\"}} {}", + if core_enabled { + limiter_metrics.cidr_throttle_down_total + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_rate_limiter_wait_ms_total Traffic limiter accumulated wait time in milliseconds by scope and direction" + ); + let _ = writeln!(out, "# TYPE telemt_rate_limiter_wait_ms_total counter"); + let _ = writeln!( + out, + "telemt_rate_limiter_wait_ms_total{{scope=\"user\",direction=\"up\"}} {}", + if core_enabled { + limiter_metrics.user_wait_up_ms_total + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_rate_limiter_wait_ms_total{{scope=\"user\",direction=\"down\"}} {}", + if core_enabled { + limiter_metrics.user_wait_down_ms_total + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_rate_limiter_wait_ms_total{{scope=\"cidr\",direction=\"up\"}} {}", + if core_enabled { + limiter_metrics.cidr_wait_up_ms_total + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_rate_limiter_wait_ms_total{{scope=\"cidr\",direction=\"down\"}} {}", + if core_enabled { + limiter_metrics.cidr_wait_down_ms_total + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_rate_limiter_active_leases Active relay leases under rate limiting by scope" + ); + let _ = writeln!(out, "# TYPE telemt_rate_limiter_active_leases gauge"); + let _ = writeln!( + out, + "telemt_rate_limiter_active_leases{{scope=\"user\"}} {}", + if core_enabled { + limiter_metrics.user_active_leases + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_rate_limiter_active_leases{{scope=\"cidr\"}} {}", + if core_enabled { + limiter_metrics.cidr_active_leases + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_rate_limiter_policy_entries Active rate-limit policy entries by scope" + ); + let _ = writeln!(out, "# TYPE telemt_rate_limiter_policy_entries gauge"); + let _ = writeln!( + out, + "telemt_rate_limiter_policy_entries{{scope=\"user\"}} {}", + if core_enabled { + limiter_metrics.user_policy_entries + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_rate_limiter_policy_entries{{scope=\"cidr\"}} {}", + if core_enabled { + limiter_metrics.cidr_policy_entries + } else { + 0 + } + ); + let _ = writeln!( out, "# HELP telemt_upstream_connect_attempt_total Upstream connect attempts across all requests" @@ -1177,6 +1310,143 @@ async fn render_metrics( 0 } ); + let _ = writeln!( + out, + "# HELP telemt_me_fair_pressure_state Worker-local fairness pressure state" + ); + let _ = writeln!(out, "# TYPE telemt_me_fair_pressure_state gauge"); + let _ = writeln!( + out, + "telemt_me_fair_pressure_state {}", + if me_allows_normal { + stats.get_me_fair_pressure_state_gauge() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_fair_active_flows Fair-scheduler active flow count" + ); + let _ = writeln!(out, "# TYPE telemt_me_fair_active_flows gauge"); + let _ = writeln!( + out, + "telemt_me_fair_active_flows {}", + if me_allows_normal { + stats.get_me_fair_active_flows_gauge() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_fair_queued_bytes Fair-scheduler queued bytes" + ); + let _ = writeln!(out, "# TYPE telemt_me_fair_queued_bytes gauge"); + let _ = writeln!( + out, + "telemt_me_fair_queued_bytes {}", + if me_allows_normal { + stats.get_me_fair_queued_bytes_gauge() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_fair_flow_state_gauge Fair-scheduler flow health classes" + ); + let _ = writeln!(out, "# TYPE telemt_me_fair_flow_state_gauge gauge"); + let _ = writeln!( + out, + "telemt_me_fair_flow_state_gauge{{class=\"standing\"}} {}", + if me_allows_normal { + stats.get_me_fair_standing_flows_gauge() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_fair_flow_state_gauge{{class=\"backpressured\"}} {}", + if me_allows_normal { + stats.get_me_fair_backpressured_flows_gauge() + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_me_fair_events_total Fair-scheduler event counters" + ); + let _ = writeln!(out, "# TYPE telemt_me_fair_events_total counter"); + let _ = writeln!( + out, + "telemt_me_fair_events_total{{event=\"scheduler_round\"}} {}", + if me_allows_normal { + stats.get_me_fair_scheduler_rounds_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_fair_events_total{{event=\"deficit_grant\"}} {}", + if me_allows_normal { + stats.get_me_fair_deficit_grants_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_fair_events_total{{event=\"deficit_skip\"}} {}", + if me_allows_normal { + stats.get_me_fair_deficit_skips_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_fair_events_total{{event=\"enqueue_reject\"}} {}", + if me_allows_normal { + stats.get_me_fair_enqueue_rejects_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_fair_events_total{{event=\"shed_drop\"}} {}", + if me_allows_normal { + stats.get_me_fair_shed_drops_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_fair_events_total{{event=\"penalty\"}} {}", + if me_allows_normal { + stats.get_me_fair_penalties_total() + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_me_fair_events_total{{event=\"downstream_stall\"}} {}", + if me_allows_normal { + stats.get_me_fair_downstream_stalls_total() + } else { + 0 + } + ); let _ = writeln!( out, diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 2c4fe45..6bd2101 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -316,6 +316,9 @@ where stats.increment_user_connects(user); let _direct_connection_lease = stats.acquire_direct_connection_lease(); + let traffic_lease = shared + .traffic_limiter + .acquire_lease(user, success.peer.ip()); let buffer_pool_trim = Arc::clone(&buffer_pool); let relay_activity_timeout = if shared.conntrack_pressure_active() { @@ -329,7 +332,7 @@ where } else { Duration::from_secs(1800) }; - let relay_result = crate::proxy::relay::relay_bidirectional_with_activity_timeout( + let relay_result = crate::proxy::relay::relay_bidirectional_with_activity_timeout_and_lease( client_reader, client_writer, tg_reader, @@ -340,6 +343,7 @@ where Arc::clone(&stats), config.access.user_data_quota.get(user).copied(), buffer_pool, + traffic_lease, relay_activity_timeout, ); tokio::pin!(relay_result); diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index eb68f83..aff2eb7 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -28,6 +28,7 @@ use crate::proxy::route_mode::{ use crate::proxy::shared_state::{ ConntrackCloseEvent, ConntrackClosePublishResult, ConntrackCloseReason, ProxySharedState, }; +use crate::proxy::traffic_limiter::{RateDirection, TrafficLease, next_refill_delay}; use crate::stats::{ MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, QuotaReserveError, Stats, UserStats, }; @@ -286,6 +287,10 @@ impl RelayClientIdleState { self.last_client_frame_at = now; self.soft_idle_marked = false; } + + fn on_client_tiny_frame(&mut self, now: Instant) { + self.last_client_frame_at = now; + } } impl MeD2cFlushPolicy { @@ -595,6 +600,41 @@ async fn reserve_user_quota_with_yield( } } +async fn wait_for_traffic_budget( + lease: Option<&Arc>, + direction: RateDirection, + bytes: u64, +) { + if bytes == 0 { + return; + } + let Some(lease) = lease else { + return; + }; + + let mut remaining = bytes; + while remaining > 0 { + let consume = lease.try_consume(direction, remaining); + if consume.granted > 0 { + remaining = remaining.saturating_sub(consume.granted); + continue; + } + + let wait_started_at = Instant::now(); + tokio::time::sleep(next_refill_delay()).await; + let wait_ms = wait_started_at + .elapsed() + .as_millis() + .min(u128::from(u64::MAX)) as u64; + lease.observe_wait_ms( + direction, + consume.blocked_user, + consume.blocked_cidr, + wait_ms, + ); + } +} + fn classify_me_d2c_flush_reason( flush_immediately: bool, batch_frames: usize, @@ -985,6 +1025,7 @@ where let quota_limit = config.access.user_data_quota.get(&user).copied(); let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user)); let peer = success.peer; + let traffic_lease = shared.traffic_limiter.acquire_lease(&user, peer.ip()); let proto_tag = success.proto_tag; let pool_generation = me_pool.current_generation(); @@ -1120,6 +1161,7 @@ where let rng_clone = rng.clone(); let user_clone = user.clone(); let quota_user_stats_me_writer = quota_user_stats.clone(); + let traffic_lease_me_writer = traffic_lease.clone(); let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); let bytes_me2c_clone = bytes_me2c.clone(); let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); @@ -1153,7 +1195,7 @@ where let first_is_downstream_activity = matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_traffic_lease( first, &mut writer, proto_tag, @@ -1164,6 +1206,7 @@ where quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + traffic_lease_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1213,7 +1256,7 @@ where let next_is_downstream_activity = matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_traffic_lease( next, &mut writer, proto_tag, @@ -1224,6 +1267,7 @@ where quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + traffic_lease_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1276,7 +1320,7 @@ where Ok(Some(next)) => { let next_is_downstream_activity = matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_traffic_lease( next, &mut writer, proto_tag, @@ -1287,6 +1331,7 @@ where quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + traffic_lease_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1341,7 +1386,7 @@ where let extra_is_downstream_activity = matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_traffic_lease( extra, &mut writer, proto_tag, @@ -1352,6 +1397,7 @@ where quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + traffic_lease_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1542,6 +1588,12 @@ where match payload_result { Ok(Some((payload, quickack))) => { trace!(conn_id, bytes = payload.len(), "C->ME frame"); + wait_for_traffic_budget( + traffic_lease.as_ref(), + RateDirection::Up, + payload.len() as u64, + ) + .await; forensics.bytes_c2me = forensics .bytes_c2me .saturating_add(payload.len() as u64); @@ -1762,40 +1814,6 @@ where let downstream_ms = last_downstream_activity_ms.load(Ordering::Relaxed); let hard_deadline = hard_deadline(idle_policy, idle_state, session_started_at, downstream_ms); - if now >= hard_deadline { - clear_relay_idle_candidate_in(shared, forensics.conn_id); - stats.increment_relay_idle_hard_close_total(); - let client_idle_secs = now - .saturating_duration_since(idle_state.last_client_frame_at) - .as_secs(); - let downstream_idle_secs = now - .saturating_duration_since( - session_started_at + Duration::from_millis(downstream_ms), - ) - .as_secs(); - warn!( - trace_id = format_args!("0x{:016x}", forensics.trace_id), - conn_id = forensics.conn_id, - user = %forensics.user, - read_label, - client_idle_secs, - downstream_idle_secs, - soft_idle_secs = idle_policy.soft_idle.as_secs(), - hard_idle_secs = idle_policy.hard_idle.as_secs(), - grace_secs = idle_policy.grace_after_downstream_activity.as_secs(), - "Middle-relay hard idle close" - ); - return Err(ProxyError::Io(std::io::Error::new( - std::io::ErrorKind::TimedOut, - format!( - "middle-relay hard idle timeout while reading {read_label}: client_idle_secs={client_idle_secs}, downstream_idle_secs={downstream_idle_secs}, soft_idle_secs={}, hard_idle_secs={}, grace_secs={}", - idle_policy.soft_idle.as_secs(), - idle_policy.hard_idle.as_secs(), - idle_policy.grace_after_downstream_activity.as_secs(), - ), - ))); - } - if !idle_state.soft_idle_marked && now.saturating_duration_since(idle_state.last_client_frame_at) >= idle_policy.soft_idle @@ -1850,7 +1868,45 @@ where ), ))); } - Err(_) => {} + Err(_) => { + let now = Instant::now(); + let downstream_ms = last_downstream_activity_ms.load(Ordering::Relaxed); + let hard_deadline = + hard_deadline(idle_policy, idle_state, session_started_at, downstream_ms); + if now >= hard_deadline { + clear_relay_idle_candidate_in(shared, forensics.conn_id); + stats.increment_relay_idle_hard_close_total(); + let client_idle_secs = now + .saturating_duration_since(idle_state.last_client_frame_at) + .as_secs(); + let downstream_idle_secs = now + .saturating_duration_since( + session_started_at + Duration::from_millis(downstream_ms), + ) + .as_secs(); + warn!( + trace_id = format_args!("0x{:016x}", forensics.trace_id), + conn_id = forensics.conn_id, + user = %forensics.user, + read_label, + client_idle_secs, + downstream_idle_secs, + soft_idle_secs = idle_policy.soft_idle.as_secs(), + hard_idle_secs = idle_policy.hard_idle.as_secs(), + grace_secs = idle_policy.grace_after_downstream_activity.as_secs(), + "Middle-relay hard idle close" + ); + return Err(ProxyError::Io(std::io::Error::new( + std::io::ErrorKind::TimedOut, + format!( + "middle-relay hard idle timeout while reading {read_label}: client_idle_secs={client_idle_secs}, downstream_idle_secs={downstream_idle_secs}, soft_idle_secs={}, hard_idle_secs={}, grace_secs={}", + idle_policy.soft_idle.as_secs(), + idle_policy.hard_idle.as_secs(), + idle_policy.grace_after_downstream_activity.as_secs(), + ), + ))); + } + } } } @@ -1941,6 +1997,7 @@ where }; if len == 0 { + idle_state.on_client_tiny_frame(Instant::now()); idle_state.tiny_frame_debt = idle_state .tiny_frame_debt .saturating_add(TINY_FRAME_DEBT_PER_TINY); @@ -2160,6 +2217,46 @@ async fn process_me_writer_response( ack_flush_immediate: bool, batched: bool, ) -> Result +where + W: AsyncWrite + Unpin + Send + 'static, +{ + process_me_writer_response_with_traffic_lease( + response, + client_writer, + proto_tag, + rng, + frame_buf, + stats, + user, + quota_user_stats, + quota_limit, + quota_soft_overshoot_bytes, + None, + bytes_me2c, + conn_id, + ack_flush_immediate, + batched, + ) + .await +} + +async fn process_me_writer_response_with_traffic_lease( + response: MeResponse, + client_writer: &mut CryptoWriter, + proto_tag: ProtoTag, + rng: &SecureRandom, + frame_buf: &mut Vec, + stats: &Stats, + user: &str, + quota_user_stats: Option<&UserStats>, + quota_limit: Option, + quota_soft_overshoot_bytes: u64, + traffic_lease: Option<&Arc>, + bytes_me2c: &AtomicU64, + conn_id: u64, + ack_flush_immediate: bool, + batched: bool, +) -> Result where W: AsyncWrite + Unpin + Send + 'static, { @@ -2183,6 +2280,7 @@ where }); } } + wait_for_traffic_budget(traffic_lease, RateDirection::Down, data_len).await; let write_mode = match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) @@ -2220,6 +2318,7 @@ where } else { trace!(conn_id, confirm, "ME->C quickack"); } + wait_for_traffic_budget(traffic_lease, RateDirection::Down, 4).await; write_client_ack(client_writer, proto_tag, confirm).await?; stats.increment_me_d2c_ack_frames_total(); diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index c4ce09c..4e1827e 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -68,6 +68,7 @@ pub mod relay; pub mod route_mode; pub mod session_eviction; pub mod shared_state; +pub mod traffic_limiter; pub use client::ClientHandler; #[allow(unused_imports)] diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index f612cb1..c9e6a98 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -52,6 +52,7 @@ //! - `SharedCounters` (atomics) let the watchdog read stats without locking use crate::error::{ProxyError, Result}; +use crate::proxy::traffic_limiter::{RateDirection, TrafficLease, next_refill_delay}; use crate::stats::{Stats, UserStats}; use crate::stream::BufferPool; use std::io; @@ -61,7 +62,7 @@ use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes}; -use tokio::time::Instant; +use tokio::time::{Instant, Sleep}; use tracing::{debug, trace, warn}; // ============= Constants ============= @@ -210,12 +211,24 @@ struct StatsIo { stats: Arc, user: String, user_stats: Arc, + traffic_lease: Option>, + c2s_rate_debt_bytes: u64, + c2s_wait: RateWaitState, + s2c_wait: RateWaitState, quota_limit: Option, quota_exceeded: Arc, quota_bytes_since_check: u64, epoch: Instant, } +#[derive(Default)] +struct RateWaitState { + sleep: Option>>, + started_at: Option, + blocked_user: bool, + blocked_cidr: bool, +} + impl StatsIo { fn new( inner: S, @@ -225,6 +238,28 @@ impl StatsIo { quota_limit: Option, quota_exceeded: Arc, epoch: Instant, + ) -> Self { + Self::new_with_traffic_lease( + inner, + counters, + stats, + user, + None, + quota_limit, + quota_exceeded, + epoch, + ) + } + + fn new_with_traffic_lease( + inner: S, + counters: Arc, + stats: Arc, + user: String, + traffic_lease: Option>, + quota_limit: Option, + quota_exceeded: Arc, + epoch: Instant, ) -> Self { // Mark initial activity so the watchdog doesn't fire before data flows counters.touch(Instant::now(), epoch); @@ -235,12 +270,88 @@ impl StatsIo { stats, user, user_stats, + traffic_lease, + c2s_rate_debt_bytes: 0, + c2s_wait: RateWaitState::default(), + s2c_wait: RateWaitState::default(), quota_limit, quota_exceeded, quota_bytes_since_check: 0, epoch, } } + + fn record_wait( + wait: &mut RateWaitState, + lease: Option<&Arc>, + direction: RateDirection, + ) { + let Some(started_at) = wait.started_at.take() else { + return; + }; + let wait_ms = started_at.elapsed().as_millis().min(u128::from(u64::MAX)) as u64; + if let Some(lease) = lease { + lease.observe_wait_ms(direction, wait.blocked_user, wait.blocked_cidr, wait_ms); + } + wait.blocked_user = false; + wait.blocked_cidr = false; + } + + fn arm_wait(wait: &mut RateWaitState, blocked_user: bool, blocked_cidr: bool) { + if wait.sleep.is_none() { + wait.sleep = Some(Box::pin(tokio::time::sleep(next_refill_delay()))); + wait.started_at = Some(Instant::now()); + } + wait.blocked_user |= blocked_user; + wait.blocked_cidr |= blocked_cidr; + } + + fn poll_wait( + wait: &mut RateWaitState, + cx: &mut Context<'_>, + lease: Option<&Arc>, + direction: RateDirection, + ) -> Poll<()> { + let Some(sleep) = wait.sleep.as_mut() else { + return Poll::Ready(()); + }; + if sleep.as_mut().poll(cx).is_pending() { + return Poll::Pending; + } + wait.sleep = None; + Self::record_wait(wait, lease, direction); + Poll::Ready(()) + } + + fn settle_c2s_rate_debt(&mut self, cx: &mut Context<'_>) -> Poll<()> { + let Some(lease) = self.traffic_lease.as_ref() else { + self.c2s_rate_debt_bytes = 0; + return Poll::Ready(()); + }; + + while self.c2s_rate_debt_bytes > 0 { + let consume = lease.try_consume(RateDirection::Up, self.c2s_rate_debt_bytes); + if consume.granted > 0 { + self.c2s_rate_debt_bytes = self.c2s_rate_debt_bytes.saturating_sub(consume.granted); + continue; + } + Self::arm_wait( + &mut self.c2s_wait, + consume.blocked_user, + consume.blocked_cidr, + ); + if Self::poll_wait(&mut self.c2s_wait, cx, Some(lease), RateDirection::Up).is_pending() + { + return Poll::Pending; + } + } + + if Self::poll_wait(&mut self.c2s_wait, cx, Some(lease), RateDirection::Up).is_pending() { + return Poll::Pending; + } + + Poll::Ready(()) + } } #[derive(Debug)] @@ -286,6 +397,25 @@ fn should_immediate_quota_check(remaining_before: u64, charge_bytes: u64) -> boo remaining_before <= QUOTA_NEAR_LIMIT_BYTES || charge_bytes >= QUOTA_LARGE_CHARGE_BYTES } +fn refund_reserved_quota_bytes(user_stats: &UserStats, reserved_bytes: u64) { + if reserved_bytes == 0 { + return; + } + let mut current = user_stats.quota_used.load(Ordering::Relaxed); + loop { + let next = current.saturating_sub(reserved_bytes); + match user_stats.quota_used.compare_exchange_weak( + current, + next, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => return, + Err(observed) => current = observed, + } + } +} + impl AsyncRead for StatsIo { fn poll_read( self: Pin<&mut Self>, @@ -296,6 +426,9 @@ impl AsyncRead for StatsIo { if this.quota_exceeded.load(Ordering::Acquire) { return Poll::Ready(Err(quota_io_error())); } + if this.settle_c2s_rate_debt(cx).is_pending() { + return Poll::Pending; + } let mut remaining_before = None; if let Some(limit) = this.quota_limit { @@ -377,6 +510,11 @@ impl AsyncRead for StatsIo { .add_user_octets_from_handle(this.user_stats.as_ref(), n_to_charge); this.stats .increment_user_msgs_from_handle(this.user_stats.as_ref()); + if this.traffic_lease.is_some() { + this.c2s_rate_debt_bytes = + this.c2s_rate_debt_bytes.saturating_add(n_to_charge); + let _ = this.settle_c2s_rate_debt(cx); + } trace!(user = %this.user, bytes = n, "C->S"); } @@ -398,28 +536,66 @@ impl AsyncWrite for StatsIo { return Poll::Ready(Err(quota_io_error())); } + let mut shaper_reserved_bytes = 0u64; + let mut write_buf = buf; + if let Some(lease) = this.traffic_lease.as_ref() { + if !buf.is_empty() { + loop { + let consume = lease.try_consume(RateDirection::Down, buf.len() as u64); + if consume.granted > 0 { + shaper_reserved_bytes = consume.granted; + if consume.granted < buf.len() as u64 { + write_buf = &buf[..consume.granted as usize]; + } + let _ = Self::poll_wait( + &mut this.s2c_wait, + cx, + Some(lease), + RateDirection::Down, + ); + break; + } + + Self::arm_wait( + &mut this.s2c_wait, + consume.blocked_user, + consume.blocked_cidr, + ); + if Self::poll_wait(&mut this.s2c_wait, cx, Some(lease), RateDirection::Down) + .is_pending() + { + return Poll::Pending; + } + } + } else { + let _ = Self::poll_wait(&mut this.s2c_wait, cx, Some(lease), RateDirection::Down); + } + } + let mut remaining_before = None; let mut reserved_bytes = 0u64; - let mut write_buf = buf; if let Some(limit) = this.quota_limit { - if !buf.is_empty() { + if !write_buf.is_empty() { let mut reserve_rounds = 0usize; while reserved_bytes == 0 { let used_before = this.user_stats.quota_used(); let remaining = limit.saturating_sub(used_before); if remaining == 0 { + if let Some(lease) = this.traffic_lease.as_ref() { + lease.refund(RateDirection::Down, shaper_reserved_bytes); + } this.quota_exceeded.store(true, Ordering::Release); return Poll::Ready(Err(quota_io_error())); } remaining_before = Some(remaining); - let desired = remaining.min(buf.len() as u64); + let desired = remaining.min(write_buf.len() as u64); let mut saw_contention = false; for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { match this.user_stats.quota_try_reserve(desired, limit) { Ok(_) => { reserved_bytes = desired; - write_buf = &buf[..desired as usize]; + write_buf = &write_buf[..desired as usize]; break; } Err(crate::stats::QuotaReserveError::LimitExceeded) => { @@ -434,6 +610,9 @@ impl AsyncWrite for StatsIo { if reserved_bytes == 0 { reserve_rounds = reserve_rounds.saturating_add(1); if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS { + if let Some(lease) = this.traffic_lease.as_ref() { + lease.refund(RateDirection::Down, shaper_reserved_bytes); + } this.quota_exceeded.store(true, Ordering::Release); return Poll::Ready(Err(quota_io_error())); } @@ -446,6 +625,9 @@ impl AsyncWrite for StatsIo { let used_before = this.user_stats.quota_used(); let remaining = limit.saturating_sub(used_before); if remaining == 0 { + if let Some(lease) = this.traffic_lease.as_ref() { + lease.refund(RateDirection::Down, shaper_reserved_bytes); + } this.quota_exceeded.store(true, Ordering::Release); return Poll::Ready(Err(quota_io_error())); } @@ -456,23 +638,20 @@ impl AsyncWrite for StatsIo { match Pin::new(&mut this.inner).poll_write(cx, write_buf) { Poll::Ready(Ok(n)) => { if reserved_bytes > n as u64 { - let refund = reserved_bytes - n as u64; - let mut current = this.user_stats.quota_used.load(Ordering::Relaxed); - loop { - let next = current.saturating_sub(refund); - match this.user_stats.quota_used.compare_exchange_weak( - current, - next, - Ordering::Relaxed, - Ordering::Relaxed, - ) { - Ok(_) => break, - Err(observed) => current = observed, - } - } + refund_reserved_quota_bytes( + this.user_stats.as_ref(), + reserved_bytes - n as u64, + ); + } + if shaper_reserved_bytes > n as u64 + && let Some(lease) = this.traffic_lease.as_ref() + { + lease.refund(RateDirection::Down, shaper_reserved_bytes - n as u64); } - if n > 0 { + if let Some(lease) = this.traffic_lease.as_ref() { + Self::record_wait(&mut this.s2c_wait, Some(lease), RateDirection::Down); + } let n_to_charge = n as u64; // S→C: data written to client @@ -512,37 +691,23 @@ impl AsyncWrite for StatsIo { } Poll::Ready(Err(err)) => { if reserved_bytes > 0 { - let mut current = this.user_stats.quota_used.load(Ordering::Relaxed); - loop { - let next = current.saturating_sub(reserved_bytes); - match this.user_stats.quota_used.compare_exchange_weak( - current, - next, - Ordering::Relaxed, - Ordering::Relaxed, - ) { - Ok(_) => break, - Err(observed) => current = observed, - } - } + refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes); + } + if shaper_reserved_bytes > 0 + && let Some(lease) = this.traffic_lease.as_ref() + { + lease.refund(RateDirection::Down, shaper_reserved_bytes); } Poll::Ready(Err(err)) } Poll::Pending => { if reserved_bytes > 0 { - let mut current = this.user_stats.quota_used.load(Ordering::Relaxed); - loop { - let next = current.saturating_sub(reserved_bytes); - match this.user_stats.quota_used.compare_exchange_weak( - current, - next, - Ordering::Relaxed, - Ordering::Relaxed, - ) { - Ok(_) => break, - Err(observed) => current = observed, - } - } + refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes); + } + if shaper_reserved_bytes > 0 + && let Some(lease) = this.traffic_lease.as_ref() + { + lease.refund(RateDirection::Down, shaper_reserved_bytes); } Poll::Pending } @@ -627,6 +792,43 @@ pub async fn relay_bidirectional_with_activity_timeout( _buffer_pool: Arc, activity_timeout: Duration, ) -> Result<()> +where + CR: AsyncRead + Unpin + Send + 'static, + CW: AsyncWrite + Unpin + Send + 'static, + SR: AsyncRead + Unpin + Send + 'static, + SW: AsyncWrite + Unpin + Send + 'static, +{ + relay_bidirectional_with_activity_timeout_and_lease( + client_reader, + client_writer, + server_reader, + server_writer, + c2s_buf_size, + s2c_buf_size, + user, + stats, + quota_limit, + _buffer_pool, + None, + activity_timeout, + ) + .await +} + +pub async fn relay_bidirectional_with_activity_timeout_and_lease( + client_reader: CR, + client_writer: CW, + server_reader: SR, + server_writer: SW, + c2s_buf_size: usize, + s2c_buf_size: usize, + user: &str, + stats: Arc, + quota_limit: Option, + _buffer_pool: Arc, + traffic_lease: Option>, + activity_timeout: Duration, +) -> Result<()> where CR: AsyncRead + Unpin + Send + 'static, CW: AsyncWrite + Unpin + Send + 'static, @@ -644,11 +846,12 @@ where let mut server = CombinedStream::new(server_reader, server_writer); // Wrap client with stats/activity tracking - let mut client = StatsIo::new( + let mut client = StatsIo::new_with_traffic_lease( client_combined, Arc::clone(&counters), Arc::clone(&stats), user_owned.clone(), + traffic_lease, quota_limit, Arc::clone("a_exceeded), epoch, diff --git a/src/proxy/shared_state.rs b/src/proxy/shared_state.rs index 4fef497..e204890 100644 --- a/src/proxy/shared_state.rs +++ b/src/proxy/shared_state.rs @@ -10,6 +10,7 @@ use tokio::sync::mpsc; use crate::proxy::handshake::{AuthProbeSaturationState, AuthProbeState}; use crate::proxy::middle_relay::{DesyncDedupRotationState, RelayIdleCandidateRegistry}; +use crate::proxy::traffic_limiter::TrafficLimiter; const HANDSHAKE_RECENT_USER_RING_LEN: usize = 64; @@ -65,6 +66,7 @@ pub(crate) struct MiddleRelaySharedState { pub(crate) struct ProxySharedState { pub(crate) handshake: HandshakeSharedState, pub(crate) middle_relay: MiddleRelaySharedState, + pub(crate) traffic_limiter: Arc, pub(crate) conntrack_pressure_active: AtomicBool, pub(crate) conntrack_close_tx: Mutex>>, } @@ -98,6 +100,7 @@ impl ProxySharedState { relay_idle_registry: Mutex::new(RelayIdleCandidateRegistry::default()), relay_idle_mark_seq: AtomicU64::new(0), }, + traffic_limiter: TrafficLimiter::new(), conntrack_pressure_active: AtomicBool::new(false), conntrack_close_tx: Mutex::new(None), }) diff --git a/src/proxy/traffic_limiter.rs b/src/proxy/traffic_limiter.rs new file mode 100644 index 0000000..0edfa0f --- /dev/null +++ b/src/proxy/traffic_limiter.rs @@ -0,0 +1,853 @@ +use std::collections::{HashMap, HashSet}; +use std::hash::{Hash, Hasher}; +use std::net::IpAddr; +use std::sync::Arc; +use std::sync::OnceLock; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; + +use arc_swap::ArcSwap; +use dashmap::DashMap; +use ipnetwork::IpNetwork; + +use crate::config::RateLimitBps; + +const REGISTRY_SHARDS: usize = 64; +const FAIR_EPOCH_MS: u64 = 20; +const MAX_BORROW_CHUNK_BYTES: u64 = 32 * 1024; +const CLEANUP_INTERVAL_SECS: u64 = 60; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RateDirection { + Up, + Down, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TrafficConsumeResult { + pub granted: u64, + pub blocked_user: bool, + pub blocked_cidr: bool, +} + +#[derive(Debug, Clone, Copy)] +pub struct TrafficLimiterMetricsSnapshot { + pub user_throttle_up_total: u64, + pub user_throttle_down_total: u64, + pub cidr_throttle_up_total: u64, + pub cidr_throttle_down_total: u64, + pub user_wait_up_ms_total: u64, + pub user_wait_down_ms_total: u64, + pub cidr_wait_up_ms_total: u64, + pub cidr_wait_down_ms_total: u64, + pub user_active_leases: u64, + pub cidr_active_leases: u64, + pub user_policy_entries: u64, + pub cidr_policy_entries: u64, +} + +#[derive(Default)] +struct ScopeMetrics { + throttle_up_total: AtomicU64, + throttle_down_total: AtomicU64, + wait_up_ms_total: AtomicU64, + wait_down_ms_total: AtomicU64, + active_leases: AtomicU64, + policy_entries: AtomicU64, +} + +impl ScopeMetrics { + fn throttle(&self, direction: RateDirection) { + match direction { + RateDirection::Up => { + self.throttle_up_total.fetch_add(1, Ordering::Relaxed); + } + RateDirection::Down => { + self.throttle_down_total.fetch_add(1, Ordering::Relaxed); + } + } + } + + fn wait_ms(&self, direction: RateDirection, wait_ms: u64) { + match direction { + RateDirection::Up => { + self.wait_up_ms_total.fetch_add(wait_ms, Ordering::Relaxed); + } + RateDirection::Down => { + self.wait_down_ms_total + .fetch_add(wait_ms, Ordering::Relaxed); + } + } + } +} + +#[derive(Default)] +struct AtomicRatePair { + up_bps: AtomicU64, + down_bps: AtomicU64, +} + +impl AtomicRatePair { + fn set(&self, limits: RateLimitBps) { + self.up_bps.store(limits.up_bps, Ordering::Relaxed); + self.down_bps.store(limits.down_bps, Ordering::Relaxed); + } + + fn get(&self, direction: RateDirection) -> u64 { + match direction { + RateDirection::Up => self.up_bps.load(Ordering::Relaxed), + RateDirection::Down => self.down_bps.load(Ordering::Relaxed), + } + } +} + +#[derive(Default)] +struct DirectionBucket { + epoch: AtomicU64, + used: AtomicU64, +} + +impl DirectionBucket { + fn sync_epoch(&self, epoch: u64) { + let current = self.epoch.load(Ordering::Relaxed); + if current == epoch { + return; + } + if current < epoch + && self + .epoch + .compare_exchange(current, epoch, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + self.used.store(0, Ordering::Relaxed); + } + } + + fn try_consume(&self, cap_bps: u64, requested: u64) -> u64 { + if requested == 0 { + return 0; + } + if cap_bps == 0 { + return requested; + } + + let epoch = current_epoch(); + self.sync_epoch(epoch); + let cap_epoch = bytes_per_epoch(cap_bps); + + loop { + let used = self.used.load(Ordering::Relaxed); + if used >= cap_epoch { + return 0; + } + let remaining = cap_epoch.saturating_sub(used); + let grant = requested.min(remaining); + if grant == 0 { + return 0; + } + let next = used.saturating_add(grant); + if self + .used + .compare_exchange_weak(used, next, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + return grant; + } + } + } + + fn refund(&self, bytes: u64) { + if bytes == 0 { + return; + } + decrement_atomic_saturating(&self.used, bytes); + } +} + +struct UserBucket { + rates: AtomicRatePair, + up: DirectionBucket, + down: DirectionBucket, + active_leases: AtomicU64, +} + +impl UserBucket { + fn new(limits: RateLimitBps) -> Self { + let rates = AtomicRatePair::default(); + rates.set(limits); + Self { + rates, + up: DirectionBucket::default(), + down: DirectionBucket::default(), + active_leases: AtomicU64::new(0), + } + } + + fn set_rates(&self, limits: RateLimitBps) { + self.rates.set(limits); + } + + fn try_consume(&self, direction: RateDirection, requested: u64) -> u64 { + let cap_bps = self.rates.get(direction); + match direction { + RateDirection::Up => self.up.try_consume(cap_bps, requested), + RateDirection::Down => self.down.try_consume(cap_bps, requested), + } + } + + fn refund(&self, direction: RateDirection, bytes: u64) { + match direction { + RateDirection::Up => self.up.refund(bytes), + RateDirection::Down => self.down.refund(bytes), + } + } +} + +#[derive(Default)] +struct CidrDirectionBucket { + epoch: AtomicU64, + used: AtomicU64, + active_users: AtomicU64, +} + +impl CidrDirectionBucket { + fn sync_epoch(&self, epoch: u64) { + let current = self.epoch.load(Ordering::Relaxed); + if current == epoch { + return; + } + if current < epoch + && self + .epoch + .compare_exchange(current, epoch, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + self.used.store(0, Ordering::Relaxed); + self.active_users.store(0, Ordering::Relaxed); + } + } + + fn try_consume( + &self, + user_state: &CidrUserDirectionState, + cap_epoch: u64, + requested: u64, + ) -> u64 { + if requested == 0 || cap_epoch == 0 { + return 0; + } + + let epoch = current_epoch(); + self.sync_epoch(epoch); + user_state.sync_epoch_and_mark_active(epoch, &self.active_users); + let active_users = self.active_users.load(Ordering::Relaxed).max(1); + let fair_share = cap_epoch.saturating_div(active_users).max(1); + + loop { + let total_used = self.used.load(Ordering::Relaxed); + if total_used >= cap_epoch { + return 0; + } + let total_remaining = cap_epoch.saturating_sub(total_used); + let user_used = user_state.used.load(Ordering::Relaxed); + let guaranteed_remaining = fair_share.saturating_sub(user_used); + + let grant = if guaranteed_remaining > 0 { + requested.min(guaranteed_remaining).min(total_remaining) + } else { + requested.min(total_remaining).min(MAX_BORROW_CHUNK_BYTES) + }; + + if grant == 0 { + return 0; + } + + let next_total = total_used.saturating_add(grant); + if self + .used + .compare_exchange_weak(total_used, next_total, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + user_state.used.fetch_add(grant, Ordering::Relaxed); + return grant; + } + } + } + + fn refund(&self, bytes: u64) { + if bytes == 0 { + return; + } + decrement_atomic_saturating(&self.used, bytes); + } +} + +#[derive(Default)] +struct CidrUserDirectionState { + epoch: AtomicU64, + used: AtomicU64, +} + +impl CidrUserDirectionState { + fn sync_epoch_and_mark_active(&self, epoch: u64, active_users: &AtomicU64) { + let current = self.epoch.load(Ordering::Relaxed); + if current == epoch { + return; + } + if current < epoch + && self + .epoch + .compare_exchange(current, epoch, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + self.used.store(0, Ordering::Relaxed); + active_users.fetch_add(1, Ordering::Relaxed); + } + } + + fn refund(&self, bytes: u64) { + if bytes == 0 { + return; + } + decrement_atomic_saturating(&self.used, bytes); + } +} + +struct CidrUserShare { + active_conns: AtomicU64, + up: CidrUserDirectionState, + down: CidrUserDirectionState, +} + +impl CidrUserShare { + fn new() -> Self { + Self { + active_conns: AtomicU64::new(0), + up: CidrUserDirectionState::default(), + down: CidrUserDirectionState::default(), + } + } +} + +struct CidrBucket { + rates: AtomicRatePair, + up: CidrDirectionBucket, + down: CidrDirectionBucket, + users: ShardedRegistry, + active_leases: AtomicU64, +} + +impl CidrBucket { + fn new(limits: RateLimitBps) -> Self { + let rates = AtomicRatePair::default(); + rates.set(limits); + Self { + rates, + up: CidrDirectionBucket::default(), + down: CidrDirectionBucket::default(), + users: ShardedRegistry::new(REGISTRY_SHARDS), + active_leases: AtomicU64::new(0), + } + } + + fn set_rates(&self, limits: RateLimitBps) { + self.rates.set(limits); + } + + fn acquire_user_share(&self, user: &str) -> Arc { + let share = self.users.get_or_insert_with(user, CidrUserShare::new); + share.active_conns.fetch_add(1, Ordering::Relaxed); + share + } + + fn release_user_share(&self, user: &str, share: &Arc) { + decrement_atomic_saturating(&share.active_conns, 1); + let share_for_remove = Arc::clone(share); + let _ = self.users.remove_if(user, |candidate| { + Arc::ptr_eq(candidate, &share_for_remove) + && candidate.active_conns.load(Ordering::Relaxed) == 0 + }); + } + + fn try_consume_for_user( + &self, + direction: RateDirection, + share: &CidrUserShare, + requested: u64, + ) -> u64 { + let cap_bps = self.rates.get(direction); + if cap_bps == 0 { + return requested; + } + let cap_epoch = bytes_per_epoch(cap_bps); + match direction { + RateDirection::Up => self.up.try_consume(&share.up, cap_epoch, requested), + RateDirection::Down => self.down.try_consume(&share.down, cap_epoch, requested), + } + } + + fn refund_for_user(&self, direction: RateDirection, share: &CidrUserShare, bytes: u64) { + match direction { + RateDirection::Up => { + self.up.refund(bytes); + share.up.refund(bytes); + } + RateDirection::Down => { + self.down.refund(bytes); + share.down.refund(bytes); + } + } + } + + fn cleanup_idle_users(&self) { + self.users + .retain(|_, share| share.active_conns.load(Ordering::Relaxed) > 0); + } +} + +#[derive(Clone)] +struct CidrRule { + key: String, + cidr: IpNetwork, + limits: RateLimitBps, + prefix_len: u8, +} + +#[derive(Default)] +struct PolicySnapshot { + user_limits: HashMap, + cidr_rules_v4: Vec, + cidr_rules_v6: Vec, + cidr_rule_keys: HashSet, +} + +impl PolicySnapshot { + fn match_cidr(&self, ip: IpAddr) -> Option<&CidrRule> { + match ip { + IpAddr::V4(_) => self + .cidr_rules_v4 + .iter() + .find(|rule| rule.cidr.contains(ip)), + IpAddr::V6(_) => self + .cidr_rules_v6 + .iter() + .find(|rule| rule.cidr.contains(ip)), + } + } +} + +struct ShardedRegistry { + shards: Box<[DashMap>]>, + mask: usize, +} + +impl ShardedRegistry { + fn new(shards: usize) -> Self { + let shard_count = shards.max(1).next_power_of_two(); + let mut items = Vec::with_capacity(shard_count); + for _ in 0..shard_count { + items.push(DashMap::>::new()); + } + Self { + shards: items.into_boxed_slice(), + mask: shard_count.saturating_sub(1), + } + } + + fn shard_index(&self, key: &str) -> usize { + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + key.hash(&mut hasher); + (hasher.finish() as usize) & self.mask + } + + fn get_or_insert_with(&self, key: &str, make: F) -> Arc + where + F: FnOnce() -> T, + { + let shard = &self.shards[self.shard_index(key)]; + match shard.entry(key.to_string()) { + dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), + dashmap::mapref::entry::Entry::Vacant(slot) => { + let value = Arc::new(make()); + slot.insert(Arc::clone(&value)); + value + } + } + } + + fn retain(&self, predicate: F) + where + F: Fn(&String, &Arc) -> bool + Copy, + { + for shard in &*self.shards { + shard.retain(|key, value| predicate(key, value)); + } + } + + fn remove_if(&self, key: &str, predicate: F) -> bool + where + F: Fn(&Arc) -> bool, + { + let shard = &self.shards[self.shard_index(key)]; + let should_remove = match shard.get(key) { + Some(entry) => predicate(entry.value()), + None => false, + }; + if !should_remove { + return false; + } + shard.remove(key).is_some() + } +} + +pub struct TrafficLease { + limiter: Arc, + user_bucket: Option>, + cidr_bucket: Option>, + cidr_user_key: Option, + cidr_user_share: Option>, +} + +impl TrafficLease { + pub fn try_consume(&self, direction: RateDirection, requested: u64) -> TrafficConsumeResult { + if requested == 0 { + return TrafficConsumeResult { + granted: 0, + blocked_user: false, + blocked_cidr: false, + }; + } + + let mut granted = requested; + if let Some(user_bucket) = self.user_bucket.as_ref() { + let user_granted = user_bucket.try_consume(direction, granted); + if user_granted == 0 { + self.limiter.observe_throttle(direction, true, false); + return TrafficConsumeResult { + granted: 0, + blocked_user: true, + blocked_cidr: false, + }; + } + granted = user_granted; + } + + if let (Some(cidr_bucket), Some(cidr_user_share)) = + (self.cidr_bucket.as_ref(), self.cidr_user_share.as_ref()) + { + let cidr_granted = + cidr_bucket.try_consume_for_user(direction, cidr_user_share, granted); + if cidr_granted < granted + && let Some(user_bucket) = self.user_bucket.as_ref() + { + user_bucket.refund(direction, granted.saturating_sub(cidr_granted)); + } + if cidr_granted == 0 { + self.limiter.observe_throttle(direction, false, true); + return TrafficConsumeResult { + granted: 0, + blocked_user: false, + blocked_cidr: true, + }; + } + granted = cidr_granted; + } + + TrafficConsumeResult { + granted, + blocked_user: false, + blocked_cidr: false, + } + } + + pub fn refund(&self, direction: RateDirection, bytes: u64) { + if bytes == 0 { + return; + } + + if let Some(user_bucket) = self.user_bucket.as_ref() { + user_bucket.refund(direction, bytes); + } + if let (Some(cidr_bucket), Some(cidr_user_share)) = + (self.cidr_bucket.as_ref(), self.cidr_user_share.as_ref()) + { + cidr_bucket.refund_for_user(direction, cidr_user_share, bytes); + } + } + + pub fn observe_wait_ms( + &self, + direction: RateDirection, + blocked_user: bool, + blocked_cidr: bool, + wait_ms: u64, + ) { + if wait_ms == 0 { + return; + } + self.limiter + .observe_wait(direction, blocked_user, blocked_cidr, wait_ms); + } +} + +impl Drop for TrafficLease { + fn drop(&mut self) { + if let Some(bucket) = self.user_bucket.as_ref() { + decrement_atomic_saturating(&bucket.active_leases, 1); + decrement_atomic_saturating(&self.limiter.user_scope.active_leases, 1); + } + + if let Some(bucket) = self.cidr_bucket.as_ref() { + if let (Some(user_key), Some(share)) = + (self.cidr_user_key.as_ref(), self.cidr_user_share.as_ref()) + { + bucket.release_user_share(user_key, share); + } + decrement_atomic_saturating(&bucket.active_leases, 1); + decrement_atomic_saturating(&self.limiter.cidr_scope.active_leases, 1); + } + } +} + +pub struct TrafficLimiter { + policy: ArcSwap, + user_buckets: ShardedRegistry, + cidr_buckets: ShardedRegistry, + user_scope: ScopeMetrics, + cidr_scope: ScopeMetrics, + last_cleanup_epoch_secs: AtomicU64, +} + +impl TrafficLimiter { + pub fn new() -> Arc { + Arc::new(Self { + policy: ArcSwap::from_pointee(PolicySnapshot::default()), + user_buckets: ShardedRegistry::new(REGISTRY_SHARDS), + cidr_buckets: ShardedRegistry::new(REGISTRY_SHARDS), + user_scope: ScopeMetrics::default(), + cidr_scope: ScopeMetrics::default(), + last_cleanup_epoch_secs: AtomicU64::new(0), + }) + } + + pub fn apply_policy( + &self, + user_limits: HashMap, + cidr_limits: HashMap, + ) { + let filtered_users = user_limits + .into_iter() + .filter(|(_, limit)| limit.up_bps > 0 || limit.down_bps > 0) + .collect::>(); + + let mut cidr_rules_v4 = Vec::new(); + let mut cidr_rules_v6 = Vec::new(); + let mut cidr_rule_keys = HashSet::new(); + for (cidr, limits) in cidr_limits { + if limits.up_bps == 0 && limits.down_bps == 0 { + continue; + } + let key = cidr.to_string(); + let rule = CidrRule { + key: key.clone(), + cidr, + limits, + prefix_len: cidr.prefix(), + }; + cidr_rule_keys.insert(key); + match rule.cidr { + IpNetwork::V4(_) => cidr_rules_v4.push(rule), + IpNetwork::V6(_) => cidr_rules_v6.push(rule), + } + } + + cidr_rules_v4.sort_by(|a, b| b.prefix_len.cmp(&a.prefix_len)); + cidr_rules_v6.sort_by(|a, b| b.prefix_len.cmp(&a.prefix_len)); + + self.user_scope + .policy_entries + .store(filtered_users.len() as u64, Ordering::Relaxed); + self.cidr_scope + .policy_entries + .store(cidr_rule_keys.len() as u64, Ordering::Relaxed); + + self.policy.store(Arc::new(PolicySnapshot { + user_limits: filtered_users, + cidr_rules_v4, + cidr_rules_v6, + cidr_rule_keys, + })); + + self.maybe_cleanup(); + } + + pub fn acquire_lease( + self: &Arc, + user: &str, + client_ip: IpAddr, + ) -> Option> { + let policy = self.policy.load_full(); + let mut user_bucket = None; + if let Some(limit) = policy.user_limits.get(user).copied() { + let bucket = self + .user_buckets + .get_or_insert_with(user, || UserBucket::new(limit)); + bucket.set_rates(limit); + bucket.active_leases.fetch_add(1, Ordering::Relaxed); + self.user_scope + .active_leases + .fetch_add(1, Ordering::Relaxed); + user_bucket = Some(bucket); + } + + let mut cidr_bucket = None; + let mut cidr_user_key = None; + let mut cidr_user_share = None; + if let Some(rule) = policy.match_cidr(client_ip) { + let bucket = self + .cidr_buckets + .get_or_insert_with(rule.key.as_str(), || CidrBucket::new(rule.limits)); + bucket.set_rates(rule.limits); + bucket.active_leases.fetch_add(1, Ordering::Relaxed); + self.cidr_scope + .active_leases + .fetch_add(1, Ordering::Relaxed); + let share = bucket.acquire_user_share(user); + cidr_user_key = Some(user.to_string()); + cidr_user_share = Some(share); + cidr_bucket = Some(bucket); + } + + if user_bucket.is_none() && cidr_bucket.is_none() { + return None; + } + + self.maybe_cleanup(); + Some(Arc::new(TrafficLease { + limiter: Arc::clone(self), + user_bucket, + cidr_bucket, + cidr_user_key, + cidr_user_share, + })) + } + + pub fn metrics_snapshot(&self) -> TrafficLimiterMetricsSnapshot { + TrafficLimiterMetricsSnapshot { + user_throttle_up_total: self.user_scope.throttle_up_total.load(Ordering::Relaxed), + user_throttle_down_total: self.user_scope.throttle_down_total.load(Ordering::Relaxed), + cidr_throttle_up_total: self.cidr_scope.throttle_up_total.load(Ordering::Relaxed), + cidr_throttle_down_total: self.cidr_scope.throttle_down_total.load(Ordering::Relaxed), + user_wait_up_ms_total: self.user_scope.wait_up_ms_total.load(Ordering::Relaxed), + user_wait_down_ms_total: self.user_scope.wait_down_ms_total.load(Ordering::Relaxed), + cidr_wait_up_ms_total: self.cidr_scope.wait_up_ms_total.load(Ordering::Relaxed), + cidr_wait_down_ms_total: self.cidr_scope.wait_down_ms_total.load(Ordering::Relaxed), + user_active_leases: self.user_scope.active_leases.load(Ordering::Relaxed), + cidr_active_leases: self.cidr_scope.active_leases.load(Ordering::Relaxed), + user_policy_entries: self.user_scope.policy_entries.load(Ordering::Relaxed), + cidr_policy_entries: self.cidr_scope.policy_entries.load(Ordering::Relaxed), + } + } + + fn observe_throttle(&self, direction: RateDirection, blocked_user: bool, blocked_cidr: bool) { + if blocked_user { + self.user_scope.throttle(direction); + } + if blocked_cidr { + self.cidr_scope.throttle(direction); + } + } + + fn observe_wait( + &self, + direction: RateDirection, + blocked_user: bool, + blocked_cidr: bool, + wait_ms: u64, + ) { + if blocked_user { + self.user_scope.wait_ms(direction, wait_ms); + } + if blocked_cidr { + self.cidr_scope.wait_ms(direction, wait_ms); + } + } + + fn maybe_cleanup(&self) { + let now_epoch_secs = now_epoch_secs(); + let last = self.last_cleanup_epoch_secs.load(Ordering::Relaxed); + if now_epoch_secs.saturating_sub(last) < CLEANUP_INTERVAL_SECS { + return; + } + if self + .last_cleanup_epoch_secs + .compare_exchange(last, now_epoch_secs, Ordering::Relaxed, Ordering::Relaxed) + .is_err() + { + return; + } + + let policy = self.policy.load_full(); + self.user_buckets.retain(|user, bucket| { + bucket.active_leases.load(Ordering::Relaxed) > 0 + || policy.user_limits.contains_key(user) + }); + self.cidr_buckets.retain(|cidr_key, bucket| { + bucket.cleanup_idle_users(); + bucket.active_leases.load(Ordering::Relaxed) > 0 + || policy.cidr_rule_keys.contains(cidr_key) + }); + } +} + +pub fn next_refill_delay() -> Duration { + let start = limiter_epoch_start(); + let elapsed_ms = start.elapsed().as_millis() as u64; + let epoch_pos = elapsed_ms % FAIR_EPOCH_MS; + let wait_ms = FAIR_EPOCH_MS.saturating_sub(epoch_pos).max(1); + Duration::from_millis(wait_ms) +} + +fn decrement_atomic_saturating(counter: &AtomicU64, by: u64) { + if by == 0 { + return; + } + let mut current = counter.load(Ordering::Relaxed); + loop { + if current == 0 { + return; + } + let next = current.saturating_sub(by); + match counter.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) { + Ok(_) => return, + Err(actual) => current = actual, + } + } +} + +fn now_epoch_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +fn bytes_per_epoch(bps: u64) -> u64 { + if bps == 0 { + return 0; + } + let numerator = bps.saturating_mul(FAIR_EPOCH_MS); + let bytes = numerator.saturating_div(8_000); + bytes.max(1) +} + +fn current_epoch() -> u64 { + let start = limiter_epoch_start(); + let elapsed_ms = start.elapsed().as_millis() as u64; + elapsed_ms / FAIR_EPOCH_MS +} + +fn limiter_epoch_start() -> &'static Instant { + static START: OnceLock = OnceLock::new(); + START.get_or_init(Instant::now) +} diff --git a/src/stats/mod.rs b/src/stats/mod.rs index 38b22bb..9609e19 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -175,6 +175,18 @@ pub struct Stats { me_route_drop_queue_full: AtomicU64, me_route_drop_queue_full_base: AtomicU64, me_route_drop_queue_full_high: AtomicU64, + me_fair_pressure_state_gauge: AtomicU64, + me_fair_active_flows_gauge: AtomicU64, + me_fair_queued_bytes_gauge: AtomicU64, + me_fair_standing_flows_gauge: AtomicU64, + me_fair_backpressured_flows_gauge: AtomicU64, + me_fair_scheduler_rounds_total: AtomicU64, + me_fair_deficit_grants_total: AtomicU64, + me_fair_deficit_skips_total: AtomicU64, + me_fair_enqueue_rejects_total: AtomicU64, + me_fair_shed_drops_total: AtomicU64, + me_fair_penalties_total: AtomicU64, + me_fair_downstream_stalls_total: AtomicU64, me_d2c_batches_total: AtomicU64, me_d2c_batch_frames_total: AtomicU64, me_d2c_batch_bytes_total: AtomicU64, @@ -856,6 +868,78 @@ impl Stats { .fetch_add(1, Ordering::Relaxed); } } + pub fn set_me_fair_pressure_state_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_fair_pressure_state_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_fair_active_flows_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_fair_active_flows_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_fair_queued_bytes_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_fair_queued_bytes_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_fair_standing_flows_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_fair_standing_flows_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn set_me_fair_backpressured_flows_gauge(&self, value: u64) { + if self.telemetry_me_allows_normal() { + self.me_fair_backpressured_flows_gauge + .store(value, Ordering::Relaxed); + } + } + pub fn add_me_fair_scheduler_rounds_total(&self, value: u64) { + if self.telemetry_me_allows_normal() && value > 0 { + self.me_fair_scheduler_rounds_total + .fetch_add(value, Ordering::Relaxed); + } + } + pub fn add_me_fair_deficit_grants_total(&self, value: u64) { + if self.telemetry_me_allows_normal() && value > 0 { + self.me_fair_deficit_grants_total + .fetch_add(value, Ordering::Relaxed); + } + } + pub fn add_me_fair_deficit_skips_total(&self, value: u64) { + if self.telemetry_me_allows_normal() && value > 0 { + self.me_fair_deficit_skips_total + .fetch_add(value, Ordering::Relaxed); + } + } + pub fn add_me_fair_enqueue_rejects_total(&self, value: u64) { + if self.telemetry_me_allows_normal() && value > 0 { + self.me_fair_enqueue_rejects_total + .fetch_add(value, Ordering::Relaxed); + } + } + pub fn add_me_fair_shed_drops_total(&self, value: u64) { + if self.telemetry_me_allows_normal() && value > 0 { + self.me_fair_shed_drops_total + .fetch_add(value, Ordering::Relaxed); + } + } + pub fn add_me_fair_penalties_total(&self, value: u64) { + if self.telemetry_me_allows_normal() && value > 0 { + self.me_fair_penalties_total + .fetch_add(value, Ordering::Relaxed); + } + } + pub fn add_me_fair_downstream_stalls_total(&self, value: u64) { + if self.telemetry_me_allows_normal() && value > 0 { + self.me_fair_downstream_stalls_total + .fetch_add(value, Ordering::Relaxed); + } + } pub fn increment_me_d2c_batches_total(&self) { if self.telemetry_me_allows_normal() { self.me_d2c_batches_total.fetch_add(1, Ordering::Relaxed); @@ -1806,6 +1890,43 @@ impl Stats { pub fn get_me_route_drop_queue_full_high(&self) -> u64 { self.me_route_drop_queue_full_high.load(Ordering::Relaxed) } + pub fn get_me_fair_pressure_state_gauge(&self) -> u64 { + self.me_fair_pressure_state_gauge.load(Ordering::Relaxed) + } + pub fn get_me_fair_active_flows_gauge(&self) -> u64 { + self.me_fair_active_flows_gauge.load(Ordering::Relaxed) + } + pub fn get_me_fair_queued_bytes_gauge(&self) -> u64 { + self.me_fair_queued_bytes_gauge.load(Ordering::Relaxed) + } + pub fn get_me_fair_standing_flows_gauge(&self) -> u64 { + self.me_fair_standing_flows_gauge.load(Ordering::Relaxed) + } + pub fn get_me_fair_backpressured_flows_gauge(&self) -> u64 { + self.me_fair_backpressured_flows_gauge + .load(Ordering::Relaxed) + } + pub fn get_me_fair_scheduler_rounds_total(&self) -> u64 { + self.me_fair_scheduler_rounds_total.load(Ordering::Relaxed) + } + pub fn get_me_fair_deficit_grants_total(&self) -> u64 { + self.me_fair_deficit_grants_total.load(Ordering::Relaxed) + } + pub fn get_me_fair_deficit_skips_total(&self) -> u64 { + self.me_fair_deficit_skips_total.load(Ordering::Relaxed) + } + pub fn get_me_fair_enqueue_rejects_total(&self) -> u64 { + self.me_fair_enqueue_rejects_total.load(Ordering::Relaxed) + } + pub fn get_me_fair_shed_drops_total(&self) -> u64 { + self.me_fair_shed_drops_total.load(Ordering::Relaxed) + } + pub fn get_me_fair_penalties_total(&self) -> u64 { + self.me_fair_penalties_total.load(Ordering::Relaxed) + } + pub fn get_me_fair_downstream_stalls_total(&self) -> u64 { + self.me_fair_downstream_stalls_total.load(Ordering::Relaxed) + } pub fn get_me_d2c_batches_total(&self) -> u64 { self.me_d2c_batches_total.load(Ordering::Relaxed) } diff --git a/src/tls_front/emulator.rs b/src/tls_front/emulator.rs index 290e203..a23373d 100644 --- a/src/tls_front/emulator.rs +++ b/src/tls_front/emulator.rs @@ -11,6 +11,7 @@ use crc32fast::Hasher; const MIN_APP_DATA: usize = 64; const MAX_APP_DATA: usize = MAX_TLS_CIPHERTEXT_SIZE; +const MAX_TICKET_RECORDS: usize = 4; fn jitter_and_clamp_sizes(sizes: &[usize], rng: &SecureRandom) -> Vec { sizes @@ -62,6 +63,53 @@ fn ensure_payload_capacity(mut sizes: Vec, payload_len: usize) -> Vec Vec { + match cached.behavior_profile.source { + TlsProfileSource::Raw | TlsProfileSource::Merged => { + if !cached.behavior_profile.app_data_record_sizes.is_empty() { + return cached.behavior_profile.app_data_record_sizes.clone(); + } + } + TlsProfileSource::Default | TlsProfileSource::Rustls => {} + } + + let mut sizes = cached.app_data_records_sizes.clone(); + if sizes.is_empty() { + sizes.push(cached.total_app_data_len.max(1024)); + } + sizes +} + +fn emulated_change_cipher_spec_count(cached: &CachedTlsData) -> usize { + usize::from(cached.behavior_profile.change_cipher_spec_count.max(1)) +} + +fn emulated_ticket_record_sizes( + cached: &CachedTlsData, + new_session_tickets: u8, + rng: &SecureRandom, +) -> Vec { + let mut sizes = match cached.behavior_profile.source { + TlsProfileSource::Raw | TlsProfileSource::Merged => { + cached.behavior_profile.ticket_record_sizes.clone() + } + TlsProfileSource::Default | TlsProfileSource::Rustls => Vec::new(), + }; + + let target_count = sizes + .len() + .max(usize::from( + new_session_tickets.min(MAX_TICKET_RECORDS as u8), + )) + .min(MAX_TICKET_RECORDS); + + while sizes.len() < target_count { + sizes.push(rng.range(48) + 48); + } + + sizes +} + fn build_compact_cert_info_payload(cert_info: &ParsedCertificateInfo) -> Option> { let mut fields = Vec::new(); @@ -180,39 +228,21 @@ pub fn build_emulated_server_hello( server_hello.extend_from_slice(&message); // --- ChangeCipherSpec --- - let change_cipher_spec = [ - TLS_RECORD_CHANGE_CIPHER, - TLS_VERSION[0], - TLS_VERSION[1], - 0x00, - 0x01, - 0x01, - ]; + let change_cipher_spec_count = emulated_change_cipher_spec_count(cached); + let mut change_cipher_spec = Vec::with_capacity(change_cipher_spec_count * 6); + for _ in 0..change_cipher_spec_count { + change_cipher_spec.extend_from_slice(&[ + TLS_RECORD_CHANGE_CIPHER, + TLS_VERSION[0], + TLS_VERSION[1], + 0x00, + 0x01, + 0x01, + ]); + } // --- ApplicationData (fake encrypted records) --- - let sizes = match cached.behavior_profile.source { - TlsProfileSource::Raw | TlsProfileSource::Merged => cached - .app_data_records_sizes - .first() - .copied() - .or_else(|| { - cached - .behavior_profile - .app_data_record_sizes - .first() - .copied() - }) - .map(|size| vec![size]) - .unwrap_or_else(|| vec![cached.total_app_data_len.max(1024)]), - _ => { - let mut sizes = cached.app_data_records_sizes.clone(); - if sizes.is_empty() { - sizes.push(cached.total_app_data_len.max(1024)); - } - sizes - } - }; - let mut sizes = jitter_and_clamp_sizes(&sizes, rng); + let mut sizes = jitter_and_clamp_sizes(&emulated_app_data_sizes(cached), rng); let compact_payload = cached .cert_info .as_ref() @@ -299,17 +329,13 @@ pub fn build_emulated_server_hello( // --- Combine --- // Optional NewSessionTicket mimic records (opaque ApplicationData for fingerprint). let mut tickets = Vec::new(); - let ticket_count = new_session_tickets.min(4); - if ticket_count > 0 { - for _ in 0..ticket_count { - let ticket_len: usize = rng.range(48) + 48; - let mut rec = Vec::with_capacity(5 + ticket_len); - rec.push(TLS_RECORD_APPLICATION); - rec.extend_from_slice(&TLS_VERSION); - rec.extend_from_slice(&(ticket_len as u16).to_be_bytes()); - rec.extend_from_slice(&rng.bytes(ticket_len)); - tickets.extend_from_slice(&rec); - } + for ticket_len in emulated_ticket_record_sizes(cached, new_session_tickets, rng) { + let mut rec = Vec::with_capacity(5 + ticket_len); + rec.push(TLS_RECORD_APPLICATION); + rec.extend_from_slice(&TLS_VERSION); + rec.extend_from_slice(&(ticket_len as u16).to_be_bytes()); + rec.extend_from_slice(&rng.bytes(ticket_len)); + tickets.extend_from_slice(&rec); } let mut response = Vec::with_capacity( @@ -334,6 +360,10 @@ pub fn build_emulated_server_hello( #[path = "tests/emulator_security_tests.rs"] mod security_tests; +#[cfg(test)] +#[path = "tests/emulator_profile_fidelity_security_tests.rs"] +mod emulator_profile_fidelity_security_tests; + #[cfg(test)] mod tests { use std::time::SystemTime; @@ -478,7 +508,7 @@ mod tests { } #[test] - fn test_build_emulated_server_hello_ignores_tail_records_for_raw_profile() { + fn test_build_emulated_server_hello_replays_tail_records_for_profiled_tls() { let mut cached = make_cached(None); cached.app_data_records_sizes = vec![27, 3905, 537, 69]; cached.total_app_data_len = 4538; @@ -500,11 +530,19 @@ mod tests { let hello_len = u16::from_be_bytes([response[3], response[4]]) as usize; let ccs_start = 5 + hello_len; - let app_start = ccs_start + 6; - let app_len = - u16::from_be_bytes([response[app_start + 3], response[app_start + 4]]) as usize; + let mut pos = ccs_start + 6; + let mut app_lengths = Vec::new(); + while pos + 5 <= response.len() { + assert_eq!(response[pos], TLS_RECORD_APPLICATION); + let record_len = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; + app_lengths.push(record_len); + pos += 5 + record_len; + } - assert_eq!(response[app_start], TLS_RECORD_APPLICATION); - assert_eq!(app_start + 5 + app_len, response.len()); + assert_eq!(app_lengths.len(), 4); + assert_eq!(app_lengths[0], 64); + assert_eq!(app_lengths[3], 69); + assert!(app_lengths[1] >= 64); + assert!(app_lengths[2] >= 64); } } diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs index 45d56ce..aad956e 100644 --- a/src/tls_front/fetcher.rs +++ b/src/tls_front/fetcher.rs @@ -1,6 +1,7 @@ #![allow(clippy::too_many_arguments)] use dashmap::DashMap; +use std::net::SocketAddr; use std::sync::Arc; use std::sync::OnceLock; use std::time::{Duration, Instant}; @@ -793,6 +794,51 @@ async fn connect_tcp_with_upstream( )) } +fn socket_addrs_from_upstream_stream( + stream: &UpstreamStream, +) -> (Option, Option) { + match stream { + UpstreamStream::Tcp(tcp) => (tcp.local_addr().ok(), tcp.peer_addr().ok()), + UpstreamStream::Shadowsocks(_) => (None, None), + } +} + +fn build_tls_fetch_proxy_header( + proxy_protocol: u8, + src_addr: Option, + dst_addr: Option, +) -> Option> { + match proxy_protocol { + 0 => None, + 2 => { + let header = match (src_addr, dst_addr) { + (Some(src @ SocketAddr::V4(_)), Some(dst @ SocketAddr::V4(_))) + | (Some(src @ SocketAddr::V6(_)), Some(dst @ SocketAddr::V6(_))) => { + ProxyProtocolV2Builder::new().with_addrs(src, dst).build() + } + _ => ProxyProtocolV2Builder::new().build(), + }; + Some(header) + } + _ => { + let header = match (src_addr, dst_addr) { + (Some(SocketAddr::V4(src)), Some(SocketAddr::V4(dst))) => { + ProxyProtocolV1Builder::new() + .tcp4(src.into(), dst.into()) + .build() + } + (Some(SocketAddr::V6(src)), Some(SocketAddr::V6(dst))) => { + ProxyProtocolV1Builder::new() + .tcp6(src.into(), dst.into()) + .build() + } + _ => ProxyProtocolV1Builder::new().build(), + }; + Some(header) + } + } +} + fn encode_tls13_certificate_message(cert_chain_der: &[Vec]) -> Option> { if cert_chain_der.is_empty() { return None; @@ -824,7 +870,7 @@ async fn fetch_via_raw_tls_stream( mut stream: S, sni: &str, connect_timeout: Duration, - proxy_protocol: u8, + proxy_header: Option>, profile: TlsFetchProfile, grease_enabled: bool, deterministic: bool, @@ -835,11 +881,7 @@ where let rng = SecureRandom::new(); let client_hello = build_client_hello(sni, &rng, profile, grease_enabled, deterministic); timeout(connect_timeout, async { - if proxy_protocol > 0 { - let header = match proxy_protocol { - 2 => ProxyProtocolV2Builder::new().build(), - _ => ProxyProtocolV1Builder::new().build(), - }; + if let Some(header) = proxy_header.as_ref() { stream.write_all(&header).await?; } stream.write_all(&client_hello).await?; @@ -921,11 +963,12 @@ async fn fetch_via_raw_tls( sock = %sock_path, "Raw TLS fetch using mask unix socket" ); + let proxy_header = build_tls_fetch_proxy_header(proxy_protocol, None, None); return fetch_via_raw_tls_stream( stream, sni, connect_timeout, - proxy_protocol, + proxy_header, profile, grease_enabled, deterministic, @@ -956,11 +999,13 @@ async fn fetch_via_raw_tls( let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route) .await?; + let (src_addr, dst_addr) = socket_addrs_from_upstream_stream(&stream); + let proxy_header = build_tls_fetch_proxy_header(proxy_protocol, src_addr, dst_addr); fetch_via_raw_tls_stream( stream, sni, connect_timeout, - proxy_protocol, + proxy_header, profile, grease_enabled, deterministic, @@ -972,17 +1017,13 @@ async fn fetch_via_rustls_stream( mut stream: S, host: &str, sni: &str, - proxy_protocol: u8, + proxy_header: Option>, ) -> Result where S: AsyncRead + AsyncWrite + Unpin, { // rustls handshake path for certificate and basic negotiated metadata. - if proxy_protocol > 0 { - let header = match proxy_protocol { - 2 => ProxyProtocolV2Builder::new().build(), - _ => ProxyProtocolV1Builder::new().build(), - }; + if let Some(header) = proxy_header.as_ref() { stream.write_all(&header).await?; stream.flush().await?; } @@ -1082,7 +1123,8 @@ async fn fetch_via_rustls( sock = %sock_path, "Rustls fetch using mask unix socket" ); - return fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await; + let proxy_header = build_tls_fetch_proxy_header(proxy_protocol, None, None); + return fetch_via_rustls_stream(stream, host, sni, proxy_header).await; } Ok(Err(e)) => { warn!( @@ -1108,7 +1150,9 @@ async fn fetch_via_rustls( let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route) .await?; - fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await + let (src_addr, dst_addr) = socket_addrs_from_upstream_stream(&stream); + let proxy_header = build_tls_fetch_proxy_header(proxy_protocol, src_addr, dst_addr); + fetch_via_rustls_stream(stream, host, sni, proxy_header).await } /// Fetch real TLS metadata with an adaptive multi-profile strategy. @@ -1278,11 +1322,13 @@ pub async fn fetch_real_tls( #[cfg(test)] mod tests { + use std::net::SocketAddr; use std::time::{Duration, Instant}; use super::{ - ProfileCacheValue, TlsFetchStrategy, build_client_hello, derive_behavior_profile, - encode_tls13_certificate_message, order_profiles, profile_cache, profile_cache_key, + ProfileCacheValue, TlsFetchStrategy, build_client_hello, build_tls_fetch_proxy_header, + derive_behavior_profile, encode_tls13_certificate_message, order_profiles, profile_cache, + profile_cache_key, }; use crate::config::TlsFetchProfile; use crate::crypto::SecureRandom; @@ -1423,4 +1469,48 @@ mod tests { assert_eq!(first, second); } + + #[test] + fn test_build_tls_fetch_proxy_header_v2_with_tcp_addrs() { + let src: SocketAddr = "198.51.100.10:42000".parse().expect("valid src"); + let dst: SocketAddr = "203.0.113.20:443".parse().expect("valid dst"); + let header = build_tls_fetch_proxy_header(2, Some(src), Some(dst)).expect("header"); + + assert_eq!( + &header[..12], + &[ + 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a + ] + ); + assert_eq!(header[12], 0x21); + assert_eq!(header[13], 0x11); + assert_eq!(u16::from_be_bytes([header[14], header[15]]), 12); + assert_eq!(&header[16..20], &[198, 51, 100, 10]); + assert_eq!(&header[20..24], &[203, 0, 113, 20]); + assert_eq!(u16::from_be_bytes([header[24], header[25]]), 42000); + assert_eq!(u16::from_be_bytes([header[26], header[27]]), 443); + } + + #[test] + fn test_build_tls_fetch_proxy_header_v2_mixed_family_falls_back_to_local_command() { + let src: SocketAddr = "198.51.100.10:42000".parse().expect("valid src"); + let dst: SocketAddr = "[2001:db8::20]:443".parse().expect("valid dst"); + let header = build_tls_fetch_proxy_header(2, Some(src), Some(dst)).expect("header"); + + assert_eq!(header[12], 0x20); + assert_eq!(header[13], 0x00); + assert_eq!(u16::from_be_bytes([header[14], header[15]]), 0); + } + + #[test] + fn test_build_tls_fetch_proxy_header_v1_with_tcp_addrs() { + let src: SocketAddr = "198.51.100.10:42000".parse().expect("valid src"); + let dst: SocketAddr = "203.0.113.20:443".parse().expect("valid dst"); + let header = build_tls_fetch_proxy_header(1, Some(src), Some(dst)).expect("header"); + + assert_eq!( + header, + b"PROXY TCP4 198.51.100.10 203.0.113.20 42000 443\r\n" + ); + } } diff --git a/src/tls_front/tests/emulator_profile_fidelity_security_tests.rs b/src/tls_front/tests/emulator_profile_fidelity_security_tests.rs new file mode 100644 index 0000000..694fd76 --- /dev/null +++ b/src/tls_front/tests/emulator_profile_fidelity_security_tests.rs @@ -0,0 +1,95 @@ +use std::time::SystemTime; + +use crate::crypto::SecureRandom; +use crate::protocol::constants::{ + TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, +}; +use crate::tls_front::emulator::build_emulated_server_hello; +use crate::tls_front::types::{ + CachedTlsData, ParsedServerHello, TlsBehaviorProfile, TlsProfileSource, +}; + +fn make_cached() -> CachedTlsData { + CachedTlsData { + server_hello_template: ParsedServerHello { + version: [0x03, 0x03], + random: [0u8; 32], + session_id: Vec::new(), + cipher_suite: [0x13, 0x01], + compression: 0, + extensions: Vec::new(), + }, + cert_info: None, + cert_payload: None, + app_data_records_sizes: vec![1200, 900, 220, 180], + total_app_data_len: 2500, + behavior_profile: TlsBehaviorProfile { + change_cipher_spec_count: 2, + app_data_record_sizes: vec![1200, 900], + ticket_record_sizes: vec![220, 180], + source: TlsProfileSource::Merged, + }, + fetched_at: SystemTime::now(), + domain: "example.com".to_string(), + } +} + +fn record_lengths_by_type(response: &[u8], wanted_type: u8) -> Vec { + let mut out = Vec::new(); + let mut pos = 0usize; + while pos + 5 <= response.len() { + let record_type = response[pos]; + let record_len = u16::from_be_bytes([response[pos + 3], response[pos + 4]]) as usize; + if pos + 5 + record_len > response.len() { + break; + } + if record_type == wanted_type { + out.push(record_len); + } + pos += 5 + record_len; + } + out +} + +#[test] +fn emulated_server_hello_replays_profile_change_cipher_spec_count() { + let cached = make_cached(); + let rng = SecureRandom::new(); + + let response = build_emulated_server_hello( + b"secret", + &[0x71; 32], + &[0x72; 16], + &cached, + false, + &rng, + None, + 0, + ); + + assert_eq!(response[0], TLS_RECORD_HANDSHAKE); + let ccs_records = record_lengths_by_type(&response, TLS_RECORD_CHANGE_CIPHER); + assert_eq!(ccs_records.len(), 2); + assert!(ccs_records.iter().all(|len| *len == 1)); +} + +#[test] +fn emulated_server_hello_replays_profile_ticket_tail_lengths() { + let cached = make_cached(); + let rng = SecureRandom::new(); + + let response = build_emulated_server_hello( + b"secret", + &[0x81; 32], + &[0x82; 16], + &cached, + false, + &rng, + None, + 0, + ); + + let app_records = record_lengths_by_type(&response, TLS_RECORD_APPLICATION); + assert!(app_records.len() >= 4); + assert_eq!(&app_records[app_records.len() - 2..], &[220, 180]); +} diff --git a/src/transport/middle_proxy/fairness/mod.rs b/src/transport/middle_proxy/fairness/mod.rs new file mode 100644 index 0000000..58eb890 --- /dev/null +++ b/src/transport/middle_proxy/fairness/mod.rs @@ -0,0 +1,13 @@ +//! Backpressure-driven fairness control for ME reader routing. +//! +//! This module keeps fairness decisions worker-local: +//! each reader loop owns one scheduler instance and mutates it without locks. + +mod model; +mod pressure; +mod scheduler; + +#[cfg(test)] +pub(crate) use model::PressureState; +pub(crate) use model::{AdmissionDecision, DispatchAction, DispatchFeedback, SchedulerDecision}; +pub(crate) use scheduler::{WorkerFairnessConfig, WorkerFairnessSnapshot, WorkerFairnessState}; diff --git a/src/transport/middle_proxy/fairness/model.rs b/src/transport/middle_proxy/fairness/model.rs new file mode 100644 index 0000000..bdf4f9f --- /dev/null +++ b/src/transport/middle_proxy/fairness/model.rs @@ -0,0 +1,140 @@ +use std::time::Instant; + +use bytes::Bytes; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[repr(u8)] +pub(crate) enum PressureState { + Normal = 0, + Pressured = 1, + Shedding = 2, + Saturated = 3, +} + +impl PressureState { + pub(crate) fn as_u8(self) -> u8 { + self as u8 + } +} + +impl Default for PressureState { + fn default() -> Self { + Self::Normal + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum FlowPressureClass { + Healthy, + Bursty, + Backpressured, + Standing, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum StandingQueueState { + Transient, + Standing, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum FlowSchedulerState { + Idle, + Active, + Backpressured, + Penalized, + SheddingCandidate, +} + +#[derive(Debug, Clone)] +pub(crate) struct QueuedFrame { + pub(crate) conn_id: u64, + pub(crate) flags: u32, + pub(crate) data: Bytes, + pub(crate) enqueued_at: Instant, +} + +impl QueuedFrame { + #[inline] + pub(crate) fn queued_bytes(&self) -> u64 { + self.data.len() as u64 + } +} + +#[derive(Debug, Clone)] +pub(crate) struct FlowFairnessState { + pub(crate) _flow_id: u64, + pub(crate) _worker_id: u16, + pub(crate) pending_bytes: u64, + pub(crate) deficit_bytes: i64, + pub(crate) queue_started_at: Option, + pub(crate) last_drain_at: Option, + pub(crate) recent_drain_bytes: u64, + pub(crate) consecutive_stalls: u8, + pub(crate) consecutive_skips: u8, + pub(crate) penalty_score: u16, + pub(crate) pressure_class: FlowPressureClass, + pub(crate) standing_state: StandingQueueState, + pub(crate) scheduler_state: FlowSchedulerState, + pub(crate) bucket_id: usize, + pub(crate) in_active_ring: bool, +} + +impl FlowFairnessState { + pub(crate) fn new(flow_id: u64, worker_id: u16, bucket_id: usize) -> Self { + Self { + _flow_id: flow_id, + _worker_id: worker_id, + pending_bytes: 0, + deficit_bytes: 0, + queue_started_at: None, + last_drain_at: None, + recent_drain_bytes: 0, + consecutive_stalls: 0, + consecutive_skips: 0, + penalty_score: 0, + pressure_class: FlowPressureClass::Healthy, + standing_state: StandingQueueState::Transient, + scheduler_state: FlowSchedulerState::Idle, + bucket_id, + in_active_ring: false, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum AdmissionDecision { + Admit, + RejectWorkerCap, + RejectFlowCap, + RejectBucketCap, + RejectSaturated, + RejectStandingFlow, +} + +#[derive(Debug, Clone)] +pub(crate) enum SchedulerDecision { + Idle, + Dispatch(DispatchCandidate), +} + +#[derive(Debug, Clone)] +pub(crate) struct DispatchCandidate { + pub(crate) frame: QueuedFrame, + pub(crate) pressure_state: PressureState, + pub(crate) flow_class: FlowPressureClass, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum DispatchFeedback { + Routed, + QueueFull, + ChannelClosed, + NoConn, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum DispatchAction { + Continue, + CloseFlow, +} diff --git a/src/transport/middle_proxy/fairness/pressure.rs b/src/transport/middle_proxy/fairness/pressure.rs new file mode 100644 index 0000000..02a5942 --- /dev/null +++ b/src/transport/middle_proxy/fairness/pressure.rs @@ -0,0 +1,214 @@ +use std::time::{Duration, Instant}; + +use super::model::PressureState; + +#[derive(Debug, Clone, Copy)] +pub(crate) struct PressureSignals { + pub(crate) active_flows: usize, + pub(crate) total_queued_bytes: u64, + pub(crate) standing_flows: usize, + pub(crate) backpressured_flows: usize, +} + +#[derive(Debug, Clone)] +pub(crate) struct PressureConfig { + pub(crate) evaluate_every_rounds: u32, + pub(crate) transition_hysteresis_rounds: u8, + pub(crate) standing_ratio_pressured_pct: u8, + pub(crate) standing_ratio_shedding_pct: u8, + pub(crate) standing_ratio_saturated_pct: u8, + pub(crate) queue_ratio_pressured_pct: u8, + pub(crate) queue_ratio_shedding_pct: u8, + pub(crate) queue_ratio_saturated_pct: u8, + pub(crate) reject_window: Duration, + pub(crate) rejects_pressured: u32, + pub(crate) rejects_shedding: u32, + pub(crate) rejects_saturated: u32, + pub(crate) stalls_pressured: u32, + pub(crate) stalls_shedding: u32, + pub(crate) stalls_saturated: u32, +} + +impl Default for PressureConfig { + fn default() -> Self { + Self { + evaluate_every_rounds: 8, + transition_hysteresis_rounds: 3, + standing_ratio_pressured_pct: 20, + standing_ratio_shedding_pct: 35, + standing_ratio_saturated_pct: 50, + queue_ratio_pressured_pct: 65, + queue_ratio_shedding_pct: 82, + queue_ratio_saturated_pct: 94, + reject_window: Duration::from_secs(2), + rejects_pressured: 32, + rejects_shedding: 96, + rejects_saturated: 256, + stalls_pressured: 32, + stalls_shedding: 96, + stalls_saturated: 256, + } + } +} + +#[derive(Debug)] +pub(crate) struct PressureEvaluator { + state: PressureState, + candidate_state: PressureState, + candidate_hits: u8, + rounds_since_eval: u32, + window_started_at: Instant, + admission_rejects_window: u32, + route_stalls_window: u32, +} + +impl PressureEvaluator { + pub(crate) fn new(now: Instant) -> Self { + Self { + state: PressureState::Normal, + candidate_state: PressureState::Normal, + candidate_hits: 0, + rounds_since_eval: 0, + window_started_at: now, + admission_rejects_window: 0, + route_stalls_window: 0, + } + } + + #[inline] + pub(crate) fn state(&self) -> PressureState { + self.state + } + + pub(crate) fn note_admission_reject(&mut self, now: Instant, cfg: &PressureConfig) { + self.rotate_window_if_needed(now, cfg); + self.admission_rejects_window = self.admission_rejects_window.saturating_add(1); + } + + pub(crate) fn note_route_stall(&mut self, now: Instant, cfg: &PressureConfig) { + self.rotate_window_if_needed(now, cfg); + self.route_stalls_window = self.route_stalls_window.saturating_add(1); + } + + pub(crate) fn maybe_evaluate( + &mut self, + now: Instant, + cfg: &PressureConfig, + max_total_queued_bytes: u64, + signals: PressureSignals, + force: bool, + ) -> PressureState { + self.rotate_window_if_needed(now, cfg); + self.rounds_since_eval = self.rounds_since_eval.saturating_add(1); + if !force && self.rounds_since_eval < cfg.evaluate_every_rounds.max(1) { + return self.state; + } + self.rounds_since_eval = 0; + + let target = self.derive_target_state(cfg, max_total_queued_bytes, signals); + if target == self.state { + self.candidate_state = target; + self.candidate_hits = 0; + return self.state; + } + + if self.candidate_state == target { + self.candidate_hits = self.candidate_hits.saturating_add(1); + } else { + self.candidate_state = target; + self.candidate_hits = 1; + } + + if self.candidate_hits >= cfg.transition_hysteresis_rounds.max(1) { + self.state = target; + self.candidate_hits = 0; + } + + self.state + } + + fn derive_target_state( + &self, + cfg: &PressureConfig, + max_total_queued_bytes: u64, + signals: PressureSignals, + ) -> PressureState { + let queue_ratio_pct = if max_total_queued_bytes == 0 { + 100 + } else { + ((signals.total_queued_bytes.saturating_mul(100)) / max_total_queued_bytes).min(100) + as u8 + }; + + let standing_ratio_pct = if signals.active_flows == 0 { + 0 + } else { + ((signals.standing_flows.saturating_mul(100)) / signals.active_flows).min(100) as u8 + }; + + let mut pressure_score = 0u8; + + if queue_ratio_pct >= cfg.queue_ratio_pressured_pct { + pressure_score = pressure_score.max(1); + } + if queue_ratio_pct >= cfg.queue_ratio_shedding_pct { + pressure_score = pressure_score.max(2); + } + if queue_ratio_pct >= cfg.queue_ratio_saturated_pct { + pressure_score = pressure_score.max(3); + } + + if standing_ratio_pct >= cfg.standing_ratio_pressured_pct { + pressure_score = pressure_score.max(1); + } + if standing_ratio_pct >= cfg.standing_ratio_shedding_pct { + pressure_score = pressure_score.max(2); + } + if standing_ratio_pct >= cfg.standing_ratio_saturated_pct { + pressure_score = pressure_score.max(3); + } + + if self.admission_rejects_window >= cfg.rejects_pressured { + pressure_score = pressure_score.max(1); + } + if self.admission_rejects_window >= cfg.rejects_shedding { + pressure_score = pressure_score.max(2); + } + if self.admission_rejects_window >= cfg.rejects_saturated { + pressure_score = pressure_score.max(3); + } + + if self.route_stalls_window >= cfg.stalls_pressured { + pressure_score = pressure_score.max(1); + } + if self.route_stalls_window >= cfg.stalls_shedding { + pressure_score = pressure_score.max(2); + } + if self.route_stalls_window >= cfg.stalls_saturated { + pressure_score = pressure_score.max(3); + } + + if signals.backpressured_flows > signals.active_flows.saturating_div(2) + && signals.active_flows > 0 + { + pressure_score = pressure_score.max(2); + } + + match pressure_score { + 0 => PressureState::Normal, + 1 => PressureState::Pressured, + 2 => PressureState::Shedding, + _ => PressureState::Saturated, + } + } + + fn rotate_window_if_needed(&mut self, now: Instant, cfg: &PressureConfig) { + if now.saturating_duration_since(self.window_started_at) < cfg.reject_window { + return; + } + + self.window_started_at = now; + self.admission_rejects_window = 0; + self.route_stalls_window = 0; + } +} diff --git a/src/transport/middle_proxy/fairness/scheduler.rs b/src/transport/middle_proxy/fairness/scheduler.rs new file mode 100644 index 0000000..8da3636 --- /dev/null +++ b/src/transport/middle_proxy/fairness/scheduler.rs @@ -0,0 +1,556 @@ +use std::collections::{HashMap, VecDeque}; +use std::time::{Duration, Instant}; + +use bytes::Bytes; + +use super::model::{ + AdmissionDecision, DispatchAction, DispatchCandidate, DispatchFeedback, FlowFairnessState, + FlowPressureClass, FlowSchedulerState, PressureState, QueuedFrame, SchedulerDecision, + StandingQueueState, +}; +use super::pressure::{PressureConfig, PressureEvaluator, PressureSignals}; + +#[derive(Debug, Clone)] +pub(crate) struct WorkerFairnessConfig { + pub(crate) worker_id: u16, + pub(crate) max_active_flows: usize, + pub(crate) max_total_queued_bytes: u64, + pub(crate) max_flow_queued_bytes: u64, + pub(crate) base_quantum_bytes: u32, + pub(crate) pressured_quantum_bytes: u32, + pub(crate) penalized_quantum_bytes: u32, + pub(crate) standing_queue_min_age: Duration, + pub(crate) standing_queue_min_backlog_bytes: u64, + pub(crate) standing_stall_threshold: u8, + pub(crate) max_consecutive_stalls_before_shed: u8, + pub(crate) max_consecutive_stalls_before_close: u8, + pub(crate) soft_bucket_count: usize, + pub(crate) soft_bucket_share_pct: u8, + pub(crate) pressure: PressureConfig, +} + +impl Default for WorkerFairnessConfig { + fn default() -> Self { + Self { + worker_id: 0, + max_active_flows: 4096, + max_total_queued_bytes: 16 * 1024 * 1024, + max_flow_queued_bytes: 512 * 1024, + base_quantum_bytes: 32 * 1024, + pressured_quantum_bytes: 16 * 1024, + penalized_quantum_bytes: 8 * 1024, + standing_queue_min_age: Duration::from_millis(250), + standing_queue_min_backlog_bytes: 64 * 1024, + standing_stall_threshold: 3, + max_consecutive_stalls_before_shed: 4, + max_consecutive_stalls_before_close: 16, + soft_bucket_count: 64, + soft_bucket_share_pct: 25, + pressure: PressureConfig::default(), + } + } +} + +struct FlowEntry { + fairness: FlowFairnessState, + queue: VecDeque, +} + +impl FlowEntry { + fn new(flow_id: u64, worker_id: u16, bucket_id: usize) -> Self { + Self { + fairness: FlowFairnessState::new(flow_id, worker_id, bucket_id), + queue: VecDeque::new(), + } + } +} + +#[derive(Debug, Clone, Copy, Default)] +pub(crate) struct WorkerFairnessSnapshot { + pub(crate) pressure_state: PressureState, + pub(crate) active_flows: usize, + pub(crate) total_queued_bytes: u64, + pub(crate) standing_flows: usize, + pub(crate) backpressured_flows: usize, + pub(crate) scheduler_rounds: u64, + pub(crate) deficit_grants: u64, + pub(crate) deficit_skips: u64, + pub(crate) enqueue_rejects: u64, + pub(crate) shed_drops: u64, + pub(crate) fairness_penalties: u64, + pub(crate) downstream_stalls: u64, +} + +pub(crate) struct WorkerFairnessState { + config: WorkerFairnessConfig, + pressure: PressureEvaluator, + flows: HashMap, + active_ring: VecDeque, + total_queued_bytes: u64, + bucket_queued_bytes: Vec, + bucket_active_flows: Vec, + standing_flow_count: usize, + backpressured_flow_count: usize, + scheduler_rounds: u64, + deficit_grants: u64, + deficit_skips: u64, + enqueue_rejects: u64, + shed_drops: u64, + fairness_penalties: u64, + downstream_stalls: u64, +} + +impl WorkerFairnessState { + pub(crate) fn new(config: WorkerFairnessConfig, now: Instant) -> Self { + let bucket_count = config.soft_bucket_count.max(1); + Self { + config, + pressure: PressureEvaluator::new(now), + flows: HashMap::new(), + active_ring: VecDeque::new(), + total_queued_bytes: 0, + bucket_queued_bytes: vec![0; bucket_count], + bucket_active_flows: vec![0; bucket_count], + standing_flow_count: 0, + backpressured_flow_count: 0, + scheduler_rounds: 0, + deficit_grants: 0, + deficit_skips: 0, + enqueue_rejects: 0, + shed_drops: 0, + fairness_penalties: 0, + downstream_stalls: 0, + } + } + + pub(crate) fn pressure_state(&self) -> PressureState { + self.pressure.state() + } + + pub(crate) fn snapshot(&self) -> WorkerFairnessSnapshot { + WorkerFairnessSnapshot { + pressure_state: self.pressure.state(), + active_flows: self.flows.len(), + total_queued_bytes: self.total_queued_bytes, + standing_flows: self.standing_flow_count, + backpressured_flows: self.backpressured_flow_count, + scheduler_rounds: self.scheduler_rounds, + deficit_grants: self.deficit_grants, + deficit_skips: self.deficit_skips, + enqueue_rejects: self.enqueue_rejects, + shed_drops: self.shed_drops, + fairness_penalties: self.fairness_penalties, + downstream_stalls: self.downstream_stalls, + } + } + + pub(crate) fn enqueue_data( + &mut self, + conn_id: u64, + flags: u32, + data: Bytes, + now: Instant, + ) -> AdmissionDecision { + let frame = QueuedFrame { + conn_id, + flags, + data, + enqueued_at: now, + }; + let frame_bytes = frame.queued_bytes(); + + if self.pressure.state() == PressureState::Saturated { + self.pressure + .note_admission_reject(now, &self.config.pressure); + self.enqueue_rejects = self.enqueue_rejects.saturating_add(1); + return AdmissionDecision::RejectSaturated; + } + + if self.total_queued_bytes.saturating_add(frame_bytes) > self.config.max_total_queued_bytes + { + self.pressure + .note_admission_reject(now, &self.config.pressure); + self.enqueue_rejects = self.enqueue_rejects.saturating_add(1); + self.evaluate_pressure(now, true); + return AdmissionDecision::RejectWorkerCap; + } + + if !self.flows.contains_key(&conn_id) && self.flows.len() >= self.config.max_active_flows { + self.pressure + .note_admission_reject(now, &self.config.pressure); + self.enqueue_rejects = self.enqueue_rejects.saturating_add(1); + self.evaluate_pressure(now, true); + return AdmissionDecision::RejectWorkerCap; + } + + let bucket_id = self.bucket_for(conn_id); + let bucket_cap = self + .config + .max_total_queued_bytes + .saturating_mul(self.config.soft_bucket_share_pct.max(1) as u64) + .saturating_div(100) + .max(self.config.max_flow_queued_bytes); + if self.bucket_queued_bytes[bucket_id].saturating_add(frame_bytes) > bucket_cap { + self.pressure + .note_admission_reject(now, &self.config.pressure); + self.enqueue_rejects = self.enqueue_rejects.saturating_add(1); + self.evaluate_pressure(now, true); + return AdmissionDecision::RejectBucketCap; + } + + let entry = if let Some(flow) = self.flows.get_mut(&conn_id) { + flow + } else { + self.bucket_active_flows[bucket_id] = + self.bucket_active_flows[bucket_id].saturating_add(1); + self.flows.insert( + conn_id, + FlowEntry::new(conn_id, self.config.worker_id, bucket_id), + ); + self.flows + .get_mut(&conn_id) + .expect("flow inserted must be retrievable") + }; + + if entry.fairness.pending_bytes.saturating_add(frame_bytes) + > self.config.max_flow_queued_bytes + { + self.pressure + .note_admission_reject(now, &self.config.pressure); + self.enqueue_rejects = self.enqueue_rejects.saturating_add(1); + self.evaluate_pressure(now, true); + return AdmissionDecision::RejectFlowCap; + } + + if self.pressure.state() >= PressureState::Shedding + && entry.fairness.standing_state == StandingQueueState::Standing + { + self.pressure + .note_admission_reject(now, &self.config.pressure); + self.enqueue_rejects = self.enqueue_rejects.saturating_add(1); + self.evaluate_pressure(now, true); + return AdmissionDecision::RejectStandingFlow; + } + + entry.fairness.pending_bytes = entry.fairness.pending_bytes.saturating_add(frame_bytes); + if entry.fairness.queue_started_at.is_none() { + entry.fairness.queue_started_at = Some(now); + } + entry.queue.push_back(frame); + + self.total_queued_bytes = self.total_queued_bytes.saturating_add(frame_bytes); + self.bucket_queued_bytes[bucket_id] = + self.bucket_queued_bytes[bucket_id].saturating_add(frame_bytes); + + if !entry.fairness.in_active_ring { + entry.fairness.in_active_ring = true; + self.active_ring.push_back(conn_id); + } + + self.evaluate_pressure(now, true); + AdmissionDecision::Admit + } + + pub(crate) fn next_decision(&mut self, now: Instant) -> SchedulerDecision { + self.scheduler_rounds = self.scheduler_rounds.saturating_add(1); + self.evaluate_pressure(now, false); + + let active_len = self.active_ring.len(); + for _ in 0..active_len { + let Some(conn_id) = self.active_ring.pop_front() else { + break; + }; + + let mut candidate = None; + let mut requeue_active = false; + let mut drained_bytes = 0u64; + let mut bucket_id = 0usize; + let pressure_state = self.pressure.state(); + + if let Some(flow) = self.flows.get_mut(&conn_id) { + bucket_id = flow.fairness.bucket_id; + + if flow.queue.is_empty() { + flow.fairness.in_active_ring = false; + flow.fairness.scheduler_state = FlowSchedulerState::Idle; + flow.fairness.pending_bytes = 0; + flow.fairness.queue_started_at = None; + continue; + } + + Self::classify_flow(&self.config, pressure_state, now, &mut flow.fairness); + + let quantum = + Self::effective_quantum_bytes(&self.config, pressure_state, &flow.fairness); + flow.fairness.deficit_bytes = flow + .fairness + .deficit_bytes + .saturating_add(i64::from(quantum)); + self.deficit_grants = self.deficit_grants.saturating_add(1); + + let front_len = flow.queue.front().map_or(0, |front| front.queued_bytes()); + if flow.fairness.deficit_bytes < front_len as i64 { + flow.fairness.consecutive_skips = + flow.fairness.consecutive_skips.saturating_add(1); + self.deficit_skips = self.deficit_skips.saturating_add(1); + requeue_active = true; + } else if let Some(frame) = flow.queue.pop_front() { + drained_bytes = frame.queued_bytes(); + flow.fairness.pending_bytes = + flow.fairness.pending_bytes.saturating_sub(drained_bytes); + flow.fairness.deficit_bytes = flow + .fairness + .deficit_bytes + .saturating_sub(drained_bytes as i64); + flow.fairness.consecutive_skips = 0; + flow.fairness.queue_started_at = + flow.queue.front().map(|front| front.enqueued_at); + requeue_active = !flow.queue.is_empty(); + if !requeue_active { + flow.fairness.scheduler_state = FlowSchedulerState::Idle; + flow.fairness.in_active_ring = false; + } + candidate = Some(DispatchCandidate { + pressure_state, + flow_class: flow.fairness.pressure_class, + frame, + }); + } + } + + if drained_bytes > 0 { + self.total_queued_bytes = self.total_queued_bytes.saturating_sub(drained_bytes); + self.bucket_queued_bytes[bucket_id] = + self.bucket_queued_bytes[bucket_id].saturating_sub(drained_bytes); + } + + if requeue_active { + if let Some(flow) = self.flows.get_mut(&conn_id) { + flow.fairness.in_active_ring = true; + } + self.active_ring.push_back(conn_id); + } + + if let Some(candidate) = candidate { + return SchedulerDecision::Dispatch(candidate); + } + } + + SchedulerDecision::Idle + } + + pub(crate) fn apply_dispatch_feedback( + &mut self, + conn_id: u64, + candidate: DispatchCandidate, + feedback: DispatchFeedback, + now: Instant, + ) -> DispatchAction { + match feedback { + DispatchFeedback::Routed => { + if let Some(flow) = self.flows.get_mut(&conn_id) { + flow.fairness.last_drain_at = Some(now); + flow.fairness.recent_drain_bytes = flow + .fairness + .recent_drain_bytes + .saturating_add(candidate.frame.queued_bytes()); + flow.fairness.consecutive_stalls = 0; + if flow.fairness.scheduler_state != FlowSchedulerState::Idle { + flow.fairness.scheduler_state = FlowSchedulerState::Active; + } + } + self.evaluate_pressure(now, false); + DispatchAction::Continue + } + DispatchFeedback::QueueFull => { + self.pressure.note_route_stall(now, &self.config.pressure); + self.downstream_stalls = self.downstream_stalls.saturating_add(1); + let Some(flow) = self.flows.get_mut(&conn_id) else { + self.evaluate_pressure(now, true); + return DispatchAction::Continue; + }; + + flow.fairness.consecutive_stalls = + flow.fairness.consecutive_stalls.saturating_add(1); + flow.fairness.scheduler_state = FlowSchedulerState::Backpressured; + flow.fairness.pressure_class = FlowPressureClass::Backpressured; + + let state = self.pressure.state(); + let should_shed_frame = matches!(state, PressureState::Saturated) + || (matches!(state, PressureState::Shedding) + && flow.fairness.standing_state == StandingQueueState::Standing + && flow.fairness.consecutive_stalls + >= self.config.max_consecutive_stalls_before_shed); + + if should_shed_frame { + self.shed_drops = self.shed_drops.saturating_add(1); + self.fairness_penalties = self.fairness_penalties.saturating_add(1); + } else { + let frame_bytes = candidate.frame.queued_bytes(); + flow.queue.push_front(candidate.frame); + flow.fairness.pending_bytes = + flow.fairness.pending_bytes.saturating_add(frame_bytes); + flow.fairness.queue_started_at = + flow.queue.front().map(|front| front.enqueued_at); + self.total_queued_bytes = self.total_queued_bytes.saturating_add(frame_bytes); + self.bucket_queued_bytes[flow.fairness.bucket_id] = self.bucket_queued_bytes + [flow.fairness.bucket_id] + .saturating_add(frame_bytes); + if !flow.fairness.in_active_ring { + flow.fairness.in_active_ring = true; + self.active_ring.push_back(conn_id); + } + } + + if flow.fairness.consecutive_stalls + >= self.config.max_consecutive_stalls_before_close + && self.pressure.state() == PressureState::Saturated + { + self.remove_flow(conn_id); + self.evaluate_pressure(now, true); + return DispatchAction::CloseFlow; + } + + self.evaluate_pressure(now, true); + DispatchAction::Continue + } + DispatchFeedback::ChannelClosed | DispatchFeedback::NoConn => { + self.remove_flow(conn_id); + self.evaluate_pressure(now, true); + DispatchAction::CloseFlow + } + } + } + + pub(crate) fn remove_flow(&mut self, conn_id: u64) { + let Some(entry) = self.flows.remove(&conn_id) else { + return; + }; + + self.bucket_active_flows[entry.fairness.bucket_id] = + self.bucket_active_flows[entry.fairness.bucket_id].saturating_sub(1); + + let mut reclaimed = 0u64; + for frame in entry.queue { + reclaimed = reclaimed.saturating_add(frame.queued_bytes()); + } + self.total_queued_bytes = self.total_queued_bytes.saturating_sub(reclaimed); + self.bucket_queued_bytes[entry.fairness.bucket_id] = + self.bucket_queued_bytes[entry.fairness.bucket_id].saturating_sub(reclaimed); + } + + fn evaluate_pressure(&mut self, now: Instant, force: bool) { + let mut standing = 0usize; + let mut backpressured = 0usize; + + for flow in self.flows.values_mut() { + Self::classify_flow(&self.config, self.pressure.state(), now, &mut flow.fairness); + if flow.fairness.standing_state == StandingQueueState::Standing { + standing = standing.saturating_add(1); + } + if matches!( + flow.fairness.scheduler_state, + FlowSchedulerState::Backpressured + | FlowSchedulerState::Penalized + | FlowSchedulerState::SheddingCandidate + ) { + backpressured = backpressured.saturating_add(1); + } + } + + self.standing_flow_count = standing; + self.backpressured_flow_count = backpressured; + + let _ = self.pressure.maybe_evaluate( + now, + &self.config.pressure, + self.config.max_total_queued_bytes, + PressureSignals { + active_flows: self.flows.len(), + total_queued_bytes: self.total_queued_bytes, + standing_flows: standing, + backpressured_flows: backpressured, + }, + force, + ); + } + + fn classify_flow( + config: &WorkerFairnessConfig, + pressure_state: PressureState, + now: Instant, + fairness: &mut FlowFairnessState, + ) { + if fairness.pending_bytes == 0 { + fairness.pressure_class = FlowPressureClass::Healthy; + fairness.standing_state = StandingQueueState::Transient; + fairness.scheduler_state = FlowSchedulerState::Idle; + fairness.penalty_score = fairness.penalty_score.saturating_sub(1); + return; + } + + let queue_age = fairness + .queue_started_at + .map(|ts| now.saturating_duration_since(ts)) + .unwrap_or_default(); + let drain_stalled = fairness + .last_drain_at + .map(|ts| now.saturating_duration_since(ts) >= config.standing_queue_min_age) + .unwrap_or(true); + + let standing = fairness.pending_bytes >= config.standing_queue_min_backlog_bytes + && queue_age >= config.standing_queue_min_age + && (fairness.consecutive_stalls >= config.standing_stall_threshold || drain_stalled); + + if standing { + fairness.standing_state = StandingQueueState::Standing; + fairness.pressure_class = FlowPressureClass::Standing; + fairness.penalty_score = fairness.penalty_score.saturating_add(1); + fairness.scheduler_state = if pressure_state >= PressureState::Shedding { + FlowSchedulerState::SheddingCandidate + } else { + FlowSchedulerState::Penalized + }; + return; + } + + fairness.standing_state = StandingQueueState::Transient; + if fairness.consecutive_stalls > 0 { + fairness.pressure_class = FlowPressureClass::Backpressured; + fairness.scheduler_state = FlowSchedulerState::Backpressured; + } else if fairness.pending_bytes >= config.standing_queue_min_backlog_bytes { + fairness.pressure_class = FlowPressureClass::Bursty; + fairness.scheduler_state = FlowSchedulerState::Active; + } else { + fairness.pressure_class = FlowPressureClass::Healthy; + fairness.scheduler_state = FlowSchedulerState::Active; + } + fairness.penalty_score = fairness.penalty_score.saturating_sub(1); + } + + fn effective_quantum_bytes( + config: &WorkerFairnessConfig, + pressure_state: PressureState, + fairness: &FlowFairnessState, + ) -> u32 { + let penalized = matches!( + fairness.scheduler_state, + FlowSchedulerState::Penalized | FlowSchedulerState::SheddingCandidate + ); + + if penalized { + return config.penalized_quantum_bytes.max(1); + } + + match pressure_state { + PressureState::Normal => config.base_quantum_bytes.max(1), + PressureState::Pressured => config.pressured_quantum_bytes.max(1), + PressureState::Shedding => config.pressured_quantum_bytes.max(1), + PressureState::Saturated => config.penalized_quantum_bytes.max(1), + } + } + + fn bucket_for(&self, conn_id: u64) -> usize { + (conn_id as usize) % self.bucket_queued_bytes.len().max(1) + } +} diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 6dfbee6..992fec3 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -2,6 +2,10 @@ mod codec; mod config_updater; +mod fairness; +#[cfg(test)] +#[path = "tests/fairness_security_tests.rs"] +mod fairness_security_tests; mod handshake; mod health; #[cfg(test)] diff --git a/src/transport/middle_proxy/reader.rs b/src/transport/middle_proxy/reader.rs index dbfd9d7..8041185 100644 --- a/src/transport/middle_proxy/reader.rs +++ b/src/transport/middle_proxy/reader.rs @@ -20,11 +20,15 @@ use crate::protocol::constants::*; use crate::stats::Stats; use super::codec::{RpcChecksumMode, WriterCommand, rpc_crc}; +use super::fairness::{ + AdmissionDecision, DispatchAction, DispatchFeedback, SchedulerDecision, WorkerFairnessConfig, + WorkerFairnessSnapshot, WorkerFairnessState, +}; use super::registry::RouteResult; use super::{ConnRegistry, MeResponse}; - const DATA_ROUTE_MAX_ATTEMPTS: usize = 3; const DATA_ROUTE_QUEUE_FULL_STARVATION_THRESHOLD: u8 = 3; +const FAIRNESS_DRAIN_BUDGET_PER_LOOP: usize = 128; fn should_close_on_route_result_for_data(result: RouteResult) -> bool { matches!(result, RouteResult::NoConn | RouteResult::ChannelClosed) @@ -77,6 +81,118 @@ async fn route_data_with_retry( } } +#[inline] +fn route_feedback(result: RouteResult) -> DispatchFeedback { + match result { + RouteResult::Routed => DispatchFeedback::Routed, + RouteResult::NoConn => DispatchFeedback::NoConn, + RouteResult::ChannelClosed => DispatchFeedback::ChannelClosed, + RouteResult::QueueFullBase | RouteResult::QueueFullHigh => DispatchFeedback::QueueFull, + } +} + +fn report_route_drop(result: RouteResult, stats: &Stats) { + match result { + RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), + RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(), + RouteResult::QueueFullBase => { + stats.increment_me_route_drop_queue_full(); + stats.increment_me_route_drop_queue_full_base(); + } + RouteResult::QueueFullHigh => { + stats.increment_me_route_drop_queue_full(); + stats.increment_me_route_drop_queue_full_high(); + } + RouteResult::Routed => {} + } +} + +fn apply_fairness_metrics_delta( + stats: &Stats, + prev: &mut WorkerFairnessSnapshot, + current: WorkerFairnessSnapshot, +) { + stats.set_me_fair_active_flows_gauge(current.active_flows as u64); + stats.set_me_fair_queued_bytes_gauge(current.total_queued_bytes); + stats.set_me_fair_standing_flows_gauge(current.standing_flows as u64); + stats.set_me_fair_backpressured_flows_gauge(current.backpressured_flows as u64); + stats.set_me_fair_pressure_state_gauge(current.pressure_state.as_u8() as u64); + stats.add_me_fair_scheduler_rounds_total( + current + .scheduler_rounds + .saturating_sub(prev.scheduler_rounds), + ); + stats.add_me_fair_deficit_grants_total( + current.deficit_grants.saturating_sub(prev.deficit_grants), + ); + stats.add_me_fair_deficit_skips_total(current.deficit_skips.saturating_sub(prev.deficit_skips)); + stats.add_me_fair_enqueue_rejects_total( + current.enqueue_rejects.saturating_sub(prev.enqueue_rejects), + ); + stats.add_me_fair_shed_drops_total(current.shed_drops.saturating_sub(prev.shed_drops)); + stats.add_me_fair_penalties_total( + current + .fairness_penalties + .saturating_sub(prev.fairness_penalties), + ); + stats.add_me_fair_downstream_stalls_total( + current + .downstream_stalls + .saturating_sub(prev.downstream_stalls), + ); + *prev = current; +} + +async fn drain_fairness_scheduler( + fairness: &mut WorkerFairnessState, + reg: &ConnRegistry, + tx: &mpsc::Sender, + data_route_queue_full_streak: &mut HashMap, + route_wait_ms: u64, + stats: &Stats, +) { + for _ in 0..FAIRNESS_DRAIN_BUDGET_PER_LOOP { + let now = Instant::now(); + let SchedulerDecision::Dispatch(candidate) = fairness.next_decision(now) else { + break; + }; + let cid = candidate.frame.conn_id; + let _pressure_state = candidate.pressure_state; + let _flow_class = candidate.flow_class; + let routed = route_data_with_retry( + reg, + cid, + candidate.frame.flags, + candidate.frame.data.clone(), + route_wait_ms, + ) + .await; + if matches!(routed, RouteResult::Routed) { + data_route_queue_full_streak.remove(&cid); + } else { + report_route_drop(routed, stats); + } + let action = fairness.apply_dispatch_feedback(cid, candidate, route_feedback(routed), now); + if is_data_route_queue_full(routed) { + let streak = data_route_queue_full_streak.entry(cid).or_insert(0); + *streak = streak.saturating_add(1); + if should_close_on_queue_full_streak(*streak) { + fairness.remove_flow(cid); + data_route_queue_full_streak.remove(&cid); + reg.unregister(cid).await; + send_close_conn(tx, cid).await; + continue; + } + } + if action == DispatchAction::CloseFlow || should_close_on_route_result_for_data(routed) { + fairness.remove_flow(cid); + data_route_queue_full_streak.remove(&cid); + reg.unregister(cid).await; + send_close_conn(tx, cid).await; + } + } +} + pub(crate) async fn reader_loop( mut rd: tokio::io::ReadHalf, dk: [u8; 32], @@ -98,7 +214,21 @@ pub(crate) async fn reader_loop( let mut raw = enc_leftover; let mut expected_seq: i32 = 0; let mut data_route_queue_full_streak = HashMap::::new(); - + let mut fairness = WorkerFairnessState::new( + WorkerFairnessConfig { + worker_id: (writer_id as u16).saturating_add(1), + max_active_flows: reg.route_channel_capacity().saturating_mul(4).max(256), + max_total_queued_bytes: (reg.route_channel_capacity() as u64) + .saturating_mul(16 * 1024) + .max(4 * 1024 * 1024), + max_flow_queued_bytes: (reg.route_channel_capacity() as u64) + .saturating_mul(2 * 1024) + .clamp(64 * 1024, 2 * 1024 * 1024), + ..WorkerFairnessConfig::default() + }, + Instant::now(), + ); + let mut fairness_snapshot = fairness.snapshot(); loop { let mut tmp = [0u8; 65_536]; let n = tokio::select! { @@ -181,36 +311,20 @@ pub(crate) async fn reader_loop( let data = body.slice(12..); trace!(cid, flags, len = data.len(), "RPC_PROXY_ANS"); - let route_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed); - let routed = - route_data_with_retry(reg.as_ref(), cid, flags, data, route_wait_ms).await; - if matches!(routed, RouteResult::Routed) { - data_route_queue_full_streak.remove(&cid); - continue; - } - match routed { - RouteResult::NoConn => stats.increment_me_route_drop_no_conn(), - RouteResult::ChannelClosed => stats.increment_me_route_drop_channel_closed(), - RouteResult::QueueFullBase => { - stats.increment_me_route_drop_queue_full(); - stats.increment_me_route_drop_queue_full_base(); - } - RouteResult::QueueFullHigh => { - stats.increment_me_route_drop_queue_full(); - stats.increment_me_route_drop_queue_full_high(); - } - RouteResult::Routed => {} - } - if should_close_on_route_result_for_data(routed) { - data_route_queue_full_streak.remove(&cid); - reg.unregister(cid).await; - send_close_conn(&tx, cid).await; - continue; - } - if is_data_route_queue_full(routed) { + let admission = fairness.enqueue_data(cid, flags, data, Instant::now()); + if !matches!(admission, AdmissionDecision::Admit) { + stats.increment_me_route_drop_queue_full(); + stats.increment_me_route_drop_queue_full_high(); let streak = data_route_queue_full_streak.entry(cid).or_insert(0); *streak = streak.saturating_add(1); - if should_close_on_queue_full_streak(*streak) { + if should_close_on_queue_full_streak(*streak) + || matches!( + admission, + AdmissionDecision::RejectSaturated + | AdmissionDecision::RejectStandingFlow + ) + { + fairness.remove_flow(cid); data_route_queue_full_streak.remove(&cid); reg.unregister(cid).await; send_close_conn(&tx, cid).await; @@ -249,12 +363,14 @@ pub(crate) async fn reader_loop( let _ = reg.route_nowait(cid, MeResponse::Close).await; reg.unregister(cid).await; data_route_queue_full_streak.remove(&cid); + fairness.remove_flow(cid); } else if pt == RPC_CLOSE_CONN_U32 && body.len() >= 8 { let cid = u64::from_le_bytes(body[0..8].try_into().unwrap()); debug!(cid, "RPC_CLOSE_CONN from ME"); let _ = reg.route_nowait(cid, MeResponse::Close).await; reg.unregister(cid).await; data_route_queue_full_streak.remove(&cid); + fairness.remove_flow(cid); } else if pt == RPC_PING_U32 && body.len() >= 8 { let ping_id = i64::from_le_bytes(body[0..8].try_into().unwrap()); trace!(ping_id, "RPC_PING -> RPC_PONG"); @@ -310,6 +426,19 @@ pub(crate) async fn reader_loop( "Unknown RPC" ); } + + let route_wait_ms = reader_route_data_wait_ms.load(Ordering::Relaxed); + drain_fairness_scheduler( + &mut fairness, + reg.as_ref(), + &tx, + &mut data_route_queue_full_streak, + route_wait_ms, + stats.as_ref(), + ) + .await; + let current_snapshot = fairness.snapshot(); + apply_fairness_metrics_delta(stats.as_ref(), &mut fairness_snapshot, current_snapshot); } } } diff --git a/src/transport/middle_proxy/registry.rs b/src/transport/middle_proxy/registry.rs index d8625f2..0c7a0a9 100644 --- a/src/transport/middle_proxy/registry.rs +++ b/src/transport/middle_proxy/registry.rs @@ -140,6 +140,10 @@ impl ConnRegistry { } } + pub fn route_channel_capacity(&self) -> usize { + self.route_channel_capacity + } + #[cfg(test)] pub fn new() -> Self { Self::with_route_channel_capacity(4096) diff --git a/src/transport/middle_proxy/tests/fairness_security_tests.rs b/src/transport/middle_proxy/tests/fairness_security_tests.rs new file mode 100644 index 0000000..41a8d86 --- /dev/null +++ b/src/transport/middle_proxy/tests/fairness_security_tests.rs @@ -0,0 +1,185 @@ +use std::time::{Duration, Instant}; + +use bytes::Bytes; + +use crate::transport::middle_proxy::fairness::{ + AdmissionDecision, DispatchAction, DispatchFeedback, PressureState, SchedulerDecision, + WorkerFairnessConfig, WorkerFairnessState, +}; + +fn enqueue_payload(size: usize) -> Bytes { + Bytes::from(vec![0xAB; size]) +} + +#[test] +fn fairness_rejects_when_worker_budget_is_exhausted() { + let now = Instant::now(); + let mut fairness = WorkerFairnessState::new( + WorkerFairnessConfig { + max_total_queued_bytes: 1024, + max_flow_queued_bytes: 1024, + ..WorkerFairnessConfig::default() + }, + now, + ); + + assert_eq!( + fairness.enqueue_data(1, 0, enqueue_payload(700), now), + AdmissionDecision::Admit + ); + assert_eq!( + fairness.enqueue_data(2, 0, enqueue_payload(400), now), + AdmissionDecision::RejectWorkerCap + ); + + let snapshot = fairness.snapshot(); + assert!(snapshot.total_queued_bytes <= 1024); + assert_eq!(snapshot.enqueue_rejects, 1); +} + +#[test] +fn fairness_marks_standing_queue_after_stall_and_age_threshold() { + let mut now = Instant::now(); + let mut fairness = WorkerFairnessState::new( + WorkerFairnessConfig { + standing_queue_min_age: Duration::from_millis(50), + standing_queue_min_backlog_bytes: 256, + standing_stall_threshold: 1, + max_flow_queued_bytes: 4096, + max_total_queued_bytes: 4096, + ..WorkerFairnessConfig::default() + }, + now, + ); + + assert_eq!( + fairness.enqueue_data(11, 0, enqueue_payload(512), now), + AdmissionDecision::Admit + ); + + now += Duration::from_millis(100); + let SchedulerDecision::Dispatch(candidate) = fairness.next_decision(now) else { + panic!("expected dispatch candidate"); + }; + + let action = fairness.apply_dispatch_feedback(11, candidate, DispatchFeedback::QueueFull, now); + assert!(matches!(action, DispatchAction::Continue)); + + let snapshot = fairness.snapshot(); + assert_eq!(snapshot.standing_flows, 1); + assert!(snapshot.backpressured_flows >= 1); +} + +#[test] +fn fairness_keeps_fast_flow_progress_under_slow_neighbor() { + let mut now = Instant::now(); + let mut fairness = WorkerFairnessState::new( + WorkerFairnessConfig { + max_total_queued_bytes: 64 * 1024, + max_flow_queued_bytes: 32 * 1024, + ..WorkerFairnessConfig::default() + }, + now, + ); + + for _ in 0..16 { + assert_eq!( + fairness.enqueue_data(1, 0, enqueue_payload(512), now), + AdmissionDecision::Admit + ); + assert_eq!( + fairness.enqueue_data(2, 0, enqueue_payload(512), now), + AdmissionDecision::Admit + ); + } + + let mut fast_routed = 0u64; + for _ in 0..128 { + now += Duration::from_millis(5); + let SchedulerDecision::Dispatch(candidate) = fairness.next_decision(now) else { + break; + }; + let cid = candidate.frame.conn_id; + let feedback = if cid == 2 { + DispatchFeedback::QueueFull + } else { + fast_routed = fast_routed.saturating_add(1); + DispatchFeedback::Routed + }; + let _ = fairness.apply_dispatch_feedback(cid, candidate, feedback, now); + } + + let snapshot = fairness.snapshot(); + assert!(fast_routed > 0, "fast flow must continue making progress"); + assert!(snapshot.total_queued_bytes <= 64 * 1024); +} + +#[test] +fn fairness_pressure_hysteresis_prevents_instant_flapping() { + let mut now = Instant::now(); + let mut cfg = WorkerFairnessConfig::default(); + cfg.max_total_queued_bytes = 4096; + cfg.max_flow_queued_bytes = 4096; + cfg.pressure.evaluate_every_rounds = 1; + cfg.pressure.transition_hysteresis_rounds = 3; + cfg.pressure.queue_ratio_pressured_pct = 40; + cfg.pressure.queue_ratio_shedding_pct = 60; + cfg.pressure.queue_ratio_saturated_pct = 80; + + let mut fairness = WorkerFairnessState::new(cfg, now); + + for _ in 0..4 { + assert_eq!( + fairness.enqueue_data(9, 0, enqueue_payload(900), now), + AdmissionDecision::Admit + ); + } + + for _ in 0..2 { + now += Duration::from_millis(1); + let _ = fairness.next_decision(now); + } + + assert_eq!( + fairness.pressure_state(), + PressureState::Normal, + "state must not flip before hysteresis confirmations" + ); +} + +#[test] +fn fairness_randomized_sequence_preserves_memory_bounds() { + let mut now = Instant::now(); + let mut fairness = WorkerFairnessState::new( + WorkerFairnessConfig { + max_total_queued_bytes: 32 * 1024, + max_flow_queued_bytes: 4 * 1024, + ..WorkerFairnessConfig::default() + }, + now, + ); + + let mut seed = 0xC0FFEE_u64; + for _ in 0..4096 { + seed ^= seed << 7; + seed ^= seed >> 9; + seed ^= seed << 8; + let flow = (seed % 32) + 1; + let size = ((seed >> 8) % 512 + 64) as usize; + let _ = fairness.enqueue_data(flow, 0, enqueue_payload(size), now); + + now += Duration::from_millis(1); + if let SchedulerDecision::Dispatch(candidate) = fairness.next_decision(now) { + let feedback = if seed & 0x1 == 0 { + DispatchFeedback::Routed + } else { + DispatchFeedback::QueueFull + }; + let _ = + fairness.apply_dispatch_feedback(candidate.frame.conn_id, candidate, feedback, now); + } + + let snapshot = fairness.snapshot(); + assert!(snapshot.total_queued_bytes <= 32 * 1024); + } +}