diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index f42638c..4a71e37 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -121,6 +121,8 @@ pub struct HotFields { pub user_max_tcp_conns_global_each: usize, pub user_expirations: std::collections::HashMap>, pub user_data_quota: std::collections::HashMap, + pub user_rate_limits: std::collections::HashMap, + pub cidr_rate_limits: std::collections::HashMap, pub user_max_unique_ips: std::collections::HashMap, pub user_max_unique_ips_global_each: usize, pub user_max_unique_ips_mode: crate::config::UserMaxUniqueIpsMode, @@ -245,6 +247,8 @@ impl HotFields { user_max_tcp_conns_global_each: cfg.access.user_max_tcp_conns_global_each, user_expirations: cfg.access.user_expirations.clone(), user_data_quota: cfg.access.user_data_quota.clone(), + user_rate_limits: cfg.access.user_rate_limits.clone(), + cidr_rate_limits: cfg.access.cidr_rate_limits.clone(), user_max_unique_ips: cfg.access.user_max_unique_ips.clone(), user_max_unique_ips_global_each: cfg.access.user_max_unique_ips_global_each, user_max_unique_ips_mode: cfg.access.user_max_unique_ips_mode, @@ -545,6 +549,8 @@ fn overlay_hot_fields(old: &ProxyConfig, new: &ProxyConfig) -> ProxyConfig { cfg.access.user_max_tcp_conns_global_each = new.access.user_max_tcp_conns_global_each; cfg.access.user_expirations = new.access.user_expirations.clone(); cfg.access.user_data_quota = new.access.user_data_quota.clone(); + cfg.access.user_rate_limits = new.access.user_rate_limits.clone(); + cfg.access.cidr_rate_limits = new.access.cidr_rate_limits.clone(); cfg.access.user_max_unique_ips = new.access.user_max_unique_ips.clone(); cfg.access.user_max_unique_ips_global_each = new.access.user_max_unique_ips_global_each; cfg.access.user_max_unique_ips_mode = new.access.user_max_unique_ips_mode; @@ -1183,6 +1189,18 @@ fn log_changes( new_hot.user_data_quota.len() ); } + if old_hot.user_rate_limits != new_hot.user_rate_limits { + info!( + "config reload: user_rate_limits updated ({} entries)", + new_hot.user_rate_limits.len() + ); + } + if old_hot.cidr_rate_limits != new_hot.cidr_rate_limits { + info!( + "config reload: cidr_rate_limits updated ({} entries)", + new_hot.cidr_rate_limits.len() + ); + } if old_hot.user_max_unique_ips != new_hot.user_max_unique_ips { info!( "config reload: user_max_unique_ips updated ({} entries)", diff --git a/src/config/load.rs b/src/config/load.rs index e5c8202..d15773c 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -861,6 +861,22 @@ impl ProxyConfig { )); } + for (user, limit) in &config.access.user_rate_limits { + if limit.up_bps == 0 && limit.down_bps == 0 { + return Err(ProxyError::Config(format!( + "access.user_rate_limits.{user} must set at least one non-zero direction" + ))); + } + } + + for (cidr, limit) in &config.access.cidr_rate_limits { + if limit.up_bps == 0 && limit.down_bps == 0 { + return Err(ProxyError::Config(format!( + "access.cidr_rate_limits.{cidr} must set at least one non-zero direction" + ))); + } + } + if config.general.me_reinit_every_secs == 0 { return Err(ProxyError::Config( "general.me_reinit_every_secs must be > 0".to_string(), diff --git a/src/config/types.rs b/src/config/types.rs index 35b8d46..9f7e0f4 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -1826,6 +1826,21 @@ pub struct AccessConfig { #[serde(default)] pub user_data_quota: HashMap, + /// Per-user transport rate limits in bits-per-second. + /// + /// Each entry supports independent upload (`up_bps`) and download + /// (`down_bps`) ceilings. A value of `0` in one direction means + /// "unlimited" for that direction. + #[serde(default)] + pub user_rate_limits: HashMap, + + /// Per-CIDR aggregate transport rate limits in bits-per-second. + /// + /// Matching uses longest-prefix-wins semantics. A value of `0` in one + /// direction means "unlimited" for that direction. + #[serde(default)] + pub cidr_rate_limits: HashMap, + #[serde(default)] pub user_max_unique_ips: HashMap, @@ -1859,6 +1874,8 @@ impl Default for AccessConfig { user_max_tcp_conns_global_each: default_user_max_tcp_conns_global_each(), user_expirations: HashMap::new(), user_data_quota: HashMap::new(), + user_rate_limits: HashMap::new(), + cidr_rate_limits: HashMap::new(), user_max_unique_ips: HashMap::new(), user_max_unique_ips_global_each: default_user_max_unique_ips_global_each(), user_max_unique_ips_mode: UserMaxUniqueIpsMode::default(), @@ -1870,6 +1887,14 @@ impl Default for AccessConfig { } } +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct RateLimitBps { + #[serde(default)] + pub up_bps: u64, + #[serde(default)] + pub down_bps: u64, +} + // ============= Aux Structures ============= #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] diff --git a/src/maestro/mod.rs b/src/maestro/mod.rs index f141331..4c5b98a 100644 --- a/src/maestro/mod.rs +++ b/src/maestro/mod.rs @@ -664,6 +664,11 @@ async fn run_telemt_core( )); let buffer_pool = Arc::new(BufferPool::with_config(64 * 1024, 4096)); + let shared_state = ProxySharedState::new(); + shared_state.traffic_limiter.apply_policy( + config.access.user_rate_limits.clone(), + config.access.cidr_rate_limits.clone(), + ); connectivity::run_startup_connectivity( &config, @@ -695,6 +700,7 @@ async fn run_telemt_core( beobachten.clone(), api_config_tx.clone(), me_pool.clone(), + shared_state.clone(), ) .await; let config_rx = runtime_watches.config_rx; @@ -711,7 +717,6 @@ async fn run_telemt_core( ) .await; let _admission_tx_hold = admission_tx; - let shared_state = ProxySharedState::new(); conntrack_control::spawn_conntrack_controller( config_rx.clone(), stats.clone(), diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index 5b3f2e0..b48060c 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -51,6 +51,7 @@ pub(crate) async fn spawn_runtime_tasks( beobachten: Arc, api_config_tx: watch::Sender>, me_pool_for_policy: Option>, + shared_state: Arc, ) -> RuntimeWatches { let um_clone = upstream_manager.clone(); let dc_overrides_for_health = config.dc_overrides.clone(); @@ -182,6 +183,33 @@ pub(crate) async fn spawn_runtime_tasks( } }); + let limiter = shared_state.traffic_limiter.clone(); + limiter.apply_policy( + config.access.user_rate_limits.clone(), + config.access.cidr_rate_limits.clone(), + ); + let mut config_rx_rate_limits = config_rx.clone(); + tokio::spawn(async move { + let mut prev_user_limits = config_rx_rate_limits.borrow().access.user_rate_limits.clone(); + let mut prev_cidr_limits = config_rx_rate_limits.borrow().access.cidr_rate_limits.clone(); + loop { + if config_rx_rate_limits.changed().await.is_err() { + break; + } + let cfg = config_rx_rate_limits.borrow_and_update().clone(); + if prev_user_limits != cfg.access.user_rate_limits + || prev_cidr_limits != cfg.access.cidr_rate_limits + { + limiter.apply_policy( + cfg.access.user_rate_limits.clone(), + cfg.access.cidr_rate_limits.clone(), + ); + prev_user_limits = cfg.access.user_rate_limits.clone(); + prev_cidr_limits = cfg.access.cidr_rate_limits.clone(); + } + } + }); + let beobachten_writer = beobachten.clone(); let config_rx_beobachten = config_rx.clone(); tokio::spawn(async move { diff --git a/src/metrics.rs b/src/metrics.rs index 1b920a8..34c0cac 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -575,6 +575,139 @@ async fn render_metrics( } ); + let limiter_metrics = shared_state.traffic_limiter.metrics_snapshot(); + let _ = writeln!( + out, + "# HELP telemt_rate_limiter_throttle_total Traffic limiter throttle events by scope and direction" + ); + let _ = writeln!(out, "# TYPE telemt_rate_limiter_throttle_total counter"); + let _ = writeln!( + out, + "telemt_rate_limiter_throttle_total{{scope=\"user\",direction=\"up\"}} {}", + if core_enabled { + limiter_metrics.user_throttle_up_total + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_rate_limiter_throttle_total{{scope=\"user\",direction=\"down\"}} {}", + if core_enabled { + limiter_metrics.user_throttle_down_total + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_rate_limiter_throttle_total{{scope=\"cidr\",direction=\"up\"}} {}", + if core_enabled { + limiter_metrics.cidr_throttle_up_total + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_rate_limiter_throttle_total{{scope=\"cidr\",direction=\"down\"}} {}", + if core_enabled { + limiter_metrics.cidr_throttle_down_total + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_rate_limiter_wait_ms_total Traffic limiter accumulated wait time in milliseconds by scope and direction" + ); + let _ = writeln!(out, "# TYPE telemt_rate_limiter_wait_ms_total counter"); + let _ = writeln!( + out, + "telemt_rate_limiter_wait_ms_total{{scope=\"user\",direction=\"up\"}} {}", + if core_enabled { + limiter_metrics.user_wait_up_ms_total + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_rate_limiter_wait_ms_total{{scope=\"user\",direction=\"down\"}} {}", + if core_enabled { + limiter_metrics.user_wait_down_ms_total + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_rate_limiter_wait_ms_total{{scope=\"cidr\",direction=\"up\"}} {}", + if core_enabled { + limiter_metrics.cidr_wait_up_ms_total + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_rate_limiter_wait_ms_total{{scope=\"cidr\",direction=\"down\"}} {}", + if core_enabled { + limiter_metrics.cidr_wait_down_ms_total + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_rate_limiter_active_leases Active relay leases under rate limiting by scope" + ); + let _ = writeln!(out, "# TYPE telemt_rate_limiter_active_leases gauge"); + let _ = writeln!( + out, + "telemt_rate_limiter_active_leases{{scope=\"user\"}} {}", + if core_enabled { + limiter_metrics.user_active_leases + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_rate_limiter_active_leases{{scope=\"cidr\"}} {}", + if core_enabled { + limiter_metrics.cidr_active_leases + } else { + 0 + } + ); + + let _ = writeln!( + out, + "# HELP telemt_rate_limiter_policy_entries Active rate-limit policy entries by scope" + ); + let _ = writeln!(out, "# TYPE telemt_rate_limiter_policy_entries gauge"); + let _ = writeln!( + out, + "telemt_rate_limiter_policy_entries{{scope=\"user\"}} {}", + if core_enabled { + limiter_metrics.user_policy_entries + } else { + 0 + } + ); + let _ = writeln!( + out, + "telemt_rate_limiter_policy_entries{{scope=\"cidr\"}} {}", + if core_enabled { + limiter_metrics.cidr_policy_entries + } else { + 0 + } + ); + let _ = writeln!( out, "# HELP telemt_upstream_connect_attempt_total Upstream connect attempts across all requests" diff --git a/src/proxy/direct_relay.rs b/src/proxy/direct_relay.rs index 2c4fe45..8b6189b 100644 --- a/src/proxy/direct_relay.rs +++ b/src/proxy/direct_relay.rs @@ -316,6 +316,7 @@ where stats.increment_user_connects(user); let _direct_connection_lease = stats.acquire_direct_connection_lease(); + let traffic_lease = shared.traffic_limiter.acquire_lease(user, success.peer.ip()); let buffer_pool_trim = Arc::clone(&buffer_pool); let relay_activity_timeout = if shared.conntrack_pressure_active() { @@ -329,7 +330,7 @@ where } else { Duration::from_secs(1800) }; - let relay_result = crate::proxy::relay::relay_bidirectional_with_activity_timeout( + let relay_result = crate::proxy::relay::relay_bidirectional_with_activity_timeout_and_lease( client_reader, client_writer, tg_reader, @@ -340,6 +341,7 @@ where Arc::clone(&stats), config.access.user_data_quota.get(user).copied(), buffer_pool, + traffic_lease, relay_activity_timeout, ); tokio::pin!(relay_result); diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index eb68f83..4b42725 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -28,6 +28,7 @@ use crate::proxy::route_mode::{ use crate::proxy::shared_state::{ ConntrackCloseEvent, ConntrackClosePublishResult, ConntrackCloseReason, ProxySharedState, }; +use crate::proxy::traffic_limiter::{RateDirection, TrafficLease, next_refill_delay}; use crate::stats::{ MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, QuotaReserveError, Stats, UserStats, }; @@ -595,6 +596,41 @@ async fn reserve_user_quota_with_yield( } } +async fn wait_for_traffic_budget( + lease: Option<&Arc>, + direction: RateDirection, + bytes: u64, +) { + if bytes == 0 { + return; + } + let Some(lease) = lease else { + return; + }; + + let mut remaining = bytes; + while remaining > 0 { + let consume = lease.try_consume(direction, remaining); + if consume.granted > 0 { + remaining = remaining.saturating_sub(consume.granted); + continue; + } + + let wait_started_at = Instant::now(); + tokio::time::sleep(next_refill_delay()).await; + let wait_ms = wait_started_at + .elapsed() + .as_millis() + .min(u128::from(u64::MAX)) as u64; + lease.observe_wait_ms( + direction, + consume.blocked_user, + consume.blocked_cidr, + wait_ms, + ); + } +} + fn classify_me_d2c_flush_reason( flush_immediately: bool, batch_frames: usize, @@ -985,6 +1021,7 @@ where let quota_limit = config.access.user_data_quota.get(&user).copied(); let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user)); let peer = success.peer; + let traffic_lease = shared.traffic_limiter.acquire_lease(&user, peer.ip()); let proto_tag = success.proto_tag; let pool_generation = me_pool.current_generation(); @@ -1120,6 +1157,7 @@ where let rng_clone = rng.clone(); let user_clone = user.clone(); let quota_user_stats_me_writer = quota_user_stats.clone(); + let traffic_lease_me_writer = traffic_lease.clone(); let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); let bytes_me2c_clone = bytes_me2c.clone(); let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); @@ -1153,7 +1191,7 @@ where let first_is_downstream_activity = matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_traffic_lease( first, &mut writer, proto_tag, @@ -1164,6 +1202,7 @@ where quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + traffic_lease_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1213,7 +1252,7 @@ where let next_is_downstream_activity = matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_traffic_lease( next, &mut writer, proto_tag, @@ -1224,6 +1263,7 @@ where quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + traffic_lease_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1276,7 +1316,7 @@ where Ok(Some(next)) => { let next_is_downstream_activity = matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_traffic_lease( next, &mut writer, proto_tag, @@ -1287,6 +1327,7 @@ where quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + traffic_lease_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1341,7 +1382,7 @@ where let extra_is_downstream_activity = matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response( + match process_me_writer_response_with_traffic_lease( extra, &mut writer, proto_tag, @@ -1352,6 +1393,7 @@ where quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, + traffic_lease_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1542,6 +1584,12 @@ where match payload_result { Ok(Some((payload, quickack))) => { trace!(conn_id, bytes = payload.len(), "C->ME frame"); + wait_for_traffic_budget( + traffic_lease.as_ref(), + RateDirection::Up, + payload.len() as u64, + ) + .await; forensics.bytes_c2me = forensics .bytes_c2me .saturating_add(payload.len() as u64); @@ -2160,6 +2208,46 @@ async fn process_me_writer_response( ack_flush_immediate: bool, batched: bool, ) -> Result +where + W: AsyncWrite + Unpin + Send + 'static, +{ + process_me_writer_response_with_traffic_lease( + response, + client_writer, + proto_tag, + rng, + frame_buf, + stats, + user, + quota_user_stats, + quota_limit, + quota_soft_overshoot_bytes, + None, + bytes_me2c, + conn_id, + ack_flush_immediate, + batched, + ) + .await +} + +async fn process_me_writer_response_with_traffic_lease( + response: MeResponse, + client_writer: &mut CryptoWriter, + proto_tag: ProtoTag, + rng: &SecureRandom, + frame_buf: &mut Vec, + stats: &Stats, + user: &str, + quota_user_stats: Option<&UserStats>, + quota_limit: Option, + quota_soft_overshoot_bytes: u64, + traffic_lease: Option<&Arc>, + bytes_me2c: &AtomicU64, + conn_id: u64, + ack_flush_immediate: bool, + batched: bool, +) -> Result where W: AsyncWrite + Unpin + Send + 'static, { @@ -2183,6 +2271,7 @@ where }); } } + wait_for_traffic_budget(traffic_lease, RateDirection::Down, data_len).await; let write_mode = match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) @@ -2220,6 +2309,7 @@ where } else { trace!(conn_id, confirm, "ME->C quickack"); } + wait_for_traffic_budget(traffic_lease, RateDirection::Down, 4).await; write_client_ack(client_writer, proto_tag, confirm).await?; stats.increment_me_d2c_ack_frames_total(); diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index c4ce09c..4e1827e 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -68,6 +68,7 @@ pub mod relay; pub mod route_mode; pub mod session_eviction; pub mod shared_state; +pub mod traffic_limiter; pub use client::ClientHandler; #[allow(unused_imports)] diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index f612cb1..a59ca5b 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -52,6 +52,7 @@ //! - `SharedCounters` (atomics) let the watchdog read stats without locking use crate::error::{ProxyError, Result}; +use crate::proxy::traffic_limiter::{RateDirection, TrafficLease, next_refill_delay}; use crate::stats::{Stats, UserStats}; use crate::stream::BufferPool; use std::io; @@ -61,7 +62,7 @@ use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes}; -use tokio::time::Instant; +use tokio::time::{Instant, Sleep}; use tracing::{debug, trace, warn}; // ============= Constants ============= @@ -210,12 +211,24 @@ struct StatsIo { stats: Arc, user: String, user_stats: Arc, + traffic_lease: Option>, + c2s_rate_debt_bytes: u64, + c2s_wait: RateWaitState, + s2c_wait: RateWaitState, quota_limit: Option, quota_exceeded: Arc, quota_bytes_since_check: u64, epoch: Instant, } +#[derive(Default)] +struct RateWaitState { + sleep: Option>>, + started_at: Option, + blocked_user: bool, + blocked_cidr: bool, +} + impl StatsIo { fn new( inner: S, @@ -225,6 +238,28 @@ impl StatsIo { quota_limit: Option, quota_exceeded: Arc, epoch: Instant, + ) -> Self { + Self::new_with_traffic_lease( + inner, + counters, + stats, + user, + None, + quota_limit, + quota_exceeded, + epoch, + ) + } + + fn new_with_traffic_lease( + inner: S, + counters: Arc, + stats: Arc, + user: String, + traffic_lease: Option>, + quota_limit: Option, + quota_exceeded: Arc, + epoch: Instant, ) -> Self { // Mark initial activity so the watchdog doesn't fire before data flows counters.touch(Instant::now(), epoch); @@ -235,12 +270,97 @@ impl StatsIo { stats, user, user_stats, + traffic_lease, + c2s_rate_debt_bytes: 0, + c2s_wait: RateWaitState::default(), + s2c_wait: RateWaitState::default(), quota_limit, quota_exceeded, quota_bytes_since_check: 0, epoch, } } + + fn record_wait( + wait: &mut RateWaitState, + lease: Option<&Arc>, + direction: RateDirection, + ) { + let Some(started_at) = wait.started_at.take() else { + return; + }; + let wait_ms = started_at + .elapsed() + .as_millis() + .min(u128::from(u64::MAX)) as u64; + if let Some(lease) = lease { + lease.observe_wait_ms( + direction, + wait.blocked_user, + wait.blocked_cidr, + wait_ms, + ); + } + wait.blocked_user = false; + wait.blocked_cidr = false; + } + + fn arm_wait(wait: &mut RateWaitState, blocked_user: bool, blocked_cidr: bool) { + if wait.sleep.is_none() { + wait.sleep = Some(Box::pin(tokio::time::sleep(next_refill_delay()))); + wait.started_at = Some(Instant::now()); + } + wait.blocked_user |= blocked_user; + wait.blocked_cidr |= blocked_cidr; + } + + fn poll_wait( + wait: &mut RateWaitState, + cx: &mut Context<'_>, + lease: Option<&Arc>, + direction: RateDirection, + ) -> Poll<()> { + let Some(sleep) = wait.sleep.as_mut() else { + return Poll::Ready(()); + }; + if sleep.as_mut().poll(cx).is_pending() { + return Poll::Pending; + } + wait.sleep = None; + Self::record_wait(wait, lease, direction); + Poll::Ready(()) + } + + fn settle_c2s_rate_debt(&mut self, cx: &mut Context<'_>) -> Poll<()> { + let Some(lease) = self.traffic_lease.as_ref() else { + self.c2s_rate_debt_bytes = 0; + return Poll::Ready(()); + }; + + while self.c2s_rate_debt_bytes > 0 { + let consume = lease.try_consume(RateDirection::Up, self.c2s_rate_debt_bytes); + if consume.granted > 0 { + self.c2s_rate_debt_bytes = + self.c2s_rate_debt_bytes.saturating_sub(consume.granted); + continue; + } + Self::arm_wait( + &mut self.c2s_wait, + consume.blocked_user, + consume.blocked_cidr, + ); + if Self::poll_wait(&mut self.c2s_wait, cx, Some(lease), RateDirection::Up).is_pending() + { + return Poll::Pending; + } + } + + if Self::poll_wait(&mut self.c2s_wait, cx, Some(lease), RateDirection::Up).is_pending() { + return Poll::Pending; + } + + Poll::Ready(()) + } } #[derive(Debug)] @@ -286,6 +406,25 @@ fn should_immediate_quota_check(remaining_before: u64, charge_bytes: u64) -> boo remaining_before <= QUOTA_NEAR_LIMIT_BYTES || charge_bytes >= QUOTA_LARGE_CHARGE_BYTES } +fn refund_reserved_quota_bytes(user_stats: &UserStats, reserved_bytes: u64) { + if reserved_bytes == 0 { + return; + } + let mut current = user_stats.quota_used.load(Ordering::Relaxed); + loop { + let next = current.saturating_sub(reserved_bytes); + match user_stats.quota_used.compare_exchange_weak( + current, + next, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => return, + Err(observed) => current = observed, + } + } +} + impl AsyncRead for StatsIo { fn poll_read( self: Pin<&mut Self>, @@ -296,6 +435,9 @@ impl AsyncRead for StatsIo { if this.quota_exceeded.load(Ordering::Acquire) { return Poll::Ready(Err(quota_io_error())); } + if this.settle_c2s_rate_debt(cx).is_pending() { + return Poll::Pending; + } let mut remaining_before = None; if let Some(limit) = this.quota_limit { @@ -377,6 +519,11 @@ impl AsyncRead for StatsIo { .add_user_octets_from_handle(this.user_stats.as_ref(), n_to_charge); this.stats .increment_user_msgs_from_handle(this.user_stats.as_ref()); + if this.traffic_lease.is_some() { + this.c2s_rate_debt_bytes = + this.c2s_rate_debt_bytes.saturating_add(n_to_charge); + let _ = this.settle_c2s_rate_debt(cx); + } trace!(user = %this.user, bytes = n, "C->S"); } @@ -398,28 +545,66 @@ impl AsyncWrite for StatsIo { return Poll::Ready(Err(quota_io_error())); } + let mut shaper_reserved_bytes = 0u64; + let mut write_buf = buf; + if let Some(lease) = this.traffic_lease.as_ref() { + if !buf.is_empty() { + loop { + let consume = lease.try_consume(RateDirection::Down, buf.len() as u64); + if consume.granted > 0 { + shaper_reserved_bytes = consume.granted; + if consume.granted < buf.len() as u64 { + write_buf = &buf[..consume.granted as usize]; + } + let _ = Self::poll_wait( + &mut this.s2c_wait, + cx, + Some(lease), + RateDirection::Down, + ); + break; + } + + Self::arm_wait( + &mut this.s2c_wait, + consume.blocked_user, + consume.blocked_cidr, + ); + if Self::poll_wait(&mut this.s2c_wait, cx, Some(lease), RateDirection::Down) + .is_pending() + { + return Poll::Pending; + } + } + } else { + let _ = Self::poll_wait(&mut this.s2c_wait, cx, Some(lease), RateDirection::Down); + } + } + let mut remaining_before = None; let mut reserved_bytes = 0u64; - let mut write_buf = buf; if let Some(limit) = this.quota_limit { - if !buf.is_empty() { + if !write_buf.is_empty() { let mut reserve_rounds = 0usize; while reserved_bytes == 0 { let used_before = this.user_stats.quota_used(); let remaining = limit.saturating_sub(used_before); if remaining == 0 { + if let Some(lease) = this.traffic_lease.as_ref() { + lease.refund(RateDirection::Down, shaper_reserved_bytes); + } this.quota_exceeded.store(true, Ordering::Release); return Poll::Ready(Err(quota_io_error())); } remaining_before = Some(remaining); - let desired = remaining.min(buf.len() as u64); + let desired = remaining.min(write_buf.len() as u64); let mut saw_contention = false; for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { match this.user_stats.quota_try_reserve(desired, limit) { Ok(_) => { reserved_bytes = desired; - write_buf = &buf[..desired as usize]; + write_buf = &write_buf[..desired as usize]; break; } Err(crate::stats::QuotaReserveError::LimitExceeded) => { @@ -434,6 +619,9 @@ impl AsyncWrite for StatsIo { if reserved_bytes == 0 { reserve_rounds = reserve_rounds.saturating_add(1); if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS { + if let Some(lease) = this.traffic_lease.as_ref() { + lease.refund(RateDirection::Down, shaper_reserved_bytes); + } this.quota_exceeded.store(true, Ordering::Release); return Poll::Ready(Err(quota_io_error())); } @@ -446,6 +634,9 @@ impl AsyncWrite for StatsIo { let used_before = this.user_stats.quota_used(); let remaining = limit.saturating_sub(used_before); if remaining == 0 { + if let Some(lease) = this.traffic_lease.as_ref() { + lease.refund(RateDirection::Down, shaper_reserved_bytes); + } this.quota_exceeded.store(true, Ordering::Release); return Poll::Ready(Err(quota_io_error())); } @@ -456,23 +647,17 @@ impl AsyncWrite for StatsIo { match Pin::new(&mut this.inner).poll_write(cx, write_buf) { Poll::Ready(Ok(n)) => { if reserved_bytes > n as u64 { - let refund = reserved_bytes - n as u64; - let mut current = this.user_stats.quota_used.load(Ordering::Relaxed); - loop { - let next = current.saturating_sub(refund); - match this.user_stats.quota_used.compare_exchange_weak( - current, - next, - Ordering::Relaxed, - Ordering::Relaxed, - ) { - Ok(_) => break, - Err(observed) => current = observed, - } - } + refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes - n as u64); + } + if shaper_reserved_bytes > n as u64 + && let Some(lease) = this.traffic_lease.as_ref() + { + lease.refund(RateDirection::Down, shaper_reserved_bytes - n as u64); } - if n > 0 { + if let Some(lease) = this.traffic_lease.as_ref() { + Self::record_wait(&mut this.s2c_wait, Some(lease), RateDirection::Down); + } let n_to_charge = n as u64; // S→C: data written to client @@ -512,37 +697,23 @@ impl AsyncWrite for StatsIo { } Poll::Ready(Err(err)) => { if reserved_bytes > 0 { - let mut current = this.user_stats.quota_used.load(Ordering::Relaxed); - loop { - let next = current.saturating_sub(reserved_bytes); - match this.user_stats.quota_used.compare_exchange_weak( - current, - next, - Ordering::Relaxed, - Ordering::Relaxed, - ) { - Ok(_) => break, - Err(observed) => current = observed, - } - } + refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes); + } + if shaper_reserved_bytes > 0 + && let Some(lease) = this.traffic_lease.as_ref() + { + lease.refund(RateDirection::Down, shaper_reserved_bytes); } Poll::Ready(Err(err)) } Poll::Pending => { if reserved_bytes > 0 { - let mut current = this.user_stats.quota_used.load(Ordering::Relaxed); - loop { - let next = current.saturating_sub(reserved_bytes); - match this.user_stats.quota_used.compare_exchange_weak( - current, - next, - Ordering::Relaxed, - Ordering::Relaxed, - ) { - Ok(_) => break, - Err(observed) => current = observed, - } - } + refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes); + } + if shaper_reserved_bytes > 0 + && let Some(lease) = this.traffic_lease.as_ref() + { + lease.refund(RateDirection::Down, shaper_reserved_bytes); } Poll::Pending } @@ -627,6 +798,43 @@ pub async fn relay_bidirectional_with_activity_timeout( _buffer_pool: Arc, activity_timeout: Duration, ) -> Result<()> +where + CR: AsyncRead + Unpin + Send + 'static, + CW: AsyncWrite + Unpin + Send + 'static, + SR: AsyncRead + Unpin + Send + 'static, + SW: AsyncWrite + Unpin + Send + 'static, +{ + relay_bidirectional_with_activity_timeout_and_lease( + client_reader, + client_writer, + server_reader, + server_writer, + c2s_buf_size, + s2c_buf_size, + user, + stats, + quota_limit, + _buffer_pool, + None, + activity_timeout, + ) + .await +} + +pub async fn relay_bidirectional_with_activity_timeout_and_lease( + client_reader: CR, + client_writer: CW, + server_reader: SR, + server_writer: SW, + c2s_buf_size: usize, + s2c_buf_size: usize, + user: &str, + stats: Arc, + quota_limit: Option, + _buffer_pool: Arc, + traffic_lease: Option>, + activity_timeout: Duration, +) -> Result<()> where CR: AsyncRead + Unpin + Send + 'static, CW: AsyncWrite + Unpin + Send + 'static, @@ -644,11 +852,12 @@ where let mut server = CombinedStream::new(server_reader, server_writer); // Wrap client with stats/activity tracking - let mut client = StatsIo::new( + let mut client = StatsIo::new_with_traffic_lease( client_combined, Arc::clone(&counters), Arc::clone(&stats), user_owned.clone(), + traffic_lease, quota_limit, Arc::clone("a_exceeded), epoch, diff --git a/src/proxy/shared_state.rs b/src/proxy/shared_state.rs index 4fef497..e204890 100644 --- a/src/proxy/shared_state.rs +++ b/src/proxy/shared_state.rs @@ -10,6 +10,7 @@ use tokio::sync::mpsc; use crate::proxy::handshake::{AuthProbeSaturationState, AuthProbeState}; use crate::proxy::middle_relay::{DesyncDedupRotationState, RelayIdleCandidateRegistry}; +use crate::proxy::traffic_limiter::TrafficLimiter; const HANDSHAKE_RECENT_USER_RING_LEN: usize = 64; @@ -65,6 +66,7 @@ pub(crate) struct MiddleRelaySharedState { pub(crate) struct ProxySharedState { pub(crate) handshake: HandshakeSharedState, pub(crate) middle_relay: MiddleRelaySharedState, + pub(crate) traffic_limiter: Arc, pub(crate) conntrack_pressure_active: AtomicBool, pub(crate) conntrack_close_tx: Mutex>>, } @@ -98,6 +100,7 @@ impl ProxySharedState { relay_idle_registry: Mutex::new(RelayIdleCandidateRegistry::default()), relay_idle_mark_seq: AtomicU64::new(0), }, + traffic_limiter: TrafficLimiter::new(), conntrack_pressure_active: AtomicBool::new(false), conntrack_close_tx: Mutex::new(None), }) diff --git a/src/proxy/traffic_limiter.rs b/src/proxy/traffic_limiter.rs new file mode 100644 index 0000000..5c93944 --- /dev/null +++ b/src/proxy/traffic_limiter.rs @@ -0,0 +1,847 @@ +use std::collections::{HashMap, HashSet}; +use std::hash::{Hash, Hasher}; +use std::net::IpAddr; +use std::sync::Arc; +use std::sync::OnceLock; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; + +use arc_swap::ArcSwap; +use dashmap::DashMap; +use ipnetwork::IpNetwork; + +use crate::config::RateLimitBps; + +const REGISTRY_SHARDS: usize = 64; +const FAIR_EPOCH_MS: u64 = 20; +const MAX_BORROW_CHUNK_BYTES: u64 = 32 * 1024; +const CLEANUP_INTERVAL_SECS: u64 = 60; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RateDirection { + Up, + Down, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TrafficConsumeResult { + pub granted: u64, + pub blocked_user: bool, + pub blocked_cidr: bool, +} + +#[derive(Debug, Clone, Copy)] +pub struct TrafficLimiterMetricsSnapshot { + pub user_throttle_up_total: u64, + pub user_throttle_down_total: u64, + pub cidr_throttle_up_total: u64, + pub cidr_throttle_down_total: u64, + pub user_wait_up_ms_total: u64, + pub user_wait_down_ms_total: u64, + pub cidr_wait_up_ms_total: u64, + pub cidr_wait_down_ms_total: u64, + pub user_active_leases: u64, + pub cidr_active_leases: u64, + pub user_policy_entries: u64, + pub cidr_policy_entries: u64, +} + +#[derive(Default)] +struct ScopeMetrics { + throttle_up_total: AtomicU64, + throttle_down_total: AtomicU64, + wait_up_ms_total: AtomicU64, + wait_down_ms_total: AtomicU64, + active_leases: AtomicU64, + policy_entries: AtomicU64, +} + +impl ScopeMetrics { + fn throttle(&self, direction: RateDirection) { + match direction { + RateDirection::Up => { + self.throttle_up_total.fetch_add(1, Ordering::Relaxed); + } + RateDirection::Down => { + self.throttle_down_total.fetch_add(1, Ordering::Relaxed); + } + } + } + + fn wait_ms(&self, direction: RateDirection, wait_ms: u64) { + match direction { + RateDirection::Up => { + self.wait_up_ms_total.fetch_add(wait_ms, Ordering::Relaxed); + } + RateDirection::Down => { + self.wait_down_ms_total.fetch_add(wait_ms, Ordering::Relaxed); + } + } + } +} + +#[derive(Default)] +struct AtomicRatePair { + up_bps: AtomicU64, + down_bps: AtomicU64, +} + +impl AtomicRatePair { + fn set(&self, limits: RateLimitBps) { + self.up_bps.store(limits.up_bps, Ordering::Relaxed); + self.down_bps.store(limits.down_bps, Ordering::Relaxed); + } + + fn get(&self, direction: RateDirection) -> u64 { + match direction { + RateDirection::Up => self.up_bps.load(Ordering::Relaxed), + RateDirection::Down => self.down_bps.load(Ordering::Relaxed), + } + } +} + +#[derive(Default)] +struct DirectionBucket { + epoch: AtomicU64, + used: AtomicU64, +} + +impl DirectionBucket { + fn sync_epoch(&self, epoch: u64) { + let current = self.epoch.load(Ordering::Relaxed); + if current == epoch { + return; + } + if current < epoch + && self + .epoch + .compare_exchange(current, epoch, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + self.used.store(0, Ordering::Relaxed); + } + } + + fn try_consume(&self, cap_bps: u64, requested: u64) -> u64 { + if requested == 0 { + return 0; + } + if cap_bps == 0 { + return requested; + } + + let epoch = current_epoch(); + self.sync_epoch(epoch); + let cap_epoch = bytes_per_epoch(cap_bps); + + loop { + let used = self.used.load(Ordering::Relaxed); + if used >= cap_epoch { + return 0; + } + let remaining = cap_epoch.saturating_sub(used); + let grant = requested.min(remaining); + if grant == 0 { + return 0; + } + let next = used.saturating_add(grant); + if self + .used + .compare_exchange_weak(used, next, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + return grant; + } + } + } + + fn refund(&self, bytes: u64) { + if bytes == 0 { + return; + } + decrement_atomic_saturating(&self.used, bytes); + } +} + +struct UserBucket { + rates: AtomicRatePair, + up: DirectionBucket, + down: DirectionBucket, + active_leases: AtomicU64, +} + +impl UserBucket { + fn new(limits: RateLimitBps) -> Self { + let rates = AtomicRatePair::default(); + rates.set(limits); + Self { + rates, + up: DirectionBucket::default(), + down: DirectionBucket::default(), + active_leases: AtomicU64::new(0), + } + } + + fn set_rates(&self, limits: RateLimitBps) { + self.rates.set(limits); + } + + fn try_consume(&self, direction: RateDirection, requested: u64) -> u64 { + let cap_bps = self.rates.get(direction); + match direction { + RateDirection::Up => self.up.try_consume(cap_bps, requested), + RateDirection::Down => self.down.try_consume(cap_bps, requested), + } + } + + fn refund(&self, direction: RateDirection, bytes: u64) { + match direction { + RateDirection::Up => self.up.refund(bytes), + RateDirection::Down => self.down.refund(bytes), + } + } +} + +#[derive(Default)] +struct CidrDirectionBucket { + epoch: AtomicU64, + used: AtomicU64, + active_users: AtomicU64, +} + +impl CidrDirectionBucket { + fn sync_epoch(&self, epoch: u64) { + let current = self.epoch.load(Ordering::Relaxed); + if current == epoch { + return; + } + if current < epoch + && self + .epoch + .compare_exchange(current, epoch, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + self.used.store(0, Ordering::Relaxed); + self.active_users.store(0, Ordering::Relaxed); + } + } + + fn try_consume( + &self, + user_state: &CidrUserDirectionState, + cap_epoch: u64, + requested: u64, + ) -> u64 { + if requested == 0 || cap_epoch == 0 { + return 0; + } + + let epoch = current_epoch(); + self.sync_epoch(epoch); + user_state.sync_epoch_and_mark_active(epoch, &self.active_users); + let active_users = self.active_users.load(Ordering::Relaxed).max(1); + let fair_share = cap_epoch.saturating_div(active_users).max(1); + + loop { + let total_used = self.used.load(Ordering::Relaxed); + if total_used >= cap_epoch { + return 0; + } + let total_remaining = cap_epoch.saturating_sub(total_used); + let user_used = user_state.used.load(Ordering::Relaxed); + let guaranteed_remaining = fair_share.saturating_sub(user_used); + + let grant = if guaranteed_remaining > 0 { + requested.min(guaranteed_remaining).min(total_remaining) + } else { + requested + .min(total_remaining) + .min(MAX_BORROW_CHUNK_BYTES) + }; + + if grant == 0 { + return 0; + } + + let next_total = total_used.saturating_add(grant); + if self + .used + .compare_exchange_weak( + total_used, + next_total, + Ordering::Relaxed, + Ordering::Relaxed, + ) + .is_ok() + { + user_state.used.fetch_add(grant, Ordering::Relaxed); + return grant; + } + } + } + + fn refund(&self, bytes: u64) { + if bytes == 0 { + return; + } + decrement_atomic_saturating(&self.used, bytes); + } +} + +#[derive(Default)] +struct CidrUserDirectionState { + epoch: AtomicU64, + used: AtomicU64, +} + +impl CidrUserDirectionState { + fn sync_epoch_and_mark_active(&self, epoch: u64, active_users: &AtomicU64) { + let current = self.epoch.load(Ordering::Relaxed); + if current == epoch { + return; + } + if current < epoch + && self + .epoch + .compare_exchange(current, epoch, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + self.used.store(0, Ordering::Relaxed); + active_users.fetch_add(1, Ordering::Relaxed); + } + } + + fn refund(&self, bytes: u64) { + if bytes == 0 { + return; + } + decrement_atomic_saturating(&self.used, bytes); + } +} + +struct CidrUserShare { + active_conns: AtomicU64, + up: CidrUserDirectionState, + down: CidrUserDirectionState, +} + +impl CidrUserShare { + fn new() -> Self { + Self { + active_conns: AtomicU64::new(0), + up: CidrUserDirectionState::default(), + down: CidrUserDirectionState::default(), + } + } +} + +struct CidrBucket { + rates: AtomicRatePair, + up: CidrDirectionBucket, + down: CidrDirectionBucket, + users: ShardedRegistry, + active_leases: AtomicU64, +} + +impl CidrBucket { + fn new(limits: RateLimitBps) -> Self { + let rates = AtomicRatePair::default(); + rates.set(limits); + Self { + rates, + up: CidrDirectionBucket::default(), + down: CidrDirectionBucket::default(), + users: ShardedRegistry::new(REGISTRY_SHARDS), + active_leases: AtomicU64::new(0), + } + } + + fn set_rates(&self, limits: RateLimitBps) { + self.rates.set(limits); + } + + fn acquire_user_share(&self, user: &str) -> Arc { + let share = self.users.get_or_insert_with(user, CidrUserShare::new); + share.active_conns.fetch_add(1, Ordering::Relaxed); + share + } + + fn release_user_share(&self, user: &str, share: &Arc) { + decrement_atomic_saturating(&share.active_conns, 1); + let share_for_remove = Arc::clone(share); + let _ = self.users.remove_if(user, |candidate| { + Arc::ptr_eq(candidate, &share_for_remove) + && candidate.active_conns.load(Ordering::Relaxed) == 0 + }); + } + + fn try_consume_for_user( + &self, + direction: RateDirection, + share: &CidrUserShare, + requested: u64, + ) -> u64 { + let cap_bps = self.rates.get(direction); + if cap_bps == 0 { + return requested; + } + let cap_epoch = bytes_per_epoch(cap_bps); + match direction { + RateDirection::Up => self.up.try_consume(&share.up, cap_epoch, requested), + RateDirection::Down => self.down.try_consume(&share.down, cap_epoch, requested), + } + } + + fn refund_for_user(&self, direction: RateDirection, share: &CidrUserShare, bytes: u64) { + match direction { + RateDirection::Up => { + self.up.refund(bytes); + share.up.refund(bytes); + } + RateDirection::Down => { + self.down.refund(bytes); + share.down.refund(bytes); + } + } + } + + fn cleanup_idle_users(&self) { + self.users + .retain(|_, share| share.active_conns.load(Ordering::Relaxed) > 0); + } +} + +#[derive(Clone)] +struct CidrRule { + key: String, + cidr: IpNetwork, + limits: RateLimitBps, + prefix_len: u8, +} + +#[derive(Default)] +struct PolicySnapshot { + user_limits: HashMap, + cidr_rules_v4: Vec, + cidr_rules_v6: Vec, + cidr_rule_keys: HashSet, +} + +impl PolicySnapshot { + fn match_cidr(&self, ip: IpAddr) -> Option<&CidrRule> { + match ip { + IpAddr::V4(_) => self.cidr_rules_v4.iter().find(|rule| rule.cidr.contains(ip)), + IpAddr::V6(_) => self.cidr_rules_v6.iter().find(|rule| rule.cidr.contains(ip)), + } + } +} + +struct ShardedRegistry { + shards: Box<[DashMap>]>, + mask: usize, +} + +impl ShardedRegistry { + fn new(shards: usize) -> Self { + let shard_count = shards.max(1).next_power_of_two(); + let mut items = Vec::with_capacity(shard_count); + for _ in 0..shard_count { + items.push(DashMap::>::new()); + } + Self { + shards: items.into_boxed_slice(), + mask: shard_count.saturating_sub(1), + } + } + + fn shard_index(&self, key: &str) -> usize { + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + key.hash(&mut hasher); + (hasher.finish() as usize) & self.mask + } + + fn get_or_insert_with(&self, key: &str, make: F) -> Arc + where + F: FnOnce() -> T, + { + let shard = &self.shards[self.shard_index(key)]; + match shard.entry(key.to_string()) { + dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), + dashmap::mapref::entry::Entry::Vacant(slot) => { + let value = Arc::new(make()); + slot.insert(Arc::clone(&value)); + value + } + } + } + + fn retain(&self, predicate: F) + where + F: Fn(&String, &Arc) -> bool + Copy, + { + for shard in &*self.shards { + shard.retain(|key, value| predicate(key, value)); + } + } + + fn remove_if(&self, key: &str, predicate: F) -> bool + where + F: Fn(&Arc) -> bool, + { + let shard = &self.shards[self.shard_index(key)]; + let should_remove = match shard.get(key) { + Some(entry) => predicate(entry.value()), + None => false, + }; + if !should_remove { + return false; + } + shard.remove(key).is_some() + } +} + +pub struct TrafficLease { + limiter: Arc, + user_bucket: Option>, + cidr_bucket: Option>, + cidr_user_key: Option, + cidr_user_share: Option>, +} + +impl TrafficLease { + pub fn try_consume(&self, direction: RateDirection, requested: u64) -> TrafficConsumeResult { + if requested == 0 { + return TrafficConsumeResult { + granted: 0, + blocked_user: false, + blocked_cidr: false, + }; + } + + let mut granted = requested; + if let Some(user_bucket) = self.user_bucket.as_ref() { + let user_granted = user_bucket.try_consume(direction, granted); + if user_granted == 0 { + self.limiter.observe_throttle(direction, true, false); + return TrafficConsumeResult { + granted: 0, + blocked_user: true, + blocked_cidr: false, + }; + } + granted = user_granted; + } + + if let (Some(cidr_bucket), Some(cidr_user_share)) = + (self.cidr_bucket.as_ref(), self.cidr_user_share.as_ref()) + { + let cidr_granted = cidr_bucket.try_consume_for_user(direction, cidr_user_share, granted); + if cidr_granted < granted + && let Some(user_bucket) = self.user_bucket.as_ref() + { + user_bucket.refund(direction, granted.saturating_sub(cidr_granted)); + } + if cidr_granted == 0 { + self.limiter.observe_throttle(direction, false, true); + return TrafficConsumeResult { + granted: 0, + blocked_user: false, + blocked_cidr: true, + }; + } + granted = cidr_granted; + } + + TrafficConsumeResult { + granted, + blocked_user: false, + blocked_cidr: false, + } + } + + pub fn refund(&self, direction: RateDirection, bytes: u64) { + if bytes == 0 { + return; + } + + if let Some(user_bucket) = self.user_bucket.as_ref() { + user_bucket.refund(direction, bytes); + } + if let (Some(cidr_bucket), Some(cidr_user_share)) = + (self.cidr_bucket.as_ref(), self.cidr_user_share.as_ref()) + { + cidr_bucket.refund_for_user(direction, cidr_user_share, bytes); + } + } + + pub fn observe_wait_ms( + &self, + direction: RateDirection, + blocked_user: bool, + blocked_cidr: bool, + wait_ms: u64, + ) { + if wait_ms == 0 { + return; + } + self.limiter + .observe_wait(direction, blocked_user, blocked_cidr, wait_ms); + } +} + +impl Drop for TrafficLease { + fn drop(&mut self) { + if let Some(bucket) = self.user_bucket.as_ref() { + decrement_atomic_saturating(&bucket.active_leases, 1); + decrement_atomic_saturating(&self.limiter.user_scope.active_leases, 1); + } + + if let Some(bucket) = self.cidr_bucket.as_ref() { + if let (Some(user_key), Some(share)) = + (self.cidr_user_key.as_ref(), self.cidr_user_share.as_ref()) + { + bucket.release_user_share(user_key, share); + } + decrement_atomic_saturating(&bucket.active_leases, 1); + decrement_atomic_saturating(&self.limiter.cidr_scope.active_leases, 1); + } + } +} + +pub struct TrafficLimiter { + policy: ArcSwap, + user_buckets: ShardedRegistry, + cidr_buckets: ShardedRegistry, + user_scope: ScopeMetrics, + cidr_scope: ScopeMetrics, + last_cleanup_epoch_secs: AtomicU64, +} + +impl TrafficLimiter { + pub fn new() -> Arc { + Arc::new(Self { + policy: ArcSwap::from_pointee(PolicySnapshot::default()), + user_buckets: ShardedRegistry::new(REGISTRY_SHARDS), + cidr_buckets: ShardedRegistry::new(REGISTRY_SHARDS), + user_scope: ScopeMetrics::default(), + cidr_scope: ScopeMetrics::default(), + last_cleanup_epoch_secs: AtomicU64::new(0), + }) + } + + pub fn apply_policy( + &self, + user_limits: HashMap, + cidr_limits: HashMap, + ) { + let filtered_users = user_limits + .into_iter() + .filter(|(_, limit)| limit.up_bps > 0 || limit.down_bps > 0) + .collect::>(); + + let mut cidr_rules_v4 = Vec::new(); + let mut cidr_rules_v6 = Vec::new(); + let mut cidr_rule_keys = HashSet::new(); + for (cidr, limits) in cidr_limits { + if limits.up_bps == 0 && limits.down_bps == 0 { + continue; + } + let key = cidr.to_string(); + let rule = CidrRule { + key: key.clone(), + cidr, + limits, + prefix_len: cidr.prefix(), + }; + cidr_rule_keys.insert(key); + match rule.cidr { + IpNetwork::V4(_) => cidr_rules_v4.push(rule), + IpNetwork::V6(_) => cidr_rules_v6.push(rule), + } + } + + cidr_rules_v4.sort_by(|a, b| b.prefix_len.cmp(&a.prefix_len)); + cidr_rules_v6.sort_by(|a, b| b.prefix_len.cmp(&a.prefix_len)); + + self.user_scope + .policy_entries + .store(filtered_users.len() as u64, Ordering::Relaxed); + self.cidr_scope + .policy_entries + .store(cidr_rule_keys.len() as u64, Ordering::Relaxed); + + self.policy.store(Arc::new(PolicySnapshot { + user_limits: filtered_users, + cidr_rules_v4, + cidr_rules_v6, + cidr_rule_keys, + })); + + self.maybe_cleanup(); + } + + pub fn acquire_lease( + self: &Arc, + user: &str, + client_ip: IpAddr, + ) -> Option> { + let policy = self.policy.load_full(); + let mut user_bucket = None; + if let Some(limit) = policy.user_limits.get(user).copied() { + let bucket = self + .user_buckets + .get_or_insert_with(user, || UserBucket::new(limit)); + bucket.set_rates(limit); + bucket.active_leases.fetch_add(1, Ordering::Relaxed); + self.user_scope.active_leases.fetch_add(1, Ordering::Relaxed); + user_bucket = Some(bucket); + } + + let mut cidr_bucket = None; + let mut cidr_user_key = None; + let mut cidr_user_share = None; + if let Some(rule) = policy.match_cidr(client_ip) { + let bucket = self + .cidr_buckets + .get_or_insert_with(rule.key.as_str(), || CidrBucket::new(rule.limits)); + bucket.set_rates(rule.limits); + bucket.active_leases.fetch_add(1, Ordering::Relaxed); + self.cidr_scope.active_leases.fetch_add(1, Ordering::Relaxed); + let share = bucket.acquire_user_share(user); + cidr_user_key = Some(user.to_string()); + cidr_user_share = Some(share); + cidr_bucket = Some(bucket); + } + + if user_bucket.is_none() && cidr_bucket.is_none() { + return None; + } + + self.maybe_cleanup(); + Some(Arc::new(TrafficLease { + limiter: Arc::clone(self), + user_bucket, + cidr_bucket, + cidr_user_key, + cidr_user_share, + })) + } + + pub fn metrics_snapshot(&self) -> TrafficLimiterMetricsSnapshot { + TrafficLimiterMetricsSnapshot { + user_throttle_up_total: self.user_scope.throttle_up_total.load(Ordering::Relaxed), + user_throttle_down_total: self.user_scope.throttle_down_total.load(Ordering::Relaxed), + cidr_throttle_up_total: self.cidr_scope.throttle_up_total.load(Ordering::Relaxed), + cidr_throttle_down_total: self.cidr_scope.throttle_down_total.load(Ordering::Relaxed), + user_wait_up_ms_total: self.user_scope.wait_up_ms_total.load(Ordering::Relaxed), + user_wait_down_ms_total: self.user_scope.wait_down_ms_total.load(Ordering::Relaxed), + cidr_wait_up_ms_total: self.cidr_scope.wait_up_ms_total.load(Ordering::Relaxed), + cidr_wait_down_ms_total: self.cidr_scope.wait_down_ms_total.load(Ordering::Relaxed), + user_active_leases: self.user_scope.active_leases.load(Ordering::Relaxed), + cidr_active_leases: self.cidr_scope.active_leases.load(Ordering::Relaxed), + user_policy_entries: self.user_scope.policy_entries.load(Ordering::Relaxed), + cidr_policy_entries: self.cidr_scope.policy_entries.load(Ordering::Relaxed), + } + } + + fn observe_throttle(&self, direction: RateDirection, blocked_user: bool, blocked_cidr: bool) { + if blocked_user { + self.user_scope.throttle(direction); + } + if blocked_cidr { + self.cidr_scope.throttle(direction); + } + } + + fn observe_wait( + &self, + direction: RateDirection, + blocked_user: bool, + blocked_cidr: bool, + wait_ms: u64, + ) { + if blocked_user { + self.user_scope.wait_ms(direction, wait_ms); + } + if blocked_cidr { + self.cidr_scope.wait_ms(direction, wait_ms); + } + } + + fn maybe_cleanup(&self) { + let now_epoch_secs = now_epoch_secs(); + let last = self.last_cleanup_epoch_secs.load(Ordering::Relaxed); + if now_epoch_secs.saturating_sub(last) < CLEANUP_INTERVAL_SECS { + return; + } + if self + .last_cleanup_epoch_secs + .compare_exchange(last, now_epoch_secs, Ordering::Relaxed, Ordering::Relaxed) + .is_err() + { + return; + } + + let policy = self.policy.load_full(); + self.user_buckets.retain(|user, bucket| { + bucket.active_leases.load(Ordering::Relaxed) > 0 || policy.user_limits.contains_key(user) + }); + self.cidr_buckets.retain(|cidr_key, bucket| { + bucket.cleanup_idle_users(); + bucket.active_leases.load(Ordering::Relaxed) > 0 + || policy.cidr_rule_keys.contains(cidr_key) + }); + } +} + +pub fn next_refill_delay() -> Duration { + let start = limiter_epoch_start(); + let elapsed_ms = start.elapsed().as_millis() as u64; + let epoch_pos = elapsed_ms % FAIR_EPOCH_MS; + let wait_ms = FAIR_EPOCH_MS.saturating_sub(epoch_pos).max(1); + Duration::from_millis(wait_ms) +} + +fn decrement_atomic_saturating(counter: &AtomicU64, by: u64) { + if by == 0 { + return; + } + let mut current = counter.load(Ordering::Relaxed); + loop { + if current == 0 { + return; + } + let next = current.saturating_sub(by); + match counter.compare_exchange_weak(current, next, Ordering::Relaxed, Ordering::Relaxed) { + Ok(_) => return, + Err(actual) => current = actual, + } + } +} + +fn now_epoch_secs() -> u64 { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs() +} + +fn bytes_per_epoch(bps: u64) -> u64 { + if bps == 0 { + return 0; + } + let numerator = bps.saturating_mul(FAIR_EPOCH_MS); + let bytes = numerator.saturating_div(8_000); + bytes.max(1) +} + +fn current_epoch() -> u64 { + let start = limiter_epoch_start(); + let elapsed_ms = start.elapsed().as_millis() as u64; + elapsed_ms / FAIR_EPOCH_MS +} + +fn limiter_epoch_start() -> &'static Instant { + static START: OnceLock = OnceLock::new(); + START.get_or_init(Instant::now) +}