Drafting Traffic Control

Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
This commit is contained in:
Alexey
2026-04-15 13:14:45 +03:00
parent 32d5cee01c
commit 21ca1014ae
12 changed files with 1430 additions and 53 deletions

View File

@@ -52,6 +52,7 @@
//! - `SharedCounters` (atomics) let the watchdog read stats without locking
use crate::error::{ProxyError, Result};
use crate::proxy::traffic_limiter::{RateDirection, TrafficLease, next_refill_delay};
use crate::stats::{Stats, UserStats};
use crate::stream::BufferPool;
use std::io;
@@ -61,7 +62,7 @@ use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes};
use tokio::time::Instant;
use tokio::time::{Instant, Sleep};
use tracing::{debug, trace, warn};
// ============= Constants =============
@@ -210,12 +211,24 @@ struct StatsIo<S> {
stats: Arc<Stats>,
user: String,
user_stats: Arc<UserStats>,
traffic_lease: Option<Arc<TrafficLease>>,
c2s_rate_debt_bytes: u64,
c2s_wait: RateWaitState,
s2c_wait: RateWaitState,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
quota_bytes_since_check: u64,
epoch: Instant,
}
#[derive(Default)]
struct RateWaitState {
sleep: Option<Pin<Box<Sleep>>>,
started_at: Option<Instant>,
blocked_user: bool,
blocked_cidr: bool,
}
impl<S> StatsIo<S> {
fn new(
inner: S,
@@ -225,6 +238,28 @@ impl<S> StatsIo<S> {
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
epoch: Instant,
) -> Self {
Self::new_with_traffic_lease(
inner,
counters,
stats,
user,
None,
quota_limit,
quota_exceeded,
epoch,
)
}
fn new_with_traffic_lease(
inner: S,
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
traffic_lease: Option<Arc<TrafficLease>>,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
epoch: Instant,
) -> Self {
// Mark initial activity so the watchdog doesn't fire before data flows
counters.touch(Instant::now(), epoch);
@@ -235,12 +270,97 @@ impl<S> StatsIo<S> {
stats,
user,
user_stats,
traffic_lease,
c2s_rate_debt_bytes: 0,
c2s_wait: RateWaitState::default(),
s2c_wait: RateWaitState::default(),
quota_limit,
quota_exceeded,
quota_bytes_since_check: 0,
epoch,
}
}
fn record_wait(
wait: &mut RateWaitState,
lease: Option<&Arc<TrafficLease>>,
direction: RateDirection,
) {
let Some(started_at) = wait.started_at.take() else {
return;
};
let wait_ms = started_at
.elapsed()
.as_millis()
.min(u128::from(u64::MAX)) as u64;
if let Some(lease) = lease {
lease.observe_wait_ms(
direction,
wait.blocked_user,
wait.blocked_cidr,
wait_ms,
);
}
wait.blocked_user = false;
wait.blocked_cidr = false;
}
fn arm_wait(wait: &mut RateWaitState, blocked_user: bool, blocked_cidr: bool) {
if wait.sleep.is_none() {
wait.sleep = Some(Box::pin(tokio::time::sleep(next_refill_delay())));
wait.started_at = Some(Instant::now());
}
wait.blocked_user |= blocked_user;
wait.blocked_cidr |= blocked_cidr;
}
fn poll_wait(
wait: &mut RateWaitState,
cx: &mut Context<'_>,
lease: Option<&Arc<TrafficLease>>,
direction: RateDirection,
) -> Poll<()> {
let Some(sleep) = wait.sleep.as_mut() else {
return Poll::Ready(());
};
if sleep.as_mut().poll(cx).is_pending() {
return Poll::Pending;
}
wait.sleep = None;
Self::record_wait(wait, lease, direction);
Poll::Ready(())
}
fn settle_c2s_rate_debt(&mut self, cx: &mut Context<'_>) -> Poll<()> {
let Some(lease) = self.traffic_lease.as_ref() else {
self.c2s_rate_debt_bytes = 0;
return Poll::Ready(());
};
while self.c2s_rate_debt_bytes > 0 {
let consume = lease.try_consume(RateDirection::Up, self.c2s_rate_debt_bytes);
if consume.granted > 0 {
self.c2s_rate_debt_bytes =
self.c2s_rate_debt_bytes.saturating_sub(consume.granted);
continue;
}
Self::arm_wait(
&mut self.c2s_wait,
consume.blocked_user,
consume.blocked_cidr,
);
if Self::poll_wait(&mut self.c2s_wait, cx, Some(lease), RateDirection::Up).is_pending()
{
return Poll::Pending;
}
}
if Self::poll_wait(&mut self.c2s_wait, cx, Some(lease), RateDirection::Up).is_pending() {
return Poll::Pending;
}
Poll::Ready(())
}
}
#[derive(Debug)]
@@ -286,6 +406,25 @@ fn should_immediate_quota_check(remaining_before: u64, charge_bytes: u64) -> boo
remaining_before <= QUOTA_NEAR_LIMIT_BYTES || charge_bytes >= QUOTA_LARGE_CHARGE_BYTES
}
fn refund_reserved_quota_bytes(user_stats: &UserStats, reserved_bytes: u64) {
if reserved_bytes == 0 {
return;
}
let mut current = user_stats.quota_used.load(Ordering::Relaxed);
loop {
let next = current.saturating_sub(reserved_bytes);
match user_stats.quota_used.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return,
Err(observed) => current = observed,
}
}
}
impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
fn poll_read(
self: Pin<&mut Self>,
@@ -296,6 +435,9 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
if this.quota_exceeded.load(Ordering::Acquire) {
return Poll::Ready(Err(quota_io_error()));
}
if this.settle_c2s_rate_debt(cx).is_pending() {
return Poll::Pending;
}
let mut remaining_before = None;
if let Some(limit) = this.quota_limit {
@@ -377,6 +519,11 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
.add_user_octets_from_handle(this.user_stats.as_ref(), n_to_charge);
this.stats
.increment_user_msgs_from_handle(this.user_stats.as_ref());
if this.traffic_lease.is_some() {
this.c2s_rate_debt_bytes =
this.c2s_rate_debt_bytes.saturating_add(n_to_charge);
let _ = this.settle_c2s_rate_debt(cx);
}
trace!(user = %this.user, bytes = n, "C->S");
}
@@ -398,28 +545,66 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
return Poll::Ready(Err(quota_io_error()));
}
let mut shaper_reserved_bytes = 0u64;
let mut write_buf = buf;
if let Some(lease) = this.traffic_lease.as_ref() {
if !buf.is_empty() {
loop {
let consume = lease.try_consume(RateDirection::Down, buf.len() as u64);
if consume.granted > 0 {
shaper_reserved_bytes = consume.granted;
if consume.granted < buf.len() as u64 {
write_buf = &buf[..consume.granted as usize];
}
let _ = Self::poll_wait(
&mut this.s2c_wait,
cx,
Some(lease),
RateDirection::Down,
);
break;
}
Self::arm_wait(
&mut this.s2c_wait,
consume.blocked_user,
consume.blocked_cidr,
);
if Self::poll_wait(&mut this.s2c_wait, cx, Some(lease), RateDirection::Down)
.is_pending()
{
return Poll::Pending;
}
}
} else {
let _ = Self::poll_wait(&mut this.s2c_wait, cx, Some(lease), RateDirection::Down);
}
}
let mut remaining_before = None;
let mut reserved_bytes = 0u64;
let mut write_buf = buf;
if let Some(limit) = this.quota_limit {
if !buf.is_empty() {
if !write_buf.is_empty() {
let mut reserve_rounds = 0usize;
while reserved_bytes == 0 {
let used_before = this.user_stats.quota_used();
let remaining = limit.saturating_sub(used_before);
if remaining == 0 {
if let Some(lease) = this.traffic_lease.as_ref() {
lease.refund(RateDirection::Down, shaper_reserved_bytes);
}
this.quota_exceeded.store(true, Ordering::Release);
return Poll::Ready(Err(quota_io_error()));
}
remaining_before = Some(remaining);
let desired = remaining.min(buf.len() as u64);
let desired = remaining.min(write_buf.len() as u64);
let mut saw_contention = false;
for _ in 0..QUOTA_RESERVE_SPIN_RETRIES {
match this.user_stats.quota_try_reserve(desired, limit) {
Ok(_) => {
reserved_bytes = desired;
write_buf = &buf[..desired as usize];
write_buf = &write_buf[..desired as usize];
break;
}
Err(crate::stats::QuotaReserveError::LimitExceeded) => {
@@ -434,6 +619,9 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
if reserved_bytes == 0 {
reserve_rounds = reserve_rounds.saturating_add(1);
if reserve_rounds >= QUOTA_RESERVE_MAX_ROUNDS {
if let Some(lease) = this.traffic_lease.as_ref() {
lease.refund(RateDirection::Down, shaper_reserved_bytes);
}
this.quota_exceeded.store(true, Ordering::Release);
return Poll::Ready(Err(quota_io_error()));
}
@@ -446,6 +634,9 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
let used_before = this.user_stats.quota_used();
let remaining = limit.saturating_sub(used_before);
if remaining == 0 {
if let Some(lease) = this.traffic_lease.as_ref() {
lease.refund(RateDirection::Down, shaper_reserved_bytes);
}
this.quota_exceeded.store(true, Ordering::Release);
return Poll::Ready(Err(quota_io_error()));
}
@@ -456,23 +647,17 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
match Pin::new(&mut this.inner).poll_write(cx, write_buf) {
Poll::Ready(Ok(n)) => {
if reserved_bytes > n as u64 {
let refund = reserved_bytes - n as u64;
let mut current = this.user_stats.quota_used.load(Ordering::Relaxed);
loop {
let next = current.saturating_sub(refund);
match this.user_stats.quota_used.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(observed) => current = observed,
}
}
refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes - n as u64);
}
if shaper_reserved_bytes > n as u64
&& let Some(lease) = this.traffic_lease.as_ref()
{
lease.refund(RateDirection::Down, shaper_reserved_bytes - n as u64);
}
if n > 0 {
if let Some(lease) = this.traffic_lease.as_ref() {
Self::record_wait(&mut this.s2c_wait, Some(lease), RateDirection::Down);
}
let n_to_charge = n as u64;
// S→C: data written to client
@@ -512,37 +697,23 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
}
Poll::Ready(Err(err)) => {
if reserved_bytes > 0 {
let mut current = this.user_stats.quota_used.load(Ordering::Relaxed);
loop {
let next = current.saturating_sub(reserved_bytes);
match this.user_stats.quota_used.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(observed) => current = observed,
}
}
refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes);
}
if shaper_reserved_bytes > 0
&& let Some(lease) = this.traffic_lease.as_ref()
{
lease.refund(RateDirection::Down, shaper_reserved_bytes);
}
Poll::Ready(Err(err))
}
Poll::Pending => {
if reserved_bytes > 0 {
let mut current = this.user_stats.quota_used.load(Ordering::Relaxed);
loop {
let next = current.saturating_sub(reserved_bytes);
match this.user_stats.quota_used.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(observed) => current = observed,
}
}
refund_reserved_quota_bytes(this.user_stats.as_ref(), reserved_bytes);
}
if shaper_reserved_bytes > 0
&& let Some(lease) = this.traffic_lease.as_ref()
{
lease.refund(RateDirection::Down, shaper_reserved_bytes);
}
Poll::Pending
}
@@ -627,6 +798,43 @@ pub async fn relay_bidirectional_with_activity_timeout<CR, CW, SR, SW>(
_buffer_pool: Arc<BufferPool>,
activity_timeout: Duration,
) -> Result<()>
where
CR: AsyncRead + Unpin + Send + 'static,
CW: AsyncWrite + Unpin + Send + 'static,
SR: AsyncRead + Unpin + Send + 'static,
SW: AsyncWrite + Unpin + Send + 'static,
{
relay_bidirectional_with_activity_timeout_and_lease(
client_reader,
client_writer,
server_reader,
server_writer,
c2s_buf_size,
s2c_buf_size,
user,
stats,
quota_limit,
_buffer_pool,
None,
activity_timeout,
)
.await
}
pub async fn relay_bidirectional_with_activity_timeout_and_lease<CR, CW, SR, SW>(
client_reader: CR,
client_writer: CW,
server_reader: SR,
server_writer: SW,
c2s_buf_size: usize,
s2c_buf_size: usize,
user: &str,
stats: Arc<Stats>,
quota_limit: Option<u64>,
_buffer_pool: Arc<BufferPool>,
traffic_lease: Option<Arc<TrafficLease>>,
activity_timeout: Duration,
) -> Result<()>
where
CR: AsyncRead + Unpin + Send + 'static,
CW: AsyncWrite + Unpin + Send + 'static,
@@ -644,11 +852,12 @@ where
let mut server = CombinedStream::new(server_reader, server_writer);
// Wrap client with stats/activity tracking
let mut client = StatsIo::new(
let mut client = StatsIo::new_with_traffic_lease(
client_combined,
Arc::clone(&counters),
Arc::clone(&stats),
user_owned.clone(),
traffic_lease,
quota_limit,
Arc::clone(&quota_exceeded),
epoch,