mirror of https://github.com/telemt/telemt.git
Merge upstream/flow into flow
This commit is contained in:
commit
1b9f483a08
|
|
@ -0,0 +1,39 @@
|
|||
name: Build
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "*" ]
|
||||
pull_request:
|
||||
branches: [ "*" ]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install latest stable Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Cache cargo registry & build artifacts
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-cargo-
|
||||
|
||||
- name: Build Release
|
||||
run: cargo build --release --verbose
|
||||
|
|
@ -26,6 +26,9 @@ jobs:
|
|||
name: GNU ${{ matrix.target }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
container:
|
||||
image: rust:slim-bookworm
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
|
@ -47,8 +50,8 @@ jobs:
|
|||
|
||||
- name: Install deps
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y \
|
||||
apt-get update
|
||||
apt-get install -y \
|
||||
build-essential \
|
||||
clang \
|
||||
lld \
|
||||
|
|
@ -69,14 +72,10 @@ jobs:
|
|||
if [ "${{ matrix.target }}" = "aarch64-unknown-linux-gnu" ]; then
|
||||
export CC=aarch64-linux-gnu-gcc
|
||||
export CXX=aarch64-linux-gnu-g++
|
||||
export CC_aarch64_unknown_linux_gnu=aarch64-linux-gnu-gcc
|
||||
export CXX_aarch64_unknown_linux_gnu=aarch64-linux-gnu-g++
|
||||
export RUSTFLAGS="-C linker=aarch64-linux-gnu-gcc"
|
||||
else
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
export CC_x86_64_unknown_linux_gnu=clang
|
||||
export CXX_x86_64_unknown_linux_gnu=clang++
|
||||
export RUSTFLAGS="-C linker=clang -C link-arg=-fuse-ld=lld"
|
||||
fi
|
||||
|
||||
|
|
@ -85,20 +84,19 @@ jobs:
|
|||
- name: Package
|
||||
run: |
|
||||
mkdir -p dist
|
||||
BIN=target/${{ matrix.target }}/release/${{ env.BINARY_NAME }}
|
||||
|
||||
cp "$BIN" dist/${{ env.BINARY_NAME }}-${{ matrix.target }}
|
||||
cp target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} dist/telemt
|
||||
|
||||
cd dist
|
||||
tar -czf ${{ matrix.asset }}.tar.gz ${{ env.BINARY_NAME }}-${{ matrix.target }}
|
||||
tar -czf ${{ matrix.asset }}.tar.gz \
|
||||
--owner=0 --group=0 --numeric-owner \
|
||||
telemt
|
||||
|
||||
sha256sum ${{ matrix.asset }}.tar.gz > ${{ matrix.asset }}.sha256
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.asset }}
|
||||
path: |
|
||||
dist/${{ matrix.asset }}.tar.gz
|
||||
dist/${{ matrix.asset }}.sha256
|
||||
path: dist/*
|
||||
|
||||
# ==========================
|
||||
# MUSL
|
||||
|
|
@ -125,43 +123,7 @@ jobs:
|
|||
- name: Install deps
|
||||
run: |
|
||||
apt-get update
|
||||
apt-get install -y \
|
||||
musl-tools \
|
||||
pkg-config \
|
||||
curl
|
||||
|
||||
- uses: actions/cache@v4
|
||||
if: matrix.target == 'aarch64-unknown-linux-musl'
|
||||
with:
|
||||
path: ~/.musl-aarch64
|
||||
key: musl-toolchain-aarch64-v1
|
||||
|
||||
- name: Install aarch64 musl toolchain
|
||||
if: matrix.target == 'aarch64-unknown-linux-musl'
|
||||
run: |
|
||||
set -e
|
||||
|
||||
TOOLCHAIN_DIR="$HOME/.musl-aarch64"
|
||||
ARCHIVE="aarch64-linux-musl-cross.tgz"
|
||||
URL="https://github.com/telemt/telemt/releases/download/toolchains/$ARCHIVE"
|
||||
|
||||
if [ -x "$TOOLCHAIN_DIR/bin/aarch64-linux-musl-gcc" ]; then
|
||||
echo "✅ MUSL toolchain already installed"
|
||||
else
|
||||
echo "⬇️ Downloading musl toolchain from Telemt GitHub Releases..."
|
||||
|
||||
curl -fL \
|
||||
--retry 5 \
|
||||
--retry-delay 3 \
|
||||
--connect-timeout 10 \
|
||||
--max-time 120 \
|
||||
-o "$ARCHIVE" "$URL"
|
||||
|
||||
mkdir -p "$TOOLCHAIN_DIR"
|
||||
tar -xzf "$ARCHIVE" --strip-components=1 -C "$TOOLCHAIN_DIR"
|
||||
fi
|
||||
|
||||
echo "$TOOLCHAIN_DIR/bin" >> $GITHUB_PATH
|
||||
apt-get install -y musl-tools pkg-config curl
|
||||
|
||||
- name: Add rust target
|
||||
run: rustup target add ${{ matrix.target }}
|
||||
|
|
@ -178,11 +140,9 @@ jobs:
|
|||
run: |
|
||||
if [ "${{ matrix.target }}" = "aarch64-unknown-linux-musl" ]; then
|
||||
export CC=aarch64-linux-musl-gcc
|
||||
export CC_aarch64_unknown_linux_musl=aarch64-linux-musl-gcc
|
||||
export RUSTFLAGS="-C target-feature=+crt-static -C linker=aarch64-linux-musl-gcc"
|
||||
else
|
||||
export CC=musl-gcc
|
||||
export CC_x86_64_unknown_linux_musl=musl-gcc
|
||||
export RUSTFLAGS="-C target-feature=+crt-static"
|
||||
fi
|
||||
|
||||
|
|
@ -191,119 +151,93 @@ jobs:
|
|||
- name: Package
|
||||
run: |
|
||||
mkdir -p dist
|
||||
BIN=target/${{ matrix.target }}/release/${{ env.BINARY_NAME }}
|
||||
|
||||
cp "$BIN" dist/${{ env.BINARY_NAME }}-${{ matrix.target }}
|
||||
cp target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} dist/telemt
|
||||
|
||||
cd dist
|
||||
tar -czf ${{ matrix.asset }}.tar.gz ${{ env.BINARY_NAME }}-${{ matrix.target }}
|
||||
tar -czf ${{ matrix.asset }}.tar.gz \
|
||||
--owner=0 --group=0 --numeric-owner \
|
||||
telemt
|
||||
|
||||
sha256sum ${{ matrix.asset }}.tar.gz > ${{ matrix.asset }}.sha256
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.asset }}
|
||||
path: |
|
||||
dist/${{ matrix.asset }}.tar.gz
|
||||
dist/${{ matrix.asset }}.sha256
|
||||
path: dist/*
|
||||
|
||||
# ==========================
|
||||
# OpenBSD
|
||||
# Release
|
||||
# ==========================
|
||||
build-openbsd:
|
||||
name: OpenBSD ${{ matrix.arch }}
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- arch: x86_64
|
||||
asset: telemt-x86_64-openbsd
|
||||
rustflags: -C opt-level=3
|
||||
- arch: aarch64
|
||||
asset: telemt-aarch64-openbsd
|
||||
rustflags: -C opt-level=3
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Build in OpenBSD VM
|
||||
uses: vmactions/openbsd-vm@v1
|
||||
with:
|
||||
release: "7.8"
|
||||
arch: ${{ matrix.arch }}
|
||||
usesh: true
|
||||
sync: sshfs
|
||||
envs: RUSTFLAGS CARGO_TERM_COLOR
|
||||
prepare: |
|
||||
pkg_add rust
|
||||
run: |
|
||||
set -e
|
||||
RUSTC_VERSION=$(rustc --version | awk '{print $2}')
|
||||
RUSTC_MAJOR=$(echo "$RUSTC_VERSION" | cut -d. -f1)
|
||||
RUSTC_MINOR=$(echo "$RUSTC_VERSION" | cut -d. -f2)
|
||||
REQUIRED_MAJOR=1
|
||||
REQUIRED_MINOR=85
|
||||
if [ "$RUSTC_MAJOR" -lt "$REQUIRED_MAJOR" ] || { [ "$RUSTC_MAJOR" -eq "$REQUIRED_MAJOR" ] && [ "$RUSTC_MINOR" -lt "$REQUIRED_MINOR" ]; }; then
|
||||
echo "rustc ${REQUIRED_MAJOR}.${REQUIRED_MINOR}.0 or newer is required for this project (found ${RUSTC_VERSION})."
|
||||
exit 1
|
||||
fi
|
||||
cargo build --release --locked --verbose
|
||||
env:
|
||||
RUSTFLAGS: ${{ matrix.rustflags }}
|
||||
|
||||
- name: Package
|
||||
run: |
|
||||
mkdir -p dist
|
||||
cp target/release/${{ env.BINARY_NAME }} dist/${{ env.BINARY_NAME }}-${{ matrix.arch }}-unknown-openbsd
|
||||
cd dist
|
||||
tar -czf ${{ matrix.asset }}.tar.gz ${{ env.BINARY_NAME }}-${{ matrix.arch }}-unknown-openbsd
|
||||
sha256sum ${{ matrix.asset }}.tar.gz > ${{ matrix.asset }}.sha256
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.asset }}
|
||||
path: |
|
||||
dist/${{ matrix.asset }}.tar.gz
|
||||
dist/${{ matrix.asset }}.sha256
|
||||
|
||||
# ==========================
|
||||
# Docker
|
||||
# ==========================
|
||||
docker:
|
||||
name: Docker
|
||||
release:
|
||||
name: Release
|
||||
runs-on: ubuntu-latest
|
||||
needs: [build-gnu, build-musl]
|
||||
continue-on-error: true
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: artifacts
|
||||
|
||||
- name: Extract binaries
|
||||
- name: Flatten
|
||||
run: |
|
||||
mkdir dist
|
||||
find artifacts -name "*.tar.gz" -exec tar -xzf {} -C dist \;
|
||||
find artifacts -type f -exec cp {} dist/ \;
|
||||
|
||||
cp dist/telemt-x86_64-unknown-linux-musl dist/telemt || true
|
||||
|
||||
- uses: docker/setup-qemu-action@v3
|
||||
- uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to GHCR
|
||||
uses: docker/login-action@v3
|
||||
- name: Create Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
files: dist/*
|
||||
generate_release_notes: true
|
||||
prerelease: ${{ contains(github.ref, '-') }}
|
||||
|
||||
# ==========================
|
||||
# Docker (FROM RELEASE)
|
||||
# ==========================
|
||||
docker:
|
||||
name: Docker (from release)
|
||||
runs-on: ubuntu-latest
|
||||
needs: release
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install gh
|
||||
run: apt-get update && apt-get install -y gh
|
||||
|
||||
- name: Extract version
|
||||
id: vars
|
||||
run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Download binary
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
mkdir dist
|
||||
|
||||
gh release download ${{ steps.vars.outputs.VERSION }} \
|
||||
--repo ${{ github.repository }} \
|
||||
--pattern "telemt-x86_64-linux-musl.tar.gz" \
|
||||
--dir dist
|
||||
|
||||
tar -xzf dist/telemt-x86_64-linux-musl.tar.gz -C dist
|
||||
chmod +x dist/telemt
|
||||
|
||||
- uses: docker/setup-qemu-action@v3
|
||||
- uses: docker/setup-buildx-action@v3
|
||||
|
||||
- uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Build & Push
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
|
|
@ -315,32 +249,3 @@ jobs:
|
|||
ghcr.io/${{ github.repository }}:latest
|
||||
build-args: |
|
||||
BINARY=dist/telemt
|
||||
|
||||
# ==========================
|
||||
# Release
|
||||
# ==========================
|
||||
release:
|
||||
name: Release
|
||||
runs-on: ubuntu-latest
|
||||
needs: [build-gnu, build-musl, build-openbsd]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
steps:
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: artifacts
|
||||
|
||||
- name: Flatten artifacts
|
||||
run: |
|
||||
mkdir dist
|
||||
find artifacts -type f -exec cp {} dist/ \;
|
||||
|
||||
- name: Create Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: dist/*
|
||||
generate_release_notes: true
|
||||
draft: false
|
||||
prerelease: ${{ contains(github.ref, '-rc') || contains(github.ref, '-beta') || contains(github.ref, '-alpha') }}
|
||||
|
|
|
|||
|
|
@ -1,54 +0,0 @@
|
|||
name: Rust
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "*" ]
|
||||
pull_request:
|
||||
branches: [ "*" ]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Compile, Test, Lint
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
actions: write
|
||||
checks: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install latest stable Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
components: rustfmt, clippy
|
||||
|
||||
- name: Cache cargo registry & build artifacts
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-cargo-
|
||||
|
||||
- name: Compile (no tests)
|
||||
run: cargo check --workspace --all-features --bins --verbose
|
||||
|
||||
- name: Run tests (single pass)
|
||||
run: cargo test --workspace --all-features --verbose
|
||||
|
||||
# clippy dont fail on warnings because of active development of telemt
|
||||
# and many warnings
|
||||
- name: Run clippy
|
||||
run: cargo clippy -- --cap-lints warn
|
||||
|
||||
- name: Check for unused dependencies
|
||||
run: cargo udeps || true
|
||||
|
|
@ -1,57 +0,0 @@
|
|||
name: Stress Tests
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 2 * * *'
|
||||
pull_request:
|
||||
branches: ["*"]
|
||||
paths:
|
||||
- src/proxy/**
|
||||
- src/transport/**
|
||||
- src/stream/**
|
||||
- src/protocol/**
|
||||
- src/tls_front/**
|
||||
- Cargo.toml
|
||||
- Cargo.lock
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
jobs:
|
||||
quota-lock-stress:
|
||||
name: Quota-lock stress loop
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install latest stable Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Cache cargo registry and build artifacts
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: ${{ runner.os }}-cargo-stress-${{ hashFiles('**/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-cargo-stress-
|
||||
${{ runner.os }}-cargo-
|
||||
|
||||
- name: Run quota-lock stress suites
|
||||
env:
|
||||
RUST_TEST_THREADS: 16
|
||||
run: |
|
||||
set -euo pipefail
|
||||
for i in $(seq 1 12); do
|
||||
echo "[quota-lock-stress] iteration ${i}/12"
|
||||
cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16
|
||||
cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16
|
||||
done
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
name: Check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "*" ]
|
||||
pull_request:
|
||||
branches: [ "*" ]
|
||||
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
|
||||
concurrency:
|
||||
group: test-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
# ==========================
|
||||
# Formatting
|
||||
# ==========================
|
||||
fmt:
|
||||
name: Fmt
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
components: rustfmt
|
||||
|
||||
- run: cargo fmt -- --check
|
||||
|
||||
# ==========================
|
||||
# Tests
|
||||
# ==========================
|
||||
test:
|
||||
name: Test
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
actions: write
|
||||
checks: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Cache cargo
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-cargo-
|
||||
|
||||
- run: cargo test --verbose
|
||||
|
||||
# ==========================
|
||||
# Clippy
|
||||
# ==========================
|
||||
clippy:
|
||||
name: Clippy
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
checks: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
components: clippy
|
||||
|
||||
- name: Cache cargo
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-cargo-
|
||||
|
||||
- run: cargo clippy -- --cap-lints warn
|
||||
|
||||
# ==========================
|
||||
# Udeps
|
||||
# ==========================
|
||||
udeps:
|
||||
name: Udeps
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Cache cargo
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-cargo-
|
||||
|
||||
- name: Install cargo-udeps
|
||||
run: cargo install cargo-udeps || true
|
||||
|
||||
# тоже не валит билд
|
||||
- run: cargo udeps || true
|
||||
58
Dockerfile
58
Dockerfile
|
|
@ -1,29 +1,9 @@
|
|||
# syntax=docker/dockerfile:1
|
||||
|
||||
# ==========================
|
||||
# Stage 1: Build
|
||||
# ==========================
|
||||
FROM rust:1.88-slim-bookworm AS builder
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
pkg-config \
|
||||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# Depcache
|
||||
COPY Cargo.toml Cargo.lock* ./
|
||||
RUN mkdir src && echo 'fn main() {}' > src/main.rs && \
|
||||
cargo build --release 2>/dev/null || true && \
|
||||
rm -rf src
|
||||
|
||||
# Build
|
||||
COPY . .
|
||||
RUN cargo build --release && strip target/release/telemt
|
||||
ARG BINARY
|
||||
|
||||
# ==========================
|
||||
# Stage 2: Compress (strip + UPX)
|
||||
# Stage: minimal
|
||||
# ==========================
|
||||
FROM debian:12-slim AS minimal
|
||||
|
||||
|
|
@ -33,7 +13,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
ca-certificates \
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
\
|
||||
# install UPX from Telemt releases
|
||||
&& curl -fL \
|
||||
--retry 5 \
|
||||
--retry-delay 3 \
|
||||
|
|
@ -46,15 +25,15 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
&& chmod +x /usr/local/bin/upx \
|
||||
&& rm -rf /tmp/upx*
|
||||
|
||||
COPY --from=builder /build/target/release/telemt /telemt
|
||||
COPY ${BINARY} /telemt
|
||||
|
||||
RUN strip /telemt || true
|
||||
RUN upx --best --lzma /telemt || true
|
||||
|
||||
# ==========================
|
||||
# Stage 3: Debug base
|
||||
# Debug image
|
||||
# ==========================
|
||||
FROM debian:12-slim AS debug-base
|
||||
FROM debian:12-slim AS debug
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ca-certificates \
|
||||
|
|
@ -64,48 +43,29 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
busybox \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# ==========================
|
||||
# Stage 4: Debug image
|
||||
# ==========================
|
||||
FROM debug-base AS debug
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=minimal /telemt /app/telemt
|
||||
COPY config.toml /app/config.toml
|
||||
|
||||
USER root
|
||||
|
||||
EXPOSE 443
|
||||
EXPOSE 9090
|
||||
EXPOSE 9091
|
||||
EXPOSE 443 9090 9091
|
||||
|
||||
ENTRYPOINT ["/app/telemt"]
|
||||
CMD ["config.toml"]
|
||||
|
||||
# ==========================
|
||||
# Stage 5: Production (distroless)
|
||||
# Production (REAL distroless)
|
||||
# ==========================
|
||||
FROM gcr.io/distroless/base-debian12 AS prod
|
||||
FROM gcr.io/distroless/static-debian12 AS prod
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=minimal /telemt /app/telemt
|
||||
COPY config.toml /app/config.toml
|
||||
|
||||
# TLS + timezone + shell
|
||||
COPY --from=debug-base /etc/ssl/certs /etc/ssl/certs
|
||||
COPY --from=debug-base /usr/share/zoneinfo /usr/share/zoneinfo
|
||||
COPY --from=debug-base /bin/busybox /bin/busybox
|
||||
|
||||
RUN ["/bin/busybox", "--install", "-s", "/bin"]
|
||||
|
||||
# distroless user
|
||||
USER nonroot:nonroot
|
||||
|
||||
EXPOSE 443
|
||||
EXPOSE 9090
|
||||
EXPOSE 9091
|
||||
EXPOSE 443 9090 9091
|
||||
|
||||
ENTRYPOINT ["/app/telemt"]
|
||||
CMD ["config.toml"]
|
||||
|
|
@ -63,7 +63,7 @@ recommended range from 5 to 2147483647 inclusive
|
|||
|
||||
> [!IMPORTANT]
|
||||
> It is recommended to use your own, unique values.\
|
||||
> You can use the [generator](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/e8b269ff0089a27effd88f8d925179b78e5666c4/awg-gen.html) to select parameters.
|
||||
> You can use the [generator](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/13f5517ca473b47c412b9a99407066de973732bd/awg-gen.html) to select parameters.
|
||||
|
||||
#### Server B Configuration (Netherlands):
|
||||
|
||||
|
|
@ -84,6 +84,8 @@ Jmin = 8
|
|||
Jmax = 80
|
||||
S1 = 29
|
||||
S2 = 15
|
||||
S3 = 18
|
||||
S4 = 0
|
||||
H1 = 2087563914
|
||||
H2 = 188817757
|
||||
H3 = 101784570
|
||||
|
|
@ -121,6 +123,8 @@ Jmin = 8
|
|||
Jmax = 80
|
||||
S1 = 29
|
||||
S2 = 15
|
||||
S3 = 18
|
||||
S4 = 0
|
||||
H1 = 2087563914
|
||||
H2 = 188817757
|
||||
H3 = 101784570
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ awg genkey | tee private.key | awg pubkey > public.key
|
|||
|
||||
Параметры обфускации `S1`, `S2`, `H1`, `H2`, `H3`, `H4` должны быть строго идентичными на обоих серверах.\
|
||||
Параметры `Jc`, `Jmin` и `Jmax` могут отличатся.\
|
||||
Параметры `I1-I5` [(Custom Protocol Signature)](https://docs.amnezia.org/documentation/amnezia-wg/) нужно указывать на стороне _клиента_ (Сервер **А**).
|
||||
Параметры `I1-I5` ([Custom Protocol Signature](https://docs.amnezia.org/documentation/amnezia-wg/)) нужно указывать на стороне _клиента_ (Сервер **А**).
|
||||
|
||||
Рекомендации по выбору значений:
|
||||
```text
|
||||
|
|
@ -62,7 +62,7 @@ H1/H2/H3/H4 — должны быть уникальны и отличаться
|
|||
```
|
||||
> [!IMPORTANT]
|
||||
> Рекомендуется использовать собственные, уникальные значения.\
|
||||
> Для выбора параметров можете воспользоваться [генератором](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/e8b269ff0089a27effd88f8d925179b78e5666c4/awg-gen.html).
|
||||
> Для выбора параметров можете воспользоваться [генератором](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/13f5517ca473b47c412b9a99407066de973732bd/awg-gen.html).
|
||||
|
||||
#### Конфигурация Сервера B (_Нидерланды_):
|
||||
|
||||
|
|
@ -83,6 +83,8 @@ Jmin = 8
|
|||
Jmax = 80
|
||||
S1 = 29
|
||||
S2 = 15
|
||||
S3 = 18
|
||||
S4 = 0
|
||||
H1 = 2087563914
|
||||
H2 = 188817757
|
||||
H3 = 101784570
|
||||
|
|
@ -121,6 +123,8 @@ Jmin = 8
|
|||
Jmax = 80
|
||||
S1 = 29
|
||||
S2 = 15
|
||||
S3 = 18
|
||||
S4 = 0
|
||||
H1 = 2087563914
|
||||
H2 = 188817757
|
||||
H3 = 101784570
|
||||
|
|
@ -272,7 +276,7 @@ backend telemt_nodes
|
|||
|
||||
```
|
||||
>[!WARNING]
|
||||
>**Файл должен заканчиваться пустой строкой, иначе HAProxy не запуститься!**
|
||||
>**Файл должен заканчиваться пустой строкой, иначе HAProxy не запустится!**
|
||||
|
||||
#### Разрешаем порт 443\tcp в фаерволе (если включен)
|
||||
```bash
|
||||
|
|
|
|||
|
|
@ -71,6 +71,22 @@ pub(crate) fn default_tls_fetch_scope() -> String {
|
|||
String::new()
|
||||
}
|
||||
|
||||
pub(crate) fn default_tls_fetch_attempt_timeout_ms() -> u64 {
|
||||
5_000
|
||||
}
|
||||
|
||||
pub(crate) fn default_tls_fetch_total_budget_ms() -> u64 {
|
||||
15_000
|
||||
}
|
||||
|
||||
pub(crate) fn default_tls_fetch_strict_route() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
pub(crate) fn default_tls_fetch_profile_cache_ttl_secs() -> u64 {
|
||||
600
|
||||
}
|
||||
|
||||
pub(crate) fn default_mask_port() -> u16 {
|
||||
443
|
||||
}
|
||||
|
|
@ -185,6 +201,10 @@ pub(crate) fn default_proxy_protocol_header_timeout_ms() -> u64 {
|
|||
500
|
||||
}
|
||||
|
||||
pub(crate) fn default_proxy_protocol_trusted_cidrs() -> Vec<IpNetwork> {
|
||||
vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]
|
||||
}
|
||||
|
||||
pub(crate) fn default_server_max_connections() -> u32 {
|
||||
10_000
|
||||
}
|
||||
|
|
|
|||
|
|
@ -228,7 +228,9 @@ impl HotFields {
|
|||
me_d2c_flush_batch_max_delay_us: cfg.general.me_d2c_flush_batch_max_delay_us,
|
||||
me_d2c_ack_flush_immediate: cfg.general.me_d2c_ack_flush_immediate,
|
||||
me_quota_soft_overshoot_bytes: cfg.general.me_quota_soft_overshoot_bytes,
|
||||
me_d2c_frame_buf_shrink_threshold_bytes: cfg.general.me_d2c_frame_buf_shrink_threshold_bytes,
|
||||
me_d2c_frame_buf_shrink_threshold_bytes: cfg
|
||||
.general
|
||||
.me_d2c_frame_buf_shrink_threshold_bytes,
|
||||
direct_relay_copy_buf_c2s_bytes: cfg.general.direct_relay_copy_buf_c2s_bytes,
|
||||
direct_relay_copy_buf_s2c_bytes: cfg.general.direct_relay_copy_buf_s2c_bytes,
|
||||
me_health_interval_ms_unhealthy: cfg.general.me_health_interval_ms_unhealthy,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#![allow(deprecated)]
|
||||
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
use std::collections::{BTreeSet, HashMap, HashSet};
|
||||
use std::hash::{DefaultHasher, Hash, Hasher};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
|
@ -444,8 +444,7 @@ impl ProxyConfig {
|
|||
|
||||
if !(5..=50).contains(&config.censorship.mask_classifier_prefetch_timeout_ms) {
|
||||
return Err(ProxyError::Config(
|
||||
"censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"
|
||||
.to_string(),
|
||||
"censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
|
|
@ -558,7 +557,9 @@ impl ProxyConfig {
|
|||
));
|
||||
}
|
||||
|
||||
if !(4096..=16 * 1024 * 1024).contains(&config.general.me_d2c_frame_buf_shrink_threshold_bytes) {
|
||||
if !(4096..=16 * 1024 * 1024)
|
||||
.contains(&config.general.me_d2c_frame_buf_shrink_threshold_bytes)
|
||||
{
|
||||
return Err(ProxyError::Config(
|
||||
"general.me_d2c_frame_buf_shrink_threshold_bytes must be within [4096, 16777216]"
|
||||
.to_string(),
|
||||
|
|
@ -976,6 +977,28 @@ impl ProxyConfig {
|
|||
// Normalize optional TLS fetch scope: whitespace-only values disable scoped routing.
|
||||
config.censorship.tls_fetch_scope = config.censorship.tls_fetch_scope.trim().to_string();
|
||||
|
||||
if config.censorship.tls_fetch.profiles.is_empty() {
|
||||
config.censorship.tls_fetch.profiles = TlsFetchConfig::default().profiles;
|
||||
} else {
|
||||
let mut seen = HashSet::new();
|
||||
config
|
||||
.censorship
|
||||
.tls_fetch
|
||||
.profiles
|
||||
.retain(|profile| seen.insert(*profile));
|
||||
}
|
||||
|
||||
if config.censorship.tls_fetch.attempt_timeout_ms == 0 {
|
||||
return Err(ProxyError::Config(
|
||||
"censorship.tls_fetch.attempt_timeout_ms must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
if config.censorship.tls_fetch.total_budget_ms == 0 {
|
||||
return Err(ProxyError::Config(
|
||||
"censorship.tls_fetch.total_budget_ms must be > 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Merge primary + extra TLS domains, deduplicate (primary always first).
|
||||
if !config.censorship.tls_domains.is_empty() {
|
||||
let mut all = Vec::with_capacity(1 + config.censorship.tls_domains.len());
|
||||
|
|
@ -1262,6 +1285,11 @@ mod tests {
|
|||
assert_eq!(cfg.general.update_every, default_update_every());
|
||||
assert_eq!(cfg.server.listen_addr_ipv4, default_listen_addr_ipv4());
|
||||
assert_eq!(cfg.server.listen_addr_ipv6, default_listen_addr_ipv6_opt());
|
||||
assert_eq!(
|
||||
cfg.server.proxy_protocol_trusted_cidrs,
|
||||
default_proxy_protocol_trusted_cidrs()
|
||||
);
|
||||
assert_eq!(cfg.censorship.unknown_sni_action, UnknownSniAction::Drop);
|
||||
assert_eq!(cfg.server.api.listen, default_api_listen());
|
||||
assert_eq!(cfg.server.api.whitelist, default_api_whitelist());
|
||||
assert_eq!(
|
||||
|
|
@ -1394,6 +1422,14 @@ mod tests {
|
|||
|
||||
let server = ServerConfig::default();
|
||||
assert_eq!(server.listen_addr_ipv6, Some(default_listen_addr_ipv6()));
|
||||
assert_eq!(
|
||||
server.proxy_protocol_trusted_cidrs,
|
||||
default_proxy_protocol_trusted_cidrs()
|
||||
);
|
||||
assert_eq!(
|
||||
AntiCensorshipConfig::default().unknown_sni_action,
|
||||
UnknownSniAction::Drop
|
||||
);
|
||||
assert_eq!(server.api.listen, default_api_listen());
|
||||
assert_eq!(server.api.whitelist, default_api_whitelist());
|
||||
assert_eq!(
|
||||
|
|
@ -1429,6 +1465,75 @@ mod tests {
|
|||
assert_eq!(access.users, default_access_users());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn proxy_protocol_trusted_cidrs_missing_uses_trust_all_but_explicit_empty_stays_empty() {
|
||||
let cfg_missing: ProxyConfig = toml::from_str(
|
||||
r#"
|
||||
[server]
|
||||
[general]
|
||||
[network]
|
||||
[access]
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
cfg_missing.server.proxy_protocol_trusted_cidrs,
|
||||
default_proxy_protocol_trusted_cidrs()
|
||||
);
|
||||
|
||||
let cfg_explicit_empty: ProxyConfig = toml::from_str(
|
||||
r#"
|
||||
[server]
|
||||
proxy_protocol_trusted_cidrs = []
|
||||
|
||||
[general]
|
||||
[network]
|
||||
[access]
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
assert!(
|
||||
cfg_explicit_empty
|
||||
.server
|
||||
.proxy_protocol_trusted_cidrs
|
||||
.is_empty()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_sni_action_parses_and_defaults_to_drop() {
|
||||
let cfg_default: ProxyConfig = toml::from_str(
|
||||
r#"
|
||||
[server]
|
||||
[general]
|
||||
[network]
|
||||
[access]
|
||||
[censorship]
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
cfg_default.censorship.unknown_sni_action,
|
||||
UnknownSniAction::Drop
|
||||
);
|
||||
|
||||
let cfg_mask: ProxyConfig = toml::from_str(
|
||||
r#"
|
||||
[server]
|
||||
[general]
|
||||
[network]
|
||||
[access]
|
||||
[censorship]
|
||||
unknown_sni_action = "mask"
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
cfg_mask.censorship.unknown_sni_action,
|
||||
UnknownSniAction::Mask
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dc_overrides_allow_string_and_array() {
|
||||
let toml = r#"
|
||||
|
|
@ -2376,6 +2481,94 @@ mod tests {
|
|||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tls_fetch_defaults_are_applied() {
|
||||
let toml = r#"
|
||||
[censorship]
|
||||
tls_domain = "example.com"
|
||||
|
||||
[access.users]
|
||||
user = "00000000000000000000000000000000"
|
||||
"#;
|
||||
let dir = std::env::temp_dir();
|
||||
let path = dir.join("telemt_tls_fetch_defaults_test.toml");
|
||||
std::fs::write(&path, toml).unwrap();
|
||||
let cfg = ProxyConfig::load(&path).unwrap();
|
||||
assert_eq!(
|
||||
cfg.censorship.tls_fetch.profiles,
|
||||
TlsFetchConfig::default().profiles
|
||||
);
|
||||
assert!(cfg.censorship.tls_fetch.strict_route);
|
||||
assert_eq!(cfg.censorship.tls_fetch.attempt_timeout_ms, 5_000);
|
||||
assert_eq!(cfg.censorship.tls_fetch.total_budget_ms, 15_000);
|
||||
assert_eq!(cfg.censorship.tls_fetch.profile_cache_ttl_secs, 600);
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tls_fetch_profiles_are_deduplicated_preserving_order() {
|
||||
let toml = r#"
|
||||
[censorship]
|
||||
tls_domain = "example.com"
|
||||
[censorship.tls_fetch]
|
||||
profiles = ["compat_tls12", "modern_chrome_like", "compat_tls12", "legacy_minimal"]
|
||||
|
||||
[access.users]
|
||||
user = "00000000000000000000000000000000"
|
||||
"#;
|
||||
let dir = std::env::temp_dir();
|
||||
let path = dir.join("telemt_tls_fetch_profiles_dedup_test.toml");
|
||||
std::fs::write(&path, toml).unwrap();
|
||||
let cfg = ProxyConfig::load(&path).unwrap();
|
||||
assert_eq!(
|
||||
cfg.censorship.tls_fetch.profiles,
|
||||
vec![
|
||||
TlsFetchProfile::CompatTls12,
|
||||
TlsFetchProfile::ModernChromeLike,
|
||||
TlsFetchProfile::LegacyMinimal
|
||||
]
|
||||
);
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tls_fetch_attempt_timeout_zero_is_rejected() {
|
||||
let toml = r#"
|
||||
[censorship]
|
||||
tls_domain = "example.com"
|
||||
[censorship.tls_fetch]
|
||||
attempt_timeout_ms = 0
|
||||
|
||||
[access.users]
|
||||
user = "00000000000000000000000000000000"
|
||||
"#;
|
||||
let dir = std::env::temp_dir();
|
||||
let path = dir.join("telemt_tls_fetch_attempt_timeout_zero_test.toml");
|
||||
std::fs::write(&path, toml).unwrap();
|
||||
let err = ProxyConfig::load(&path).unwrap_err().to_string();
|
||||
assert!(err.contains("censorship.tls_fetch.attempt_timeout_ms must be > 0"));
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tls_fetch_total_budget_zero_is_rejected() {
|
||||
let toml = r#"
|
||||
[censorship]
|
||||
tls_domain = "example.com"
|
||||
[censorship.tls_fetch]
|
||||
total_budget_ms = 0
|
||||
|
||||
[access.users]
|
||||
user = "00000000000000000000000000000000"
|
||||
"#;
|
||||
let dir = std::env::temp_dir();
|
||||
let path = dir.join("telemt_tls_fetch_total_budget_zero_test.toml");
|
||||
std::fs::write(&path, toml).unwrap();
|
||||
let err = ProxyConfig::load(&path).unwrap_err().to_string();
|
||||
assert!(err.contains("censorship.tls_fetch.total_budget_ms must be > 0"));
|
||||
let _ = std::fs::remove_file(path);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_ad_tag_is_disabled_during_load() {
|
||||
let toml = r#"
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ fn write_temp_config(contents: &str) -> PathBuf {
|
|||
.duration_since(UNIX_EPOCH)
|
||||
.expect("system time must be after unix epoch")
|
||||
.as_nanos();
|
||||
let path = std::env::temp_dir()
|
||||
.join(format!("telemt-load-mask-prefetch-timeout-security-{nonce}.toml"));
|
||||
let path = std::env::temp_dir().join(format!(
|
||||
"telemt-load-mask-prefetch-timeout-security-{nonce}.toml"
|
||||
));
|
||||
fs::write(&path, contents).expect("temp config write must succeed");
|
||||
path
|
||||
}
|
||||
|
|
@ -67,8 +68,8 @@ mask_classifier_prefetch_timeout_ms = 20
|
|||
"#,
|
||||
);
|
||||
|
||||
let cfg = ProxyConfig::load(&path)
|
||||
.expect("prefetch timeout within security bounds must be accepted");
|
||||
let cfg =
|
||||
ProxyConfig::load(&path).expect("prefetch timeout within security bounds must be accepted");
|
||||
assert_eq!(cfg.censorship.mask_classifier_prefetch_timeout_ms, 20);
|
||||
|
||||
remove_temp_config(&path);
|
||||
|
|
|
|||
|
|
@ -265,8 +265,8 @@ mask_relay_max_bytes = 67108865
|
|||
"#,
|
||||
);
|
||||
|
||||
let err = ProxyConfig::load(&path)
|
||||
.expect_err("mask_relay_max_bytes above hard cap must be rejected");
|
||||
let err =
|
||||
ProxyConfig::load(&path).expect_err("mask_relay_max_bytes above hard cap must be rejected");
|
||||
let msg = err.to_string();
|
||||
assert!(
|
||||
msg.contains("censorship.mask_relay_max_bytes must be <= 67108864"),
|
||||
|
|
|
|||
|
|
@ -954,7 +954,8 @@ impl Default for GeneralConfig {
|
|||
me_d2c_flush_batch_max_delay_us: default_me_d2c_flush_batch_max_delay_us(),
|
||||
me_d2c_ack_flush_immediate: default_me_d2c_ack_flush_immediate(),
|
||||
me_quota_soft_overshoot_bytes: default_me_quota_soft_overshoot_bytes(),
|
||||
me_d2c_frame_buf_shrink_threshold_bytes: default_me_d2c_frame_buf_shrink_threshold_bytes(),
|
||||
me_d2c_frame_buf_shrink_threshold_bytes:
|
||||
default_me_d2c_frame_buf_shrink_threshold_bytes(),
|
||||
direct_relay_copy_buf_c2s_bytes: default_direct_relay_copy_buf_c2s_bytes(),
|
||||
direct_relay_copy_buf_s2c_bytes: default_direct_relay_copy_buf_s2c_bytes(),
|
||||
me_warmup_stagger_enabled: default_true(),
|
||||
|
|
@ -1239,9 +1240,10 @@ pub struct ServerConfig {
|
|||
|
||||
/// Trusted source CIDRs allowed to send incoming PROXY protocol headers.
|
||||
///
|
||||
/// When non-empty, connections from addresses outside this allowlist are
|
||||
/// rejected before `src_addr` is applied.
|
||||
#[serde(default)]
|
||||
/// If this field is omitted in config, it defaults to trust-all CIDRs
|
||||
/// (`0.0.0.0/0` and `::/0`). If it is explicitly set to an empty list,
|
||||
/// all PROXY protocol headers are rejected.
|
||||
#[serde(default = "default_proxy_protocol_trusted_cidrs")]
|
||||
pub proxy_protocol_trusted_cidrs: Vec<IpNetwork>,
|
||||
|
||||
/// Port for the Prometheus-compatible metrics endpoint.
|
||||
|
|
@ -1286,7 +1288,7 @@ impl Default for ServerConfig {
|
|||
listen_tcp: None,
|
||||
proxy_protocol: false,
|
||||
proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(),
|
||||
proxy_protocol_trusted_cidrs: Vec::new(),
|
||||
proxy_protocol_trusted_cidrs: default_proxy_protocol_trusted_cidrs(),
|
||||
metrics_port: None,
|
||||
metrics_listen: None,
|
||||
metrics_whitelist: default_metrics_whitelist(),
|
||||
|
|
@ -1357,6 +1359,90 @@ impl Default for TimeoutsConfig {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum UnknownSniAction {
|
||||
#[default]
|
||||
Drop,
|
||||
Mask,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TlsFetchProfile {
|
||||
ModernChromeLike,
|
||||
ModernFirefoxLike,
|
||||
CompatTls12,
|
||||
LegacyMinimal,
|
||||
}
|
||||
|
||||
impl TlsFetchProfile {
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
TlsFetchProfile::ModernChromeLike => "modern_chrome_like",
|
||||
TlsFetchProfile::ModernFirefoxLike => "modern_firefox_like",
|
||||
TlsFetchProfile::CompatTls12 => "compat_tls12",
|
||||
TlsFetchProfile::LegacyMinimal => "legacy_minimal",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_tls_fetch_profiles() -> Vec<TlsFetchProfile> {
|
||||
vec![
|
||||
TlsFetchProfile::ModernChromeLike,
|
||||
TlsFetchProfile::ModernFirefoxLike,
|
||||
TlsFetchProfile::CompatTls12,
|
||||
TlsFetchProfile::LegacyMinimal,
|
||||
]
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TlsFetchConfig {
|
||||
/// Ordered list of ClientHello profiles used for adaptive fallback.
|
||||
#[serde(default = "default_tls_fetch_profiles")]
|
||||
pub profiles: Vec<TlsFetchProfile>,
|
||||
|
||||
/// When true and upstream route is configured, TLS fetch fails closed on
|
||||
/// upstream connect errors and does not fallback to direct TCP.
|
||||
#[serde(default = "default_tls_fetch_strict_route")]
|
||||
pub strict_route: bool,
|
||||
|
||||
/// Timeout per one profile attempt in milliseconds.
|
||||
#[serde(default = "default_tls_fetch_attempt_timeout_ms")]
|
||||
pub attempt_timeout_ms: u64,
|
||||
|
||||
/// Total wall-clock budget in milliseconds across all profile attempts.
|
||||
#[serde(default = "default_tls_fetch_total_budget_ms")]
|
||||
pub total_budget_ms: u64,
|
||||
|
||||
/// Adds GREASE-style values into selected ClientHello extensions.
|
||||
#[serde(default)]
|
||||
pub grease_enabled: bool,
|
||||
|
||||
/// Produces deterministic ClientHello randomness for debugging/tests.
|
||||
#[serde(default)]
|
||||
pub deterministic: bool,
|
||||
|
||||
/// TTL for winner-profile cache entries in seconds.
|
||||
/// Set to 0 to disable profile cache.
|
||||
#[serde(default = "default_tls_fetch_profile_cache_ttl_secs")]
|
||||
pub profile_cache_ttl_secs: u64,
|
||||
}
|
||||
|
||||
impl Default for TlsFetchConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
profiles: default_tls_fetch_profiles(),
|
||||
strict_route: default_tls_fetch_strict_route(),
|
||||
attempt_timeout_ms: default_tls_fetch_attempt_timeout_ms(),
|
||||
total_budget_ms: default_tls_fetch_total_budget_ms(),
|
||||
grease_enabled: false,
|
||||
deterministic: false,
|
||||
profile_cache_ttl_secs: default_tls_fetch_profile_cache_ttl_secs(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AntiCensorshipConfig {
|
||||
#[serde(default = "default_tls_domain")]
|
||||
|
|
@ -1366,11 +1452,19 @@ pub struct AntiCensorshipConfig {
|
|||
#[serde(default)]
|
||||
pub tls_domains: Vec<String>,
|
||||
|
||||
/// Policy for TLS ClientHello with unknown (non-configured) SNI.
|
||||
#[serde(default)]
|
||||
pub unknown_sni_action: UnknownSniAction,
|
||||
|
||||
/// Upstream scope used for TLS front metadata fetches.
|
||||
/// Empty value keeps default upstream routing behavior.
|
||||
#[serde(default = "default_tls_fetch_scope")]
|
||||
pub tls_fetch_scope: String,
|
||||
|
||||
/// Fetch strategy for TLS front metadata bootstrap and periodic refresh.
|
||||
#[serde(default)]
|
||||
pub tls_fetch: TlsFetchConfig,
|
||||
|
||||
#[serde(default = "default_true")]
|
||||
pub mask: bool,
|
||||
|
||||
|
|
@ -1476,7 +1570,9 @@ impl Default for AntiCensorshipConfig {
|
|||
Self {
|
||||
tls_domain: default_tls_domain(),
|
||||
tls_domains: Vec::new(),
|
||||
unknown_sni_action: UnknownSniAction::Drop,
|
||||
tls_fetch_scope: default_tls_fetch_scope(),
|
||||
tls_fetch: TlsFetchConfig::default(),
|
||||
mask: default_true(),
|
||||
mask_host: None,
|
||||
mask_port: default_mask_port(),
|
||||
|
|
|
|||
|
|
@ -216,6 +216,9 @@ pub enum ProxyError {
|
|||
#[error("Invalid proxy protocol header")]
|
||||
InvalidProxyProtocol,
|
||||
|
||||
#[error("Unknown TLS SNI")]
|
||||
UnknownTlsSni,
|
||||
|
||||
#[error("Proxy error: {0}")]
|
||||
Proxy(String),
|
||||
|
||||
|
|
|
|||
|
|
@ -8,8 +8,10 @@ use tracing::{debug, error, info, warn};
|
|||
|
||||
use crate::cli;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::transport::UpstreamManager;
|
||||
use crate::transport::middle_proxy::{
|
||||
ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache,
|
||||
ProxyConfigData, fetch_proxy_config_with_raw_via_upstream, load_proxy_config_cache,
|
||||
save_proxy_config_cache,
|
||||
};
|
||||
|
||||
pub(crate) fn resolve_runtime_config_path(
|
||||
|
|
@ -288,9 +290,10 @@ pub(crate) async fn load_startup_proxy_config_snapshot(
|
|||
cache_path: Option<&str>,
|
||||
me2dc_fallback: bool,
|
||||
label: &'static str,
|
||||
upstream: Option<std::sync::Arc<UpstreamManager>>,
|
||||
) -> Option<ProxyConfigData> {
|
||||
loop {
|
||||
match fetch_proxy_config_with_raw(url).await {
|
||||
match fetch_proxy_config_with_raw_via_upstream(url, upstream.clone()).await {
|
||||
Ok((cfg, raw)) => {
|
||||
if !cfg.map.is_empty() {
|
||||
if let Some(path) = cache_path
|
||||
|
|
|
|||
|
|
@ -63,9 +63,10 @@ pub(crate) async fn initialize_me_pool(
|
|||
let proxy_secret_path = config.general.proxy_secret_path.as_deref();
|
||||
let pool_size = config.general.middle_proxy_pool_size.max(1);
|
||||
let proxy_secret = loop {
|
||||
match crate::transport::middle_proxy::fetch_proxy_secret(
|
||||
match crate::transport::middle_proxy::fetch_proxy_secret_with_upstream(
|
||||
proxy_secret_path,
|
||||
config.general.proxy_secret_len_max,
|
||||
Some(upstream_manager.clone()),
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
|
@ -129,6 +130,7 @@ pub(crate) async fn initialize_me_pool(
|
|||
config.general.proxy_config_v4_cache_path.as_deref(),
|
||||
me2dc_fallback,
|
||||
"getProxyConfig",
|
||||
Some(upstream_manager.clone()),
|
||||
)
|
||||
.await;
|
||||
if cfg_v4.is_some() {
|
||||
|
|
@ -160,6 +162,7 @@ pub(crate) async fn initialize_me_pool(
|
|||
config.general.proxy_config_v6_cache_path.as_deref(),
|
||||
me2dc_fallback,
|
||||
"getProxyConfigV6",
|
||||
Some(upstream_manager.clone()),
|
||||
)
|
||||
.await;
|
||||
if cfg_v6.is_some() {
|
||||
|
|
|
|||
|
|
@ -32,14 +32,6 @@ pub(crate) struct RuntimeWatches {
|
|||
pub(crate) detected_ip_v6: Option<IpAddr>,
|
||||
}
|
||||
|
||||
const QUOTA_USER_LOCK_EVICT_INTERVAL_SECS: u64 = 60;
|
||||
|
||||
fn spawn_quota_lock_maintenance_task() -> tokio::task::JoinHandle<()> {
|
||||
crate::proxy::relay::spawn_quota_user_lock_evictor(std::time::Duration::from_secs(
|
||||
QUOTA_USER_LOCK_EVICT_INTERVAL_SECS,
|
||||
))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn spawn_runtime_tasks(
|
||||
config: &Arc<ProxyConfig>,
|
||||
|
|
@ -77,8 +69,6 @@ pub(crate) async fn spawn_runtime_tasks(
|
|||
rc_clone.run_periodic_cleanup().await;
|
||||
});
|
||||
|
||||
spawn_quota_lock_maintenance_task();
|
||||
|
||||
let detected_ip_v4: Option<IpAddr> = probe.detected_ipv4.map(IpAddr::V4);
|
||||
let detected_ip_v6: Option<IpAddr> = probe.detected_ipv6.map(IpAddr::V6);
|
||||
debug!(
|
||||
|
|
@ -370,24 +360,3 @@ pub(crate) async fn mark_runtime_ready(startup_tracker: &Arc<StartupTracker>) {
|
|||
.await;
|
||||
startup_tracker.mark_ready().await;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_runtime_quota_lock_maintenance_path_spawns_single_evictor_task() {
|
||||
crate::proxy::relay::reset_quota_user_lock_evictor_spawn_count_for_tests();
|
||||
|
||||
let handle = spawn_quota_lock_maintenance_task();
|
||||
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
|
||||
|
||||
assert_eq!(
|
||||
crate::proxy::relay::quota_user_lock_evictor_spawn_count_for_tests(),
|
||||
1,
|
||||
"runtime maintenance path must spawn exactly one quota lock evictor task per call"
|
||||
);
|
||||
|
||||
handle.abort();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ use tracing::warn;
|
|||
use crate::config::ProxyConfig;
|
||||
use crate::startup::{COMPONENT_TLS_FRONT_BOOTSTRAP, StartupTracker};
|
||||
use crate::tls_front::TlsFrontCache;
|
||||
use crate::tls_front::fetcher::TlsFetchStrategy;
|
||||
use crate::transport::UpstreamManager;
|
||||
|
||||
pub(crate) async fn bootstrap_tls_front(
|
||||
|
|
@ -40,7 +41,17 @@ pub(crate) async fn bootstrap_tls_front(
|
|||
let mask_unix_sock = config.censorship.mask_unix_sock.clone();
|
||||
let tls_fetch_scope = (!config.censorship.tls_fetch_scope.is_empty())
|
||||
.then(|| config.censorship.tls_fetch_scope.clone());
|
||||
let fetch_timeout = Duration::from_secs(5);
|
||||
let tls_fetch = config.censorship.tls_fetch.clone();
|
||||
let fetch_strategy = TlsFetchStrategy {
|
||||
profiles: tls_fetch.profiles,
|
||||
strict_route: tls_fetch.strict_route,
|
||||
attempt_timeout: Duration::from_millis(tls_fetch.attempt_timeout_ms.max(1)),
|
||||
total_budget: Duration::from_millis(tls_fetch.total_budget_ms.max(1)),
|
||||
grease_enabled: tls_fetch.grease_enabled,
|
||||
deterministic: tls_fetch.deterministic,
|
||||
profile_cache_ttl: Duration::from_secs(tls_fetch.profile_cache_ttl_secs),
|
||||
};
|
||||
let fetch_timeout = fetch_strategy.total_budget;
|
||||
|
||||
let cache_initial = cache.clone();
|
||||
let domains_initial = tls_domains.to_vec();
|
||||
|
|
@ -48,6 +59,7 @@ pub(crate) async fn bootstrap_tls_front(
|
|||
let unix_sock_initial = mask_unix_sock.clone();
|
||||
let scope_initial = tls_fetch_scope.clone();
|
||||
let upstream_initial = upstream_manager.clone();
|
||||
let strategy_initial = fetch_strategy.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut join = tokio::task::JoinSet::new();
|
||||
for domain in domains_initial {
|
||||
|
|
@ -56,12 +68,13 @@ pub(crate) async fn bootstrap_tls_front(
|
|||
let unix_sock_domain = unix_sock_initial.clone();
|
||||
let scope_domain = scope_initial.clone();
|
||||
let upstream_domain = upstream_initial.clone();
|
||||
let strategy_domain = strategy_initial.clone();
|
||||
join.spawn(async move {
|
||||
match crate::tls_front::fetcher::fetch_real_tls(
|
||||
match crate::tls_front::fetcher::fetch_real_tls_with_strategy(
|
||||
&host_domain,
|
||||
port,
|
||||
&domain,
|
||||
fetch_timeout,
|
||||
&strategy_domain,
|
||||
Some(upstream_domain),
|
||||
scope_domain.as_deref(),
|
||||
proxy_protocol,
|
||||
|
|
@ -107,6 +120,7 @@ pub(crate) async fn bootstrap_tls_front(
|
|||
let unix_sock_refresh = mask_unix_sock.clone();
|
||||
let scope_refresh = tls_fetch_scope.clone();
|
||||
let upstream_refresh = upstream_manager.clone();
|
||||
let strategy_refresh = fetch_strategy.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let base_secs = rand::rng().random_range(4 * 3600..=6 * 3600);
|
||||
|
|
@ -120,12 +134,13 @@ pub(crate) async fn bootstrap_tls_front(
|
|||
let unix_sock_domain = unix_sock_refresh.clone();
|
||||
let scope_domain = scope_refresh.clone();
|
||||
let upstream_domain = upstream_refresh.clone();
|
||||
let strategy_domain = strategy_refresh.clone();
|
||||
join.spawn(async move {
|
||||
match crate::tls_front::fetcher::fetch_real_tls(
|
||||
match crate::tls_front::fetcher::fetch_real_tls_with_strategy(
|
||||
&host_domain,
|
||||
port,
|
||||
&domain,
|
||||
fetch_timeout,
|
||||
&strategy_domain,
|
||||
Some(upstream_domain),
|
||||
scope_domain.as_deref(),
|
||||
proxy_protocol,
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@ mod crypto;
|
|||
mod error;
|
||||
mod ip_tracker;
|
||||
#[cfg(test)]
|
||||
#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"]
|
||||
mod ip_tracker_hotpath_adversarial_tests;
|
||||
#[cfg(test)]
|
||||
#[path = "tests/ip_tracker_encapsulation_adversarial_tests.rs"]
|
||||
mod ip_tracker_encapsulation_adversarial_tests;
|
||||
#[cfg(test)]
|
||||
#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"]
|
||||
mod ip_tracker_hotpath_adversarial_tests;
|
||||
#[cfg(test)]
|
||||
#[path = "tests/ip_tracker_regression_tests.rs"]
|
||||
mod ip_tracker_regression_tests;
|
||||
mod maestro;
|
||||
|
|
@ -29,5 +29,6 @@ mod util;
|
|||
|
||||
#[tokio::main]
|
||||
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = rustls::crypto::ring::default_provider().install_default();
|
||||
maestro::run().await
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1233,10 +1233,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
|
|||
out,
|
||||
"# HELP telemt_me_d2c_batch_bytes_bucket_total DC->Client batch byte size buckets"
|
||||
);
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"# TYPE telemt_me_d2c_batch_bytes_bucket_total counter"
|
||||
);
|
||||
let _ = writeln!(out, "# TYPE telemt_me_d2c_batch_bytes_bucket_total counter");
|
||||
let _ = writeln!(
|
||||
out,
|
||||
"telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"0_1k\"}} {}",
|
||||
|
|
|
|||
|
|
@ -210,7 +210,9 @@ fn should_prefetch_mask_classifier_window(initial_data: &[u8]) -> bool {
|
|||
return false;
|
||||
}
|
||||
|
||||
initial_data.iter().all(|b| b.is_ascii_alphabetic() || *b == b' ')
|
||||
initial_data
|
||||
.iter()
|
||||
.all(|b| b.is_ascii_alphabetic() || *b == b' ')
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -218,16 +220,19 @@ async fn extend_masking_initial_window<R>(reader: &mut R, initial_data: &mut Vec
|
|||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
extend_masking_initial_window_with_timeout(reader, initial_data, MASK_CLASSIFIER_PREFETCH_TIMEOUT)
|
||||
.await;
|
||||
extend_masking_initial_window_with_timeout(
|
||||
reader,
|
||||
initial_data,
|
||||
MASK_CLASSIFIER_PREFETCH_TIMEOUT,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn extend_masking_initial_window_with_timeout<R>(
|
||||
reader: &mut R,
|
||||
initial_data: &mut Vec<u8>,
|
||||
prefetch_timeout: Duration,
|
||||
)
|
||||
where
|
||||
) where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
if !should_prefetch_mask_classifier_window(initial_data) {
|
||||
|
|
@ -312,13 +317,20 @@ fn record_handshake_failure_class(
|
|||
record_beobachten_class(beobachten, config, peer_ip, class);
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn increment_bad_on_unknown_tls_sni(stats: &Stats, error: &ProxyError) {
|
||||
if matches!(error, ProxyError::UnknownTlsSni) {
|
||||
stats.increment_connects_bad();
|
||||
}
|
||||
}
|
||||
|
||||
fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool {
|
||||
if trusted.is_empty() {
|
||||
static EMPTY_PROXY_TRUST_WARNED: OnceLock<AtomicBool> = OnceLock::new();
|
||||
let warned = EMPTY_PROXY_TRUST_WARNED.get_or_init(|| AtomicBool::new(false));
|
||||
if !warned.swap(true, Ordering::Relaxed) {
|
||||
warn!(
|
||||
"PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers by default"
|
||||
"PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers"
|
||||
);
|
||||
}
|
||||
return false;
|
||||
|
|
@ -503,7 +515,10 @@ where
|
|||
beobachten.clone(),
|
||||
));
|
||||
}
|
||||
HandshakeResult::Error(e) => return Err(e),
|
||||
HandshakeResult::Error(e) => {
|
||||
increment_bad_on_unknown_tls_sni(stats.as_ref(), &e);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
debug!(peer = %peer, "Reading MTProto handshake through TLS");
|
||||
|
|
@ -954,7 +969,10 @@ impl RunningClientHandler {
|
|||
self.beobachten.clone(),
|
||||
));
|
||||
}
|
||||
HandshakeResult::Error(e) => return Err(e),
|
||||
HandshakeResult::Error(e) => {
|
||||
increment_bad_on_unknown_tls_sni(stats.as_ref(), &e);
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
debug!(peer = %peer, "Reading MTProto handshake through TLS");
|
||||
|
|
@ -1223,7 +1241,7 @@ impl RunningClientHandler {
|
|||
}
|
||||
|
||||
if let Some(quota) = config.access.user_data_quota.get(user)
|
||||
&& stats.get_user_total_octets(user) >= *quota
|
||||
&& stats.get_user_quota_used(user) >= *quota
|
||||
{
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
|
|
@ -1282,7 +1300,7 @@ impl RunningClientHandler {
|
|||
}
|
||||
|
||||
if let Some(quota) = config.access.user_data_quota.get(user)
|
||||
&& stats.get_user_total_octets(user) >= *quota
|
||||
&& stats.get_user_quota_used(user) >= *quota
|
||||
{
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
|||
use tracing::{debug, trace, warn};
|
||||
use zeroize::{Zeroize, Zeroizing};
|
||||
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::config::{ProxyConfig, UnknownSniAction};
|
||||
use crate::crypto::{AesCtr, SecureRandom, sha256};
|
||||
use crate::error::{HandshakeResult, ProxyError};
|
||||
use crate::protocol::constants::*;
|
||||
|
|
@ -282,30 +282,9 @@ fn auth_probe_record_failure_with_state(
|
|||
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
|
||||
let state_len = state.len();
|
||||
let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT);
|
||||
let start_offset = auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit);
|
||||
|
||||
let mut scanned = 0usize;
|
||||
for entry in state.iter().skip(start_offset) {
|
||||
let key = *entry.key();
|
||||
let fail_streak = entry.value().fail_streak;
|
||||
let last_seen = entry.value().last_seen;
|
||||
match eviction_candidate {
|
||||
Some((_, current_fail, current_seen))
|
||||
if fail_streak > current_fail
|
||||
|| (fail_streak == current_fail && last_seen >= current_seen) => {}
|
||||
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
|
||||
}
|
||||
if auth_probe_state_expired(entry.value(), now) {
|
||||
stale_keys.push(key);
|
||||
}
|
||||
scanned += 1;
|
||||
if scanned >= scan_limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if scanned < scan_limit {
|
||||
for entry in state.iter().take(scan_limit - scanned) {
|
||||
if state_len <= AUTH_PROBE_PRUNE_SCAN_LIMIT {
|
||||
for entry in state.iter() {
|
||||
let key = *entry.key();
|
||||
let fail_streak = entry.value().fail_streak;
|
||||
let last_seen = entry.value().last_seen;
|
||||
|
|
@ -319,6 +298,46 @@ fn auth_probe_record_failure_with_state(
|
|||
stale_keys.push(key);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let start_offset =
|
||||
auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit);
|
||||
let mut scanned = 0usize;
|
||||
for entry in state.iter().skip(start_offset) {
|
||||
let key = *entry.key();
|
||||
let fail_streak = entry.value().fail_streak;
|
||||
let last_seen = entry.value().last_seen;
|
||||
match eviction_candidate {
|
||||
Some((_, current_fail, current_seen))
|
||||
if fail_streak > current_fail
|
||||
|| (fail_streak == current_fail && last_seen >= current_seen) => {}
|
||||
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
|
||||
}
|
||||
if auth_probe_state_expired(entry.value(), now) {
|
||||
stale_keys.push(key);
|
||||
}
|
||||
scanned += 1;
|
||||
if scanned >= scan_limit {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if scanned < scan_limit {
|
||||
for entry in state.iter().take(scan_limit - scanned) {
|
||||
let key = *entry.key();
|
||||
let fail_streak = entry.value().fail_streak;
|
||||
let last_seen = entry.value().last_seen;
|
||||
match eviction_candidate {
|
||||
Some((_, current_fail, current_seen))
|
||||
if fail_streak > current_fail
|
||||
|| (fail_streak == current_fail
|
||||
&& last_seen >= current_seen) => {}
|
||||
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
|
||||
}
|
||||
if auth_probe_state_expired(entry.value(), now) {
|
||||
stale_keys.push(key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for stale_key in stale_keys {
|
||||
|
|
@ -510,6 +529,21 @@ fn decode_user_secrets(
|
|||
secrets
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn find_matching_tls_domain<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> {
|
||||
if config.censorship.tls_domain.eq_ignore_ascii_case(sni) {
|
||||
return Some(config.censorship.tls_domain.as_str());
|
||||
}
|
||||
|
||||
for domain in &config.censorship.tls_domains {
|
||||
if domain.eq_ignore_ascii_case(sni) {
|
||||
return Some(domain.as_str());
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
async fn maybe_apply_server_hello_delay(config: &ProxyConfig) {
|
||||
if config.censorship.server_hello_delay_max_ms == 0 {
|
||||
return;
|
||||
|
|
@ -593,60 +627,12 @@ where
|
|||
}
|
||||
|
||||
let client_sni = tls::extract_sni_from_client_hello(handshake);
|
||||
let secrets = decode_user_secrets(config, client_sni.as_deref());
|
||||
|
||||
let validation = match tls::validate_tls_handshake_with_replay_window(
|
||||
handshake,
|
||||
&secrets,
|
||||
config.access.ignore_time_skew,
|
||||
config.access.replay_window_secs,
|
||||
) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
debug!(
|
||||
peer = %peer,
|
||||
ignore_time_skew = config.access.ignore_time_skew,
|
||||
"TLS handshake validation failed - no matching user or time skew"
|
||||
);
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
};
|
||||
|
||||
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
|
||||
Some((_, s)) => s,
|
||||
None => {
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
};
|
||||
|
||||
let cached = if config.censorship.tls_emulation {
|
||||
if let Some(cache) = tls_cache.as_ref() {
|
||||
let selected_domain = if let Some(sni) = client_sni.as_ref() {
|
||||
if cache.contains_domain(sni).await {
|
||||
sni.clone()
|
||||
} else {
|
||||
config.censorship.tls_domain.clone()
|
||||
}
|
||||
} else {
|
||||
config.censorship.tls_domain.clone()
|
||||
};
|
||||
let cached_entry = cache.get(&selected_domain).await;
|
||||
let use_full_cert_payload = cache
|
||||
.take_full_cert_budget_for_ip(
|
||||
peer.ip(),
|
||||
Duration::from_secs(config.censorship.tls_full_cert_ttl_secs),
|
||||
)
|
||||
.await;
|
||||
Some((cached_entry, use_full_cert_payload))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let preferred_user_hint = client_sni
|
||||
.as_deref()
|
||||
.filter(|sni| config.access.users.contains_key(*sni));
|
||||
let matched_tls_domain = client_sni
|
||||
.as_deref()
|
||||
.and_then(|sni| find_matching_tls_domain(config, sni));
|
||||
|
||||
let alpn_list = if config.censorship.alpn_enforce {
|
||||
tls::extract_alpn_from_client_hello(handshake)
|
||||
|
|
@ -669,16 +655,81 @@ where
|
|||
None
|
||||
};
|
||||
|
||||
// Replay tracking is applied only after full policy validation (including
|
||||
// ALPN checks) so rejected handshakes cannot poison replay state.
|
||||
if client_sni.is_some() && matched_tls_domain.is_none() && preferred_user_hint.is_none() {
|
||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
debug!(
|
||||
peer = %peer,
|
||||
sni = ?client_sni,
|
||||
action = ?config.censorship.unknown_sni_action,
|
||||
"TLS handshake rejected by unknown SNI policy"
|
||||
);
|
||||
return match config.censorship.unknown_sni_action {
|
||||
UnknownSniAction::Drop => HandshakeResult::Error(ProxyError::UnknownTlsSni),
|
||||
UnknownSniAction::Mask => HandshakeResult::BadClient { reader, writer },
|
||||
};
|
||||
}
|
||||
|
||||
let secrets = decode_user_secrets(config, preferred_user_hint);
|
||||
|
||||
let validation = match tls::validate_tls_handshake_with_replay_window(
|
||||
handshake,
|
||||
&secrets,
|
||||
config.access.ignore_time_skew,
|
||||
config.access.replay_window_secs,
|
||||
) {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
debug!(
|
||||
peer = %peer,
|
||||
ignore_time_skew = config.access.ignore_time_skew,
|
||||
"TLS handshake validation failed - no matching user or time skew"
|
||||
);
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
};
|
||||
|
||||
// Reject known replay digests before expensive cache/domain/ALPN policy work.
|
||||
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
|
||||
if replay_checker.check_and_add_tls_digest(digest_half) {
|
||||
if replay_checker.check_tls_digest(digest_half) {
|
||||
auth_probe_record_failure(peer.ip(), Instant::now());
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
|
||||
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
|
||||
Some((_, s)) => s,
|
||||
None => {
|
||||
maybe_apply_server_hello_delay(config).await;
|
||||
return HandshakeResult::BadClient { reader, writer };
|
||||
}
|
||||
};
|
||||
|
||||
let cached = if config.censorship.tls_emulation {
|
||||
if let Some(cache) = tls_cache.as_ref() {
|
||||
let selected_domain =
|
||||
matched_tls_domain.unwrap_or(config.censorship.tls_domain.as_str());
|
||||
let cached_entry = cache.get(selected_domain).await;
|
||||
let use_full_cert_payload = cache
|
||||
.take_full_cert_budget_for_ip(
|
||||
peer.ip(),
|
||||
Duration::from_secs(config.censorship.tls_full_cert_ttl_secs),
|
||||
)
|
||||
.await;
|
||||
Some((cached_entry, use_full_cert_payload))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Add replay digest only for policy-valid handshakes.
|
||||
replay_checker.add_tls_digest(digest_half);
|
||||
|
||||
let response = if let Some((cached_entry, use_full_cert_payload)) = cached {
|
||||
emulator::build_emulated_server_hello(
|
||||
secret,
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ use rand::rngs::StdRng;
|
|||
use rand::{Rng, RngExt, SeedableRng};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::str;
|
||||
#[cfg(unix)]
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
#[cfg(test)]
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
#[cfg(unix)]
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
use std::time::{Duration, Instant as StdInstant};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
|
|
@ -60,7 +60,7 @@ where
|
|||
R: AsyncRead + Unpin,
|
||||
W: AsyncWrite + Unpin,
|
||||
{
|
||||
let mut buf = [0u8; MASK_BUFFER_SIZE];
|
||||
let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]);
|
||||
let mut total = 0usize;
|
||||
let mut ended_by_eof = false;
|
||||
|
||||
|
|
@ -107,15 +107,7 @@ where
|
|||
fn is_http_probe(data: &[u8]) -> bool {
|
||||
// RFC 7540 section 3.5: HTTP/2 client preface starts with "PRI ".
|
||||
const HTTP_METHODS: [&[u8]; 10] = [
|
||||
b"GET ",
|
||||
b"POST",
|
||||
b"HEAD",
|
||||
b"PUT ",
|
||||
b"DELETE",
|
||||
b"OPTIONS",
|
||||
b"CONNECT",
|
||||
b"TRACE",
|
||||
b"PATCH",
|
||||
b"GET ", b"POST", b"HEAD", b"PUT ", b"DELETE", b"OPTIONS", b"CONNECT", b"TRACE", b"PATCH",
|
||||
b"PRI ",
|
||||
];
|
||||
|
||||
|
|
@ -262,7 +254,11 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration {
|
|||
let floor = config.censorship.mask_timing_normalization_floor_ms;
|
||||
let ceiling = config.censorship.mask_timing_normalization_ceiling_ms;
|
||||
if floor == 0 {
|
||||
return MASK_TIMEOUT;
|
||||
if ceiling == 0 {
|
||||
return Duration::from_millis(0);
|
||||
}
|
||||
let mut rng = rand::rng();
|
||||
return Duration::from_millis(rng.random_range(0..=ceiling));
|
||||
}
|
||||
if ceiling > floor {
|
||||
let mut rng = rand::rng();
|
||||
|
|
@ -324,7 +320,10 @@ fn parse_mask_host_ip_literal(host: &str) -> Option<IpAddr> {
|
|||
|
||||
fn canonical_ip(ip: IpAddr) -> IpAddr {
|
||||
match ip {
|
||||
IpAddr::V6(v6) => v6.to_ipv4_mapped().map(IpAddr::V4).unwrap_or(IpAddr::V6(v6)),
|
||||
IpAddr::V6(v6) => v6
|
||||
.to_ipv4_mapped()
|
||||
.map(IpAddr::V4)
|
||||
.unwrap_or(IpAddr::V6(v6)),
|
||||
IpAddr::V4(v4) => IpAddr::V4(v4),
|
||||
}
|
||||
}
|
||||
|
|
@ -660,12 +659,20 @@ pub async fn handle_bad_client<R, W>(
|
|||
Ok(Err(e)) => {
|
||||
wait_mask_connect_budget_if_needed(connect_started, config).await;
|
||||
debug!(error = %e, "Failed to connect to mask unix socket");
|
||||
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
|
||||
consume_client_data_with_timeout_and_cap(
|
||||
reader,
|
||||
config.censorship.mask_relay_max_bytes,
|
||||
)
|
||||
.await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Timeout connecting to mask unix socket");
|
||||
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
|
||||
consume_client_data_with_timeout_and_cap(
|
||||
reader,
|
||||
config.censorship.mask_relay_max_bytes,
|
||||
)
|
||||
.await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
}
|
||||
}
|
||||
|
|
@ -694,7 +701,8 @@ pub async fn handle_bad_client<R, W>(
|
|||
local = %local_addr,
|
||||
"Mask target resolves to local listener; refusing self-referential masking fallback"
|
||||
);
|
||||
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
|
||||
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes)
|
||||
.await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
return;
|
||||
}
|
||||
|
|
@ -754,12 +762,20 @@ pub async fn handle_bad_client<R, W>(
|
|||
Ok(Err(e)) => {
|
||||
wait_mask_connect_budget_if_needed(connect_started, config).await;
|
||||
debug!(error = %e, "Failed to connect to mask host");
|
||||
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
|
||||
consume_client_data_with_timeout_and_cap(
|
||||
reader,
|
||||
config.censorship.mask_relay_max_bytes,
|
||||
)
|
||||
.await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
}
|
||||
Err(_) => {
|
||||
debug!("Timeout connecting to mask host");
|
||||
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await;
|
||||
consume_client_data_with_timeout_and_cap(
|
||||
reader,
|
||||
config.censorship.mask_relay_max_bytes,
|
||||
)
|
||||
.await;
|
||||
wait_mask_outcome_budget(outcome_started, config).await;
|
||||
}
|
||||
}
|
||||
|
|
@ -838,7 +854,7 @@ async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R, byte_cap: usiz
|
|||
}
|
||||
|
||||
// Keep drain path fail-closed under slow-loris stalls.
|
||||
let mut buf = [0u8; MASK_BUFFER_SIZE];
|
||||
let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]);
|
||||
let mut total = 0usize;
|
||||
|
||||
loop {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ use std::time::{Duration, Instant};
|
|||
|
||||
use dashmap::DashMap;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::sync::{Mutex as AsyncMutex, mpsc, oneshot, watch};
|
||||
use tokio::sync::{mpsc, oneshot, watch};
|
||||
use tokio::time::timeout;
|
||||
use tracing::{debug, info, trace, warn};
|
||||
|
||||
|
|
@ -23,7 +23,9 @@ use crate::proxy::route_mode::{
|
|||
ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state,
|
||||
cutover_stagger_delay,
|
||||
};
|
||||
use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, Stats};
|
||||
use crate::stats::{
|
||||
MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, QuotaReserveError, Stats, UserStats,
|
||||
};
|
||||
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
|
||||
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
|
||||
|
||||
|
|
@ -53,20 +55,11 @@ const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
|
|||
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
|
||||
const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2;
|
||||
const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024;
|
||||
#[cfg(test)]
|
||||
const QUOTA_USER_LOCKS_MAX: usize = 64;
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
|
||||
#[cfg(test)]
|
||||
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
|
||||
const QUOTA_RESERVE_SPIN_RETRIES: usize = 32;
|
||||
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
|
||||
static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new();
|
||||
static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock<Mutex<Option<Instant>>> = OnceLock::new();
|
||||
static DESYNC_DEDUP_EVER_SATURATED: OnceLock<AtomicBool> = OnceLock::new();
|
||||
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<AsyncMutex<()>>>> = OnceLock::new();
|
||||
static QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<AsyncMutex<()>>>> = OnceLock::new();
|
||||
static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock<Mutex<RelayIdleCandidateRegistry>> = OnceLock::new();
|
||||
static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
|
|
@ -100,7 +93,8 @@ fn relay_idle_candidate_registry() -> &'static Mutex<RelayIdleCandidateRegistry>
|
|||
RELAY_IDLE_CANDIDATE_REGISTRY.get_or_init(|| Mutex::new(RelayIdleCandidateRegistry::default()))
|
||||
}
|
||||
|
||||
fn relay_idle_candidate_registry_lock() -> std::sync::MutexGuard<'static, RelayIdleCandidateRegistry> {
|
||||
fn relay_idle_candidate_registry_lock() -> std::sync::MutexGuard<'static, RelayIdleCandidateRegistry>
|
||||
{
|
||||
let registry = relay_idle_candidate_registry();
|
||||
match registry.lock() {
|
||||
Ok(guard) => guard,
|
||||
|
|
@ -538,36 +532,28 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool
|
|||
has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET
|
||||
}
|
||||
|
||||
fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option<u64>) -> bool {
|
||||
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota)
|
||||
}
|
||||
|
||||
#[cfg_attr(not(test), allow(dead_code))]
|
||||
fn quota_would_be_exceeded_for_user(
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
quota_limit: Option<u64>,
|
||||
bytes: u64,
|
||||
) -> bool {
|
||||
quota_limit.is_some_and(|quota| {
|
||||
let used = stats.get_user_total_octets(user);
|
||||
used >= quota || bytes > quota.saturating_sub(used)
|
||||
})
|
||||
}
|
||||
|
||||
fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 {
|
||||
limit.saturating_add(overshoot)
|
||||
}
|
||||
|
||||
fn quota_would_be_exceeded_for_user_soft(
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
quota_limit: Option<u64>,
|
||||
async fn reserve_user_quota_with_yield(
|
||||
user_stats: &UserStats,
|
||||
bytes: u64,
|
||||
overshoot: u64,
|
||||
) -> bool {
|
||||
let capped_limit = quota_limit.map(|quota| quota_soft_cap(quota, overshoot));
|
||||
quota_would_be_exceeded_for_user(stats, user, capped_limit, bytes)
|
||||
limit: u64,
|
||||
) -> std::result::Result<u64, QuotaReserveError> {
|
||||
loop {
|
||||
for _ in 0..QUOTA_RESERVE_SPIN_RETRIES {
|
||||
match user_stats.quota_try_reserve(bytes, limit) {
|
||||
Ok(total) => return Ok(total),
|
||||
Err(QuotaReserveError::LimitExceeded) => {
|
||||
return Err(QuotaReserveError::LimitExceeded);
|
||||
}
|
||||
Err(QuotaReserveError::Contended) => std::hint::spin_loop(),
|
||||
}
|
||||
}
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
|
||||
fn classify_me_d2c_flush_reason(
|
||||
|
|
@ -613,29 +599,6 @@ fn observe_me_d2c_flush_event(
|
|||
}
|
||||
}
|
||||
|
||||
fn rollback_me2c_quota_reservation(
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
bytes_me2c: &AtomicU64,
|
||||
reserved_bytes: u64,
|
||||
) {
|
||||
stats.sub_user_octets_to(user, reserved_bytes);
|
||||
bytes_me2c.fetch_sub(reserved_bytes, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
|
||||
quota_user_lock_test_guard()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn relay_idle_pressure_test_guard() -> &'static Mutex<()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
|
|
@ -649,46 +612,6 @@ pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static,
|
|||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
fn quota_overflow_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
|
||||
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
|
||||
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
|
||||
.map(|_| Arc::new(AsyncMutex::new(())))
|
||||
.collect()
|
||||
});
|
||||
|
||||
let hash = crc32fast::hash(user.as_bytes()) as usize;
|
||||
Arc::clone(&stripes[hash % stripes.len()])
|
||||
}
|
||||
|
||||
fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
|
||||
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
if let Some(existing) = locks.get(user) {
|
||||
return Arc::clone(existing.value());
|
||||
}
|
||||
|
||||
if locks.len() >= QUOTA_USER_LOCKS_MAX {
|
||||
locks.retain(|_, value| Arc::strong_count(value) > 1);
|
||||
}
|
||||
|
||||
if locks.len() >= QUOTA_USER_LOCKS_MAX {
|
||||
return quota_overflow_user_lock(user);
|
||||
}
|
||||
|
||||
let created = Arc::new(AsyncMutex::new(()));
|
||||
match locks.entry(user.to_string()) {
|
||||
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
|
||||
dashmap::mapref::entry::Entry::Vacant(entry) => {
|
||||
entry.insert(Arc::clone(&created));
|
||||
created
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<AsyncMutex<()>> {
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user)
|
||||
}
|
||||
|
||||
async fn enqueue_c2me_command(
|
||||
tx: &mpsc::Sender<C2MeCommand>,
|
||||
cmd: C2MeCommand,
|
||||
|
|
@ -744,8 +667,7 @@ where
|
|||
{
|
||||
let user = success.user.clone();
|
||||
let quota_limit = config.access.user_data_quota.get(&user).copied();
|
||||
let cross_mode_quota_lock =
|
||||
quota_limit.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
|
||||
let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user));
|
||||
let peer = success.peer;
|
||||
let proto_tag = success.proto_tag;
|
||||
let pool_generation = me_pool.current_generation();
|
||||
|
|
@ -872,7 +794,7 @@ where
|
|||
let stats_clone = stats.clone();
|
||||
let rng_clone = rng.clone();
|
||||
let user_clone = user.clone();
|
||||
let cross_mode_quota_lock_me_writer = cross_mode_quota_lock.clone();
|
||||
let quota_user_stats_me_writer = quota_user_stats.clone();
|
||||
let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone();
|
||||
let bytes_me2c_clone = bytes_me2c.clone();
|
||||
let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config);
|
||||
|
|
@ -894,7 +816,7 @@ where
|
|||
|
||||
let first_is_downstream_activity =
|
||||
matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_));
|
||||
match process_me_writer_response_with_cross_mode_lock(
|
||||
match process_me_writer_response(
|
||||
first,
|
||||
&mut writer,
|
||||
proto_tag,
|
||||
|
|
@ -902,9 +824,9 @@ where
|
|||
&mut frame_buf,
|
||||
stats_clone.as_ref(),
|
||||
&user_clone,
|
||||
quota_user_stats_me_writer.as_deref(),
|
||||
quota_limit,
|
||||
d2c_flush_policy.quota_soft_overshoot_bytes,
|
||||
cross_mode_quota_lock_me_writer.as_ref(),
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -953,7 +875,7 @@ where
|
|||
|
||||
let next_is_downstream_activity =
|
||||
matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_));
|
||||
match process_me_writer_response_with_cross_mode_lock(
|
||||
match process_me_writer_response(
|
||||
next,
|
||||
&mut writer,
|
||||
proto_tag,
|
||||
|
|
@ -961,9 +883,9 @@ where
|
|||
&mut frame_buf,
|
||||
stats_clone.as_ref(),
|
||||
&user_clone,
|
||||
quota_user_stats_me_writer.as_deref(),
|
||||
quota_limit,
|
||||
d2c_flush_policy.quota_soft_overshoot_bytes,
|
||||
cross_mode_quota_lock_me_writer.as_ref(),
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -1015,7 +937,7 @@ where
|
|||
Ok(Some(next)) => {
|
||||
let next_is_downstream_activity =
|
||||
matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_));
|
||||
match process_me_writer_response_with_cross_mode_lock(
|
||||
match process_me_writer_response(
|
||||
next,
|
||||
&mut writer,
|
||||
proto_tag,
|
||||
|
|
@ -1023,9 +945,9 @@ where
|
|||
&mut frame_buf,
|
||||
stats_clone.as_ref(),
|
||||
&user_clone,
|
||||
quota_user_stats_me_writer.as_deref(),
|
||||
quota_limit,
|
||||
d2c_flush_policy.quota_soft_overshoot_bytes,
|
||||
cross_mode_quota_lock_me_writer.as_ref(),
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -1079,7 +1001,7 @@ where
|
|||
|
||||
let extra_is_downstream_activity =
|
||||
matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_));
|
||||
match process_me_writer_response_with_cross_mode_lock(
|
||||
match process_me_writer_response(
|
||||
extra,
|
||||
&mut writer,
|
||||
proto_tag,
|
||||
|
|
@ -1087,9 +1009,9 @@ where
|
|||
&mut frame_buf,
|
||||
stats_clone.as_ref(),
|
||||
&user_clone,
|
||||
quota_user_stats_me_writer.as_deref(),
|
||||
quota_limit,
|
||||
d2c_flush_policy.quota_soft_overshoot_bytes,
|
||||
cross_mode_quota_lock_me_writer.as_ref(),
|
||||
bytes_me2c_clone.as_ref(),
|
||||
conn_id,
|
||||
d2c_flush_policy.ack_flush_immediate,
|
||||
|
|
@ -1259,24 +1181,23 @@ where
|
|||
forensics.bytes_c2me = forensics
|
||||
.bytes_c2me
|
||||
.saturating_add(payload.len() as u64);
|
||||
if let Some(limit) = quota_limit {
|
||||
let quota_lock = quota_user_lock(&user);
|
||||
let _quota_guard = quota_lock.lock().await;
|
||||
let Some(cross_mode_lock) = cross_mode_quota_lock.as_ref() else {
|
||||
main_result = Err(ProxyError::Proxy(
|
||||
"cross-mode quota lock missing for quota-limited session"
|
||||
.to_string(),
|
||||
));
|
||||
break;
|
||||
};
|
||||
let _cross_mode_quota_guard = cross_mode_lock.lock().await;
|
||||
stats.add_user_octets_from(&user, payload.len() as u64);
|
||||
if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) {
|
||||
if let (Some(limit), Some(user_stats)) =
|
||||
(quota_limit, quota_user_stats.as_deref())
|
||||
{
|
||||
if reserve_user_quota_with_yield(
|
||||
user_stats,
|
||||
payload.len() as u64,
|
||||
limit,
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
main_result = Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.clone(),
|
||||
});
|
||||
break;
|
||||
}
|
||||
stats.add_user_octets_from_handle(user_stats, payload.len() as u64);
|
||||
} else {
|
||||
stats.add_user_octets_from(&user, payload.len() as u64);
|
||||
}
|
||||
|
|
@ -1602,8 +1523,7 @@ where
|
|||
}
|
||||
|
||||
if !idle_policy.enabled {
|
||||
consecutive_zero_len_frames =
|
||||
consecutive_zero_len_frames.saturating_add(1);
|
||||
consecutive_zero_len_frames = consecutive_zero_len_frames.saturating_add(1);
|
||||
if consecutive_zero_len_frames > LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES {
|
||||
stats.increment_relay_protocol_desync_close_total();
|
||||
return Err(ProxyError::Proxy(
|
||||
|
|
@ -1755,7 +1675,6 @@ enum MeWriterResponseOutcome {
|
|||
Close,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn process_me_writer_response<W>(
|
||||
response: MeResponse,
|
||||
client_writer: &mut CryptoWriter<W>,
|
||||
|
|
@ -1764,6 +1683,7 @@ async fn process_me_writer_response<W>(
|
|||
frame_buf: &mut Vec<u8>,
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
quota_user_stats: Option<&UserStats>,
|
||||
quota_limit: Option<u64>,
|
||||
quota_soft_overshoot_bytes: u64,
|
||||
bytes_me2c: &AtomicU64,
|
||||
|
|
@ -1771,44 +1691,6 @@ async fn process_me_writer_response<W>(
|
|||
ack_flush_immediate: bool,
|
||||
batched: bool,
|
||||
) -> Result<MeWriterResponseOutcome>
|
||||
where
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
response,
|
||||
client_writer,
|
||||
proto_tag,
|
||||
rng,
|
||||
frame_buf,
|
||||
stats,
|
||||
user,
|
||||
quota_limit,
|
||||
quota_soft_overshoot_bytes,
|
||||
None,
|
||||
bytes_me2c,
|
||||
conn_id,
|
||||
ack_flush_immediate,
|
||||
batched,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn process_me_writer_response_with_cross_mode_lock<W>(
|
||||
response: MeResponse,
|
||||
client_writer: &mut CryptoWriter<W>,
|
||||
proto_tag: ProtoTag,
|
||||
rng: &SecureRandom,
|
||||
frame_buf: &mut Vec<u8>,
|
||||
stats: &Stats,
|
||||
user: &str,
|
||||
quota_limit: Option<u64>,
|
||||
quota_soft_overshoot_bytes: u64,
|
||||
cross_mode_quota_lock: Option<&Arc<AsyncMutex<()>>>,
|
||||
bytes_me2c: &AtomicU64,
|
||||
conn_id: u64,
|
||||
ack_flush_immediate: bool,
|
||||
batched: bool,
|
||||
) -> Result<MeWriterResponseOutcome>
|
||||
where
|
||||
W: AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
|
|
@ -1820,78 +1702,43 @@ where
|
|||
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
|
||||
}
|
||||
let data_len = data.len() as u64;
|
||||
if let Some(limit) = quota_limit {
|
||||
let owned_cross_mode_lock;
|
||||
let cross_mode_lock = if let Some(lock) = cross_mode_quota_lock {
|
||||
lock
|
||||
} else {
|
||||
owned_cross_mode_lock =
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user);
|
||||
&owned_cross_mode_lock
|
||||
};
|
||||
let cross_mode_quota_guard = cross_mode_lock.lock().await;
|
||||
if let (Some(limit), Some(user_stats)) = (quota_limit, quota_user_stats) {
|
||||
let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes);
|
||||
if quota_would_be_exceeded_for_user_soft(
|
||||
stats,
|
||||
user,
|
||||
Some(limit),
|
||||
data_len,
|
||||
quota_soft_overshoot_bytes,
|
||||
) {
|
||||
if reserve_user_quota_with_yield(user_stats, data_len, soft_limit)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Reserve quota before awaiting network I/O to avoid same-user HoL stalls.
|
||||
// If reservation loses a race or write fails, we roll back immediately.
|
||||
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
|
||||
stats.add_user_octets_to(user, data_len);
|
||||
|
||||
if stats.get_user_total_octets(user) > soft_limit {
|
||||
rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len);
|
||||
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
|
||||
return Err(ProxyError::DataQuotaExceeded {
|
||||
user: user.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Keep cross-mode lock scope explicit and minimal: quota reservation is serialized,
|
||||
// but socket I/O proceeds without holding same-user cross-mode admission lock.
|
||||
drop(cross_mode_quota_guard);
|
||||
|
||||
let write_mode =
|
||||
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
|
||||
.await
|
||||
{
|
||||
Ok(mode) => mode,
|
||||
Err(err) => {
|
||||
rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
stats.increment_me_d2c_data_frames_total();
|
||||
stats.add_me_d2c_payload_bytes_total(data_len);
|
||||
stats.increment_me_d2c_write_mode(write_mode);
|
||||
|
||||
// Do not fail immediately on exact boundary after a successful write.
|
||||
// Returning an error here can bypass batch flush in the caller and risk
|
||||
// dropping buffered ciphertext from CryptoWriter. The next frame is
|
||||
// rejected by the pre-check at function entry.
|
||||
} else {
|
||||
let write_mode =
|
||||
write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
|
||||
.await?;
|
||||
|
||||
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
|
||||
stats.add_user_octets_to(user, data_len);
|
||||
stats.increment_me_d2c_data_frames_total();
|
||||
stats.add_me_d2c_payload_bytes_total(data_len);
|
||||
stats.increment_me_d2c_write_mode(write_mode);
|
||||
}
|
||||
|
||||
let write_mode =
|
||||
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
|
||||
.await
|
||||
{
|
||||
Ok(mode) => mode,
|
||||
Err(err) => {
|
||||
if quota_limit.is_some() {
|
||||
stats.add_quota_write_fail_bytes_total(data_len);
|
||||
stats.increment_quota_write_fail_events_total();
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
|
||||
if let Some(user_stats) = quota_user_stats {
|
||||
stats.add_user_octets_to_handle(user_stats, data_len);
|
||||
} else {
|
||||
stats.add_user_octets_to(user, data_len);
|
||||
}
|
||||
stats.increment_me_d2c_data_frames_total();
|
||||
stats.add_me_d2c_payload_bytes_total(data_len);
|
||||
stats.increment_me_d2c_write_mode(write_mode);
|
||||
|
||||
Ok(MeWriterResponseOutcome::Continue {
|
||||
frames: 1,
|
||||
bytes: data.len(),
|
||||
|
|
@ -1990,8 +1837,14 @@ where
|
|||
MeD2cWriteMode::Coalesced
|
||||
} else {
|
||||
let header = [first];
|
||||
client_writer.write_all(&header).await.map_err(ProxyError::Io)?;
|
||||
client_writer.write_all(data).await.map_err(ProxyError::Io)?;
|
||||
client_writer
|
||||
.write_all(&header)
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
client_writer
|
||||
.write_all(data)
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
MeD2cWriteMode::Split
|
||||
}
|
||||
} else if len_words < (1 << 24) {
|
||||
|
|
@ -2013,8 +1866,14 @@ where
|
|||
MeD2cWriteMode::Coalesced
|
||||
} else {
|
||||
let header = [first, lw[0], lw[1], lw[2]];
|
||||
client_writer.write_all(&header).await.map_err(ProxyError::Io)?;
|
||||
client_writer.write_all(data).await.map_err(ProxyError::Io)?;
|
||||
client_writer
|
||||
.write_all(&header)
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
client_writer
|
||||
.write_all(data)
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
MeD2cWriteMode::Split
|
||||
}
|
||||
} else {
|
||||
|
|
@ -2056,8 +1915,14 @@ where
|
|||
MeD2cWriteMode::Coalesced
|
||||
} else {
|
||||
let header = len_val.to_le_bytes();
|
||||
client_writer.write_all(&header).await.map_err(ProxyError::Io)?;
|
||||
client_writer.write_all(data).await.map_err(ProxyError::Io)?;
|
||||
client_writer
|
||||
.write_all(&header)
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
client_writer
|
||||
.write_all(data)
|
||||
.await
|
||||
.map_err(ProxyError::Io)?;
|
||||
if padding_len > 0 {
|
||||
frame_buf.clear();
|
||||
if frame_buf.capacity() < padding_len {
|
||||
|
|
@ -2097,10 +1962,6 @@ where
|
|||
.map_err(ProxyError::Io)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_security_tests.rs"]
|
||||
mod security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_idle_policy_security_tests.rs"]
|
||||
mod idle_policy_security_tests;
|
||||
|
|
@ -2113,30 +1974,10 @@ mod desync_all_full_dedup_security_tests;
|
|||
#[path = "tests/middle_relay_stub_completion_security_tests.rs"]
|
||||
mod stub_completion_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_coverage_high_risk_security_tests.rs"]
|
||||
mod coverage_high_risk_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_quota_overflow_lock_security_tests.rs"]
|
||||
mod quota_overflow_lock_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_length_cast_hardening_security_tests.rs"]
|
||||
mod length_cast_hardening_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"]
|
||||
mod blackhat_campaign_integration_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_hol_quota_security_tests.rs"]
|
||||
mod hol_quota_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_quota_reservation_adversarial_tests.rs"]
|
||||
mod quota_reservation_adversarial_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"]
|
||||
mod middle_relay_idle_registry_poison_security_tests;
|
||||
|
|
@ -2158,25 +1999,5 @@ mod middle_relay_tiny_frame_debt_concurrency_security_tests;
|
|||
mod middle_relay_tiny_frame_debt_proto_chunking_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_cross_mode_quota_reservation_security_tests.rs"]
|
||||
mod middle_relay_cross_mode_quota_reservation_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs"]
|
||||
mod middle_relay_cross_mode_quota_lock_matrix_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs"]
|
||||
mod middle_relay_cross_mode_lookup_efficiency_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs"]
|
||||
mod middle_relay_cross_mode_lock_release_regression_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_quota_extended_attack_surface_security_tests.rs"]
|
||||
mod middle_relay_quota_extended_attack_surface_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/middle_relay_quota_reservation_extreme_security_tests.rs"]
|
||||
mod middle_relay_quota_reservation_extreme_security_tests;
|
||||
#[path = "tests/middle_relay_atomic_quota_invariant_tests.rs"]
|
||||
mod middle_relay_atomic_quota_invariant_tests;
|
||||
|
|
|
|||
101
src/proxy/mod.rs
101
src/proxy/mod.rs
|
|
@ -4,58 +4,58 @@
|
|||
#![cfg_attr(test, allow(warnings))]
|
||||
#![cfg_attr(not(test), forbid(clippy::undocumented_unsafe_blocks))]
|
||||
#![cfg_attr(
|
||||
not(test),
|
||||
deny(
|
||||
clippy::unwrap_used,
|
||||
clippy::expect_used,
|
||||
clippy::panic,
|
||||
clippy::todo,
|
||||
clippy::unimplemented,
|
||||
clippy::correctness,
|
||||
clippy::option_if_let_else,
|
||||
clippy::or_fun_call,
|
||||
clippy::branches_sharing_code,
|
||||
clippy::single_option_map,
|
||||
clippy::useless_let_if_seq,
|
||||
clippy::redundant_locals,
|
||||
clippy::cloned_ref_to_slice_refs,
|
||||
unsafe_code,
|
||||
clippy::await_holding_lock,
|
||||
clippy::await_holding_refcell_ref,
|
||||
clippy::debug_assert_with_mut_call,
|
||||
clippy::macro_use_imports,
|
||||
clippy::cast_ptr_alignment,
|
||||
clippy::cast_lossless,
|
||||
clippy::ptr_as_ptr,
|
||||
clippy::large_stack_arrays,
|
||||
clippy::same_functions_in_if_condition,
|
||||
trivial_casts,
|
||||
trivial_numeric_casts,
|
||||
unused_extern_crates,
|
||||
unused_import_braces,
|
||||
rust_2018_idioms
|
||||
)
|
||||
not(test),
|
||||
deny(
|
||||
clippy::unwrap_used,
|
||||
clippy::expect_used,
|
||||
clippy::panic,
|
||||
clippy::todo,
|
||||
clippy::unimplemented,
|
||||
clippy::correctness,
|
||||
clippy::option_if_let_else,
|
||||
clippy::or_fun_call,
|
||||
clippy::branches_sharing_code,
|
||||
clippy::single_option_map,
|
||||
clippy::useless_let_if_seq,
|
||||
clippy::redundant_locals,
|
||||
clippy::cloned_ref_to_slice_refs,
|
||||
unsafe_code,
|
||||
clippy::await_holding_lock,
|
||||
clippy::await_holding_refcell_ref,
|
||||
clippy::debug_assert_with_mut_call,
|
||||
clippy::macro_use_imports,
|
||||
clippy::cast_ptr_alignment,
|
||||
clippy::cast_lossless,
|
||||
clippy::ptr_as_ptr,
|
||||
clippy::large_stack_arrays,
|
||||
clippy::same_functions_in_if_condition,
|
||||
trivial_casts,
|
||||
trivial_numeric_casts,
|
||||
unused_extern_crates,
|
||||
unused_import_braces,
|
||||
rust_2018_idioms
|
||||
)
|
||||
)]
|
||||
#![cfg_attr(
|
||||
not(test),
|
||||
allow(
|
||||
clippy::use_self,
|
||||
clippy::redundant_closure,
|
||||
clippy::too_many_arguments,
|
||||
clippy::doc_markdown,
|
||||
clippy::missing_const_for_fn,
|
||||
clippy::unnecessary_operation,
|
||||
clippy::redundant_pub_crate,
|
||||
clippy::derive_partial_eq_without_eq,
|
||||
clippy::type_complexity,
|
||||
clippy::new_ret_no_self,
|
||||
clippy::cast_possible_truncation,
|
||||
clippy::cast_possible_wrap,
|
||||
clippy::significant_drop_tightening,
|
||||
clippy::significant_drop_in_scrutinee,
|
||||
clippy::float_cmp,
|
||||
clippy::nursery
|
||||
)
|
||||
not(test),
|
||||
allow(
|
||||
clippy::use_self,
|
||||
clippy::redundant_closure,
|
||||
clippy::too_many_arguments,
|
||||
clippy::doc_markdown,
|
||||
clippy::missing_const_for_fn,
|
||||
clippy::unnecessary_operation,
|
||||
clippy::redundant_pub_crate,
|
||||
clippy::derive_partial_eq_without_eq,
|
||||
clippy::type_complexity,
|
||||
clippy::new_ret_no_self,
|
||||
clippy::cast_possible_truncation,
|
||||
clippy::cast_possible_wrap,
|
||||
clippy::significant_drop_tightening,
|
||||
clippy::significant_drop_in_scrutinee,
|
||||
clippy::float_cmp,
|
||||
clippy::nursery
|
||||
)
|
||||
)]
|
||||
|
||||
pub mod adaptive_buffers;
|
||||
|
|
@ -64,7 +64,6 @@ pub mod direct_relay;
|
|||
pub mod handshake;
|
||||
pub mod masking;
|
||||
pub mod middle_relay;
|
||||
pub mod quota_lock_registry;
|
||||
pub mod relay;
|
||||
pub mod route_mode;
|
||||
pub mod session_eviction;
|
||||
|
|
|
|||
|
|
@ -1,88 +0,0 @@
|
|||
use dashmap::DashMap;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
#[cfg(test)]
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
#[cfg(test)]
|
||||
const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 64;
|
||||
#[cfg(not(test))]
|
||||
const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 4_096;
|
||||
#[cfg(test)]
|
||||
const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
|
||||
#[cfg(not(test))]
|
||||
const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
|
||||
|
||||
static CROSS_MODE_QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
|
||||
static CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<Mutex<()>>>> = OnceLock::new();
|
||||
|
||||
#[cfg(test)]
|
||||
static CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS: AtomicUsize = AtomicUsize::new(0);
|
||||
#[cfg(test)]
|
||||
static CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER: OnceLock<DashMap<String, usize>> = OnceLock::new();
|
||||
|
||||
fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
let stripes = CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
|
||||
(0..CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES)
|
||||
.map(|_| Arc::new(Mutex::new(())))
|
||||
.collect()
|
||||
});
|
||||
|
||||
let hash = crc32fast::hash(user.as_bytes()) as usize;
|
||||
Arc::clone(&stripes[hash % stripes.len()])
|
||||
}
|
||||
|
||||
pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
#[cfg(test)]
|
||||
{
|
||||
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.fetch_add(1, Ordering::Relaxed);
|
||||
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
|
||||
let mut entry = lookups.entry(user.to_string()).or_insert(0);
|
||||
*entry += 1;
|
||||
}
|
||||
|
||||
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
if let Some(existing) = locks.get(user) {
|
||||
return Arc::clone(existing.value());
|
||||
}
|
||||
|
||||
if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX {
|
||||
locks.retain(|_, value| Arc::strong_count(value) > 1);
|
||||
}
|
||||
|
||||
if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX {
|
||||
return cross_mode_quota_overflow_user_lock(user);
|
||||
}
|
||||
|
||||
let created = Arc::new(Mutex::new(()));
|
||||
match locks.entry(user.to_string()) {
|
||||
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
|
||||
dashmap::mapref::entry::Entry::Vacant(entry) => {
|
||||
entry.insert(Arc::clone(&created));
|
||||
created
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn reset_cross_mode_quota_user_lock_lookup_count_for_tests() {
|
||||
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.store(0, Ordering::Relaxed);
|
||||
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
|
||||
lookups.clear();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_tests() -> usize {
|
||||
CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_user_for_tests(user: &str) -> usize {
|
||||
let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new);
|
||||
lookups.get(user).map(|entry| *entry).unwrap_or(0)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/quota_lock_registry_cross_mode_adversarial_tests.rs"]
|
||||
mod quota_lock_registry_cross_mode_adversarial_tests;
|
||||
|
|
@ -52,18 +52,16 @@
|
|||
//! - `SharedCounters` (atomics) let the watchdog read stats without locking
|
||||
|
||||
use crate::error::{ProxyError, Result};
|
||||
use crate::stats::Stats;
|
||||
use crate::stats::{Stats, UserStats};
|
||||
use crate::stream::BufferPool;
|
||||
use dashmap::DashMap;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use std::task::{Context, Poll};
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes};
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
use tokio::time::{Instant, Sleep};
|
||||
use tokio::time::Instant;
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
// ============= Constants =============
|
||||
|
|
@ -210,16 +208,10 @@ struct StatsIo<S> {
|
|||
counters: Arc<SharedCounters>,
|
||||
stats: Arc<Stats>,
|
||||
user: String,
|
||||
quota_lock: Option<Arc<Mutex<()>>>,
|
||||
cross_mode_quota_lock: Option<Arc<AsyncMutex<()>>>,
|
||||
user_stats: Arc<UserStats>,
|
||||
quota_limit: Option<u64>,
|
||||
quota_exceeded: Arc<AtomicBool>,
|
||||
quota_read_wake_scheduled: bool,
|
||||
quota_write_wake_scheduled: bool,
|
||||
quota_read_retry_sleep: Option<Pin<Box<Sleep>>>,
|
||||
quota_write_retry_sleep: Option<Pin<Box<Sleep>>>,
|
||||
quota_read_retry_attempt: u8,
|
||||
quota_write_retry_attempt: u8,
|
||||
quota_bytes_since_check: u64,
|
||||
epoch: Instant,
|
||||
}
|
||||
|
||||
|
|
@ -235,24 +227,16 @@ impl<S> StatsIo<S> {
|
|||
) -> Self {
|
||||
// Mark initial activity so the watchdog doesn't fire before data flows
|
||||
counters.touch(Instant::now(), epoch);
|
||||
let quota_lock = quota_limit.map(|_| quota_user_lock(&user));
|
||||
let cross_mode_quota_lock = quota_limit
|
||||
.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
|
||||
let user_stats = stats.get_or_create_user_stats_handle(&user);
|
||||
Self {
|
||||
inner,
|
||||
counters,
|
||||
stats,
|
||||
user,
|
||||
quota_lock,
|
||||
cross_mode_quota_lock,
|
||||
user_stats,
|
||||
quota_limit,
|
||||
quota_exceeded,
|
||||
quota_read_wake_scheduled: false,
|
||||
quota_write_wake_scheduled: false,
|
||||
quota_read_retry_sleep: None,
|
||||
quota_write_retry_sleep: None,
|
||||
quota_read_retry_attempt: 0,
|
||||
quota_write_retry_attempt: 0,
|
||||
quota_bytes_since_check: 0,
|
||||
epoch,
|
||||
}
|
||||
}
|
||||
|
|
@ -281,193 +265,22 @@ fn is_quota_io_error(err: &io::Error) -> bool {
|
|||
.is_some()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1);
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2);
|
||||
#[cfg(test)]
|
||||
const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(16);
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(64);
|
||||
const QUOTA_NEAR_LIMIT_BYTES: u64 = 64 * 1024;
|
||||
const QUOTA_LARGE_CHARGE_BYTES: u64 = 16 * 1024;
|
||||
const QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES: u64 = 4 * 1024;
|
||||
const QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES: u64 = 64 * 1024;
|
||||
|
||||
#[cfg(test)]
|
||||
static QUOTA_RETRY_SLEEP_ALLOCS: AtomicU64 = AtomicU64::new(0);
|
||||
#[cfg(test)]
|
||||
static QUOTA_RETRY_SLEEP_ALLOCS_BY_USER: OnceLock<DashMap<String, u64>> = OnceLock::new();
|
||||
#[cfg(test)]
|
||||
static QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn reset_quota_retry_sleep_allocs_for_tests() {
|
||||
QUOTA_RETRY_SLEEP_ALLOCS.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn reset_quota_retry_sleep_allocs_for_user_for_tests(user: &str) {
|
||||
let map = QUOTA_RETRY_SLEEP_ALLOCS_BY_USER.get_or_init(DashMap::new);
|
||||
map.remove(user);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn quota_retry_sleep_allocs_for_tests() -> u64 {
|
||||
QUOTA_RETRY_SLEEP_ALLOCS.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn quota_retry_sleep_allocs_for_user_for_tests(user: &str) -> u64 {
|
||||
let map = QUOTA_RETRY_SLEEP_ALLOCS_BY_USER.get_or_init(DashMap::new);
|
||||
map.get(user).map(|v| *v.value()).unwrap_or(0)
|
||||
#[inline]
|
||||
fn quota_adaptive_interval_bytes(remaining_before: u64) -> u64 {
|
||||
remaining_before.saturating_div(2).clamp(
|
||||
QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES,
|
||||
QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn quota_contention_retry_delay(retry_attempt: u8) -> Duration {
|
||||
let shift = u32::from(retry_attempt.min(5));
|
||||
let multiplier = 1_u32 << shift;
|
||||
QUOTA_CONTENTION_RETRY_INTERVAL
|
||||
.saturating_mul(multiplier)
|
||||
.min(QUOTA_CONTENTION_RETRY_MAX_INTERVAL)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn reset_quota_retry_scheduler(
|
||||
sleep_slot: &mut Option<Pin<Box<Sleep>>>,
|
||||
wake_scheduled: &mut bool,
|
||||
retry_attempt: &mut u8,
|
||||
) {
|
||||
*wake_scheduled = false;
|
||||
*sleep_slot = None;
|
||||
*retry_attempt = 0;
|
||||
}
|
||||
|
||||
fn poll_quota_retry_sleep(
|
||||
sleep_slot: &mut Option<Pin<Box<Sleep>>>,
|
||||
wake_scheduled: &mut bool,
|
||||
retry_attempt: &mut u8,
|
||||
user: &str,
|
||||
cx: &mut Context<'_>,
|
||||
) {
|
||||
#[cfg(not(test))]
|
||||
let _ = user;
|
||||
|
||||
if !*wake_scheduled {
|
||||
*wake_scheduled = true;
|
||||
#[cfg(test)]
|
||||
{
|
||||
QUOTA_RETRY_SLEEP_ALLOCS.fetch_add(1, Ordering::Relaxed);
|
||||
let map = QUOTA_RETRY_SLEEP_ALLOCS_BY_USER.get_or_init(DashMap::new);
|
||||
map.entry(user.to_string())
|
||||
.and_modify(|count| *count = count.saturating_add(1))
|
||||
.or_insert(1);
|
||||
}
|
||||
*sleep_slot = Some(Box::pin(tokio::time::sleep(quota_contention_retry_delay(
|
||||
*retry_attempt,
|
||||
))));
|
||||
}
|
||||
|
||||
if let Some(sleep) = sleep_slot.as_mut()
|
||||
&& sleep.as_mut().poll(cx).is_ready()
|
||||
{
|
||||
*sleep_slot = None;
|
||||
*wake_scheduled = false;
|
||||
*retry_attempt = retry_attempt.saturating_add(1);
|
||||
cx.waker().wake_by_ref();
|
||||
}
|
||||
}
|
||||
|
||||
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
|
||||
static QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<Mutex<()>>>> = OnceLock::new();
|
||||
|
||||
#[cfg(test)]
|
||||
const QUOTA_USER_LOCKS_MAX: usize = 64;
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
|
||||
#[cfg(test)]
|
||||
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
|
||||
#[cfg(not(test))]
|
||||
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
|
||||
|
||||
#[cfg(test)]
|
||||
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
TEST_LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
|
||||
quota_user_lock_test_guard()
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
fn quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
|
||||
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
|
||||
.map(|_| Arc::new(Mutex::new(())))
|
||||
.collect()
|
||||
});
|
||||
|
||||
let hash = crc32fast::hash(user.as_bytes()) as usize;
|
||||
Arc::clone(&stripes[hash % stripes.len()])
|
||||
}
|
||||
|
||||
pub(crate) fn quota_user_lock_evict() {
|
||||
if let Some(locks) = QUOTA_USER_LOCKS.get() {
|
||||
locks.retain(|_, value| Arc::strong_count(value) > 1);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn spawn_quota_user_lock_evictor(interval: Duration) -> tokio::task::JoinHandle<()> {
|
||||
let interval = interval.max(Duration::from_millis(1));
|
||||
#[cfg(test)]
|
||||
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.fetch_add(1, Ordering::Relaxed);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
quota_user_lock_evict();
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn spawn_quota_user_lock_evictor_for_tests(
|
||||
interval: Duration,
|
||||
) -> tokio::task::JoinHandle<()> {
|
||||
spawn_quota_user_lock_evictor(interval)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn reset_quota_user_lock_evictor_spawn_count_for_tests() {
|
||||
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.store(0, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn quota_user_lock_evictor_spawn_count_for_tests() -> u64 {
|
||||
QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
|
||||
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
if let Some(existing) = locks.get(user) {
|
||||
return Arc::clone(existing.value());
|
||||
}
|
||||
|
||||
if locks.len() >= QUOTA_USER_LOCKS_MAX {
|
||||
return quota_overflow_user_lock(user);
|
||||
}
|
||||
|
||||
let created = Arc::new(Mutex::new(()));
|
||||
match locks.entry(user.to_string()) {
|
||||
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
|
||||
dashmap::mapref::entry::Entry::Vacant(entry) => {
|
||||
entry.insert(Arc::clone(&created));
|
||||
created
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc<AsyncMutex<()>> {
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user)
|
||||
fn should_immediate_quota_check(remaining_before: u64, charge_bytes: u64) -> bool {
|
||||
remaining_before <= QUOTA_NEAR_LIMIT_BYTES || charge_bytes >= QUOTA_LARGE_CHARGE_BYTES
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
||||
|
|
@ -477,95 +290,60 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
|
|||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
let this = self.get_mut();
|
||||
if this.quota_exceeded.load(Ordering::Relaxed) {
|
||||
if this.quota_exceeded.load(Ordering::Acquire) {
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
|
||||
let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => Some(guard),
|
||||
Err(_) => {
|
||||
poll_quota_retry_sleep(
|
||||
&mut this.quota_read_retry_sleep,
|
||||
&mut this.quota_read_wake_scheduled,
|
||||
&mut this.quota_read_retry_attempt,
|
||||
&this.user,
|
||||
cx,
|
||||
);
|
||||
return Poll::Pending;
|
||||
}
|
||||
let mut remaining_before = None;
|
||||
if let Some(limit) = this.quota_limit {
|
||||
let used_before = this.user_stats.quota_used();
|
||||
let remaining = limit.saturating_sub(used_before);
|
||||
if remaining == 0 {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => Some(guard),
|
||||
Err(_) => {
|
||||
poll_quota_retry_sleep(
|
||||
&mut this.quota_read_retry_sleep,
|
||||
&mut this.quota_read_wake_scheduled,
|
||||
&mut this.quota_read_retry_attempt,
|
||||
&this.user,
|
||||
cx,
|
||||
);
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
reset_quota_retry_scheduler(
|
||||
&mut this.quota_read_retry_sleep,
|
||||
&mut this.quota_read_wake_scheduled,
|
||||
&mut this.quota_read_retry_attempt,
|
||||
);
|
||||
|
||||
if let Some(limit) = this.quota_limit
|
||||
&& this.stats.get_user_total_octets(&this.user) >= limit
|
||||
{
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
remaining_before = Some(remaining);
|
||||
}
|
||||
|
||||
let before = buf.filled().len();
|
||||
|
||||
match Pin::new(&mut this.inner).poll_read(cx, buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let n = buf.filled().len() - before;
|
||||
if n > 0 {
|
||||
let mut reached_quota_boundary = false;
|
||||
if let Some(limit) = this.quota_limit {
|
||||
let used = this.stats.get_user_total_octets(&this.user);
|
||||
if used >= limit {
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
|
||||
let remaining = limit - used;
|
||||
if (n as u64) > remaining {
|
||||
// Fail closed: when a single read chunk would cross quota,
|
||||
// stop relay immediately without accounting beyond the cap.
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
|
||||
reached_quota_boundary = (n as u64) == remaining;
|
||||
}
|
||||
let n_to_charge = n as u64;
|
||||
|
||||
// C→S: client sent data
|
||||
this.counters
|
||||
.c2s_bytes
|
||||
.fetch_add(n as u64, Ordering::Relaxed);
|
||||
.fetch_add(n_to_charge, Ordering::Relaxed);
|
||||
this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed);
|
||||
this.counters.touch(Instant::now(), this.epoch);
|
||||
|
||||
this.stats.add_user_octets_from(&this.user, n as u64);
|
||||
this.stats.increment_user_msgs_from(&this.user);
|
||||
this.stats
|
||||
.add_user_octets_from_handle(this.user_stats.as_ref(), n_to_charge);
|
||||
this.stats
|
||||
.increment_user_msgs_from_handle(this.user_stats.as_ref());
|
||||
|
||||
if reached_quota_boundary {
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) {
|
||||
this.stats
|
||||
.quota_charge_post_write(this.user_stats.as_ref(), n_to_charge);
|
||||
if should_immediate_quota_check(remaining, n_to_charge) {
|
||||
this.quota_bytes_since_check = 0;
|
||||
if this.user_stats.quota_used() >= limit {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
}
|
||||
} else {
|
||||
this.quota_bytes_since_check =
|
||||
this.quota_bytes_since_check.saturating_add(n_to_charge);
|
||||
let interval = quota_adaptive_interval_bytes(remaining);
|
||||
if this.quota_bytes_since_check >= interval {
|
||||
this.quota_bytes_since_check = 0;
|
||||
if this.user_stats.quota_used() >= limit {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trace!(user = %this.user, bytes = n, "C->S");
|
||||
|
|
@ -584,89 +362,57 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
|
|||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
if this.quota_exceeded.load(Ordering::Relaxed) {
|
||||
if this.quota_exceeded.load(Ordering::Acquire) {
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
|
||||
let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => Some(guard),
|
||||
Err(_) => {
|
||||
poll_quota_retry_sleep(
|
||||
&mut this.quota_write_retry_sleep,
|
||||
&mut this.quota_write_wake_scheduled,
|
||||
&mut this.quota_write_retry_attempt,
|
||||
&this.user,
|
||||
cx,
|
||||
);
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() {
|
||||
match lock.try_lock() {
|
||||
Ok(guard) => Some(guard),
|
||||
Err(_) => {
|
||||
poll_quota_retry_sleep(
|
||||
&mut this.quota_write_retry_sleep,
|
||||
&mut this.quota_write_wake_scheduled,
|
||||
&mut this.quota_write_retry_attempt,
|
||||
&this.user,
|
||||
cx,
|
||||
);
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
reset_quota_retry_scheduler(
|
||||
&mut this.quota_write_retry_sleep,
|
||||
&mut this.quota_write_wake_scheduled,
|
||||
&mut this.quota_write_retry_attempt,
|
||||
);
|
||||
|
||||
let write_buf = if let Some(limit) = this.quota_limit {
|
||||
let used = this.stats.get_user_total_octets(&this.user);
|
||||
if used >= limit {
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
let mut remaining_before = None;
|
||||
if let Some(limit) = this.quota_limit {
|
||||
let used_before = this.user_stats.quota_used();
|
||||
let remaining = limit.saturating_sub(used_before);
|
||||
if remaining == 0 {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
remaining_before = Some(remaining);
|
||||
}
|
||||
|
||||
let remaining = (limit - used) as usize;
|
||||
if buf.len() > remaining {
|
||||
// Fail closed: do not emit partial S->C payload when remaining
|
||||
// quota cannot accommodate the pending write request.
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
}
|
||||
buf
|
||||
} else {
|
||||
buf
|
||||
};
|
||||
|
||||
match Pin::new(&mut this.inner).poll_write(cx, write_buf) {
|
||||
match Pin::new(&mut this.inner).poll_write(cx, buf) {
|
||||
Poll::Ready(Ok(n)) => {
|
||||
if n > 0 {
|
||||
let n_to_charge = n as u64;
|
||||
|
||||
// S→C: data written to client
|
||||
this.counters
|
||||
.s2c_bytes
|
||||
.fetch_add(n as u64, Ordering::Relaxed);
|
||||
.fetch_add(n_to_charge, Ordering::Relaxed);
|
||||
this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed);
|
||||
this.counters.touch(Instant::now(), this.epoch);
|
||||
|
||||
this.stats.add_user_octets_to(&this.user, n as u64);
|
||||
this.stats.increment_user_msgs_to(&this.user);
|
||||
this.stats
|
||||
.add_user_octets_to_handle(this.user_stats.as_ref(), n_to_charge);
|
||||
this.stats
|
||||
.increment_user_msgs_to_handle(this.user_stats.as_ref());
|
||||
|
||||
if let Some(limit) = this.quota_limit
|
||||
&& this.stats.get_user_total_octets(&this.user) >= limit
|
||||
{
|
||||
this.quota_exceeded.store(true, Ordering::Relaxed);
|
||||
return Poll::Ready(Err(quota_io_error()));
|
||||
if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) {
|
||||
this.stats
|
||||
.quota_charge_post_write(this.user_stats.as_ref(), n_to_charge);
|
||||
if should_immediate_quota_check(remaining, n_to_charge) {
|
||||
this.quota_bytes_since_check = 0;
|
||||
if this.user_stats.quota_used() >= limit {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
}
|
||||
} else {
|
||||
this.quota_bytes_since_check =
|
||||
this.quota_bytes_since_check.saturating_add(n_to_charge);
|
||||
let interval = quota_adaptive_interval_bytes(remaining);
|
||||
if this.quota_bytes_since_check >= interval {
|
||||
this.quota_bytes_since_check = 0;
|
||||
if this.user_stats.quota_used() >= limit {
|
||||
this.quota_exceeded.store(true, Ordering::Release);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trace!(user = %this.user, bytes = n, "S->C");
|
||||
|
|
@ -760,7 +506,7 @@ where
|
|||
let now = Instant::now();
|
||||
let idle = wd_counters.idle_duration(now, epoch);
|
||||
|
||||
if wd_quota_exceeded.load(Ordering::Relaxed) {
|
||||
if wd_quota_exceeded.load(Ordering::Acquire) {
|
||||
warn!(user = %wd_user, "User data quota reached, closing relay");
|
||||
return;
|
||||
}
|
||||
|
|
@ -898,18 +644,10 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_security_tests.rs"]
|
||||
mod security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_adversarial_tests.rs"]
|
||||
mod adversarial_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"]
|
||||
mod relay_quota_lock_pressure_adversarial_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_boundary_blackhat_tests.rs"]
|
||||
mod relay_quota_boundary_blackhat_tests;
|
||||
|
|
@ -931,69 +669,5 @@ mod relay_quota_extended_attack_surface_security_tests;
|
|||
mod relay_watchdog_delta_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_waker_storm_adversarial_tests.rs"]
|
||||
mod relay_quota_waker_storm_adversarial_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"]
|
||||
mod relay_quota_wake_liveness_regression_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_lock_identity_security_tests.rs"]
|
||||
mod relay_quota_lock_identity_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_cross_mode_quota_lock_security_tests.rs"]
|
||||
mod relay_cross_mode_quota_lock_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_retry_scheduler_tdd_tests.rs"]
|
||||
mod relay_quota_retry_scheduler_tdd_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_cross_mode_quota_fairness_tdd_tests.rs"]
|
||||
mod relay_cross_mode_quota_fairness_tdd_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs"]
|
||||
mod relay_cross_mode_pipeline_hol_integration_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs"]
|
||||
mod relay_cross_mode_pipeline_latency_benchmark_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_retry_backoff_security_tests.rs"]
|
||||
mod relay_quota_retry_backoff_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_retry_backoff_benchmark_security_tests.rs"]
|
||||
mod relay_quota_retry_backoff_benchmark_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_dual_lock_backoff_regression_security_tests.rs"]
|
||||
mod relay_dual_lock_backoff_regression_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_dual_lock_contention_matrix_security_tests.rs"]
|
||||
mod relay_dual_lock_contention_matrix_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_dual_lock_race_harness_security_tests.rs"]
|
||||
mod relay_dual_lock_race_harness_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_dual_lock_alternating_contention_security_tests.rs"]
|
||||
mod relay_dual_lock_alternating_contention_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_retry_allocation_latency_security_tests.rs"]
|
||||
mod relay_quota_retry_allocation_latency_security_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs"]
|
||||
mod relay_quota_lock_eviction_lifecycle_tdd_tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "tests/relay_quota_lock_eviction_stress_security_tests.rs"]
|
||||
mod relay_quota_lock_eviction_stress_security_tests;
|
||||
#[path = "tests/relay_atomic_quota_invariant_tests.rs"]
|
||||
mod relay_atomic_quota_invariant_tests;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use super::*;
|
||||
use crate::config::{UpstreamConfig, UpstreamType, ProxyConfig};
|
||||
use crate::config::{ProxyConfig, UpstreamConfig, UpstreamType};
|
||||
use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE};
|
||||
use crate::stats::Stats;
|
||||
use crate::transport::UpstreamManager;
|
||||
|
|
@ -41,7 +41,9 @@ fn edge_handshake_timeout_with_mask_grace_saturating_add_prevents_overflow() {
|
|||
#[test]
|
||||
fn edge_tls_clienthello_len_in_bounds_exact_boundaries() {
|
||||
assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE));
|
||||
assert!(!tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE - 1));
|
||||
assert!(!tls_clienthello_len_in_bounds(
|
||||
MIN_TLS_CLIENT_HELLO_SIZE - 1
|
||||
));
|
||||
assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE));
|
||||
assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1));
|
||||
}
|
||||
|
|
@ -87,7 +89,15 @@ async fn adversarial_tls_handshake_timeout_during_masking_delay() {
|
|||
"198.51.100.1:55000".parse().unwrap(),
|
||||
config,
|
||||
stats.clone(),
|
||||
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
|
||||
Arc::new(UpstreamManager::new(
|
||||
vec![],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
)),
|
||||
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
|
||||
Arc::new(BufferPool::new()),
|
||||
Arc::new(SecureRandom::new()),
|
||||
|
|
@ -99,7 +109,10 @@ async fn adversarial_tls_handshake_timeout_during_masking_delay() {
|
|||
false,
|
||||
));
|
||||
|
||||
client_side.write_all(&[0x16, 0x03, 0x01, 0xFF, 0xFF]).await.unwrap();
|
||||
client_side
|
||||
.write_all(&[0x16, 0x03, 0x01, 0xFF, 0xFF])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = tokio::time::timeout(Duration::from_secs(4), handle)
|
||||
.await
|
||||
|
|
@ -123,7 +136,15 @@ async fn blackhat_proxy_protocol_slowloris_timeout() {
|
|||
"198.51.100.2:55000".parse().unwrap(),
|
||||
config,
|
||||
stats.clone(),
|
||||
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
|
||||
Arc::new(UpstreamManager::new(
|
||||
vec![],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
)),
|
||||
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
|
||||
Arc::new(BufferPool::new()),
|
||||
Arc::new(SecureRandom::new()),
|
||||
|
|
@ -167,7 +188,15 @@ async fn negative_proxy_protocol_enabled_but_client_sends_tls_hello() {
|
|||
"198.51.100.3:55000".parse().unwrap(),
|
||||
config,
|
||||
stats.clone(),
|
||||
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
|
||||
Arc::new(UpstreamManager::new(
|
||||
vec![],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
)),
|
||||
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
|
||||
Arc::new(BufferPool::new()),
|
||||
Arc::new(SecureRandom::new()),
|
||||
|
|
@ -179,7 +208,10 @@ async fn negative_proxy_protocol_enabled_but_client_sends_tls_hello() {
|
|||
true,
|
||||
));
|
||||
|
||||
client_side.write_all(&[0x16, 0x03, 0x01, 0x02, 0x00]).await.unwrap();
|
||||
client_side
|
||||
.write_all(&[0x16, 0x03, 0x01, 0x02, 0x00])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = tokio::time::timeout(Duration::from_secs(2), handle)
|
||||
.await
|
||||
|
|
@ -202,7 +234,15 @@ async fn edge_client_stream_exactly_4_bytes_eof() {
|
|||
"198.51.100.4:55000".parse().unwrap(),
|
||||
config,
|
||||
stats.clone(),
|
||||
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
|
||||
Arc::new(UpstreamManager::new(
|
||||
vec![],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
)),
|
||||
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
|
||||
Arc::new(BufferPool::new()),
|
||||
Arc::new(SecureRandom::new()),
|
||||
|
|
@ -214,7 +254,10 @@ async fn edge_client_stream_exactly_4_bytes_eof() {
|
|||
false,
|
||||
));
|
||||
|
||||
client_side.write_all(&[0x16, 0x03, 0x01, 0x00]).await.unwrap();
|
||||
client_side
|
||||
.write_all(&[0x16, 0x03, 0x01, 0x00])
|
||||
.await
|
||||
.unwrap();
|
||||
client_side.shutdown().await.unwrap();
|
||||
|
||||
let _ = tokio::time::timeout(Duration::from_secs(2), handle).await;
|
||||
|
|
@ -234,7 +277,15 @@ async fn edge_client_stream_tls_header_valid_but_body_1_byte_short_eof() {
|
|||
"198.51.100.5:55000".parse().unwrap(),
|
||||
config,
|
||||
stats.clone(),
|
||||
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
|
||||
Arc::new(UpstreamManager::new(
|
||||
vec![],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
)),
|
||||
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
|
||||
Arc::new(BufferPool::new()),
|
||||
Arc::new(SecureRandom::new()),
|
||||
|
|
@ -246,7 +297,10 @@ async fn edge_client_stream_tls_header_valid_but_body_1_byte_short_eof() {
|
|||
false,
|
||||
));
|
||||
|
||||
client_side.write_all(&[0x16, 0x03, 0x01, 0x00, 100]).await.unwrap();
|
||||
client_side
|
||||
.write_all(&[0x16, 0x03, 0x01, 0x00, 100])
|
||||
.await
|
||||
.unwrap();
|
||||
client_side.write_all(&vec![0x41; 99]).await.unwrap();
|
||||
client_side.shutdown().await.unwrap();
|
||||
|
||||
|
|
@ -269,7 +323,15 @@ async fn integration_non_tls_modes_disabled_immediately_masks() {
|
|||
"198.51.100.6:55000".parse().unwrap(),
|
||||
config,
|
||||
stats.clone(),
|
||||
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
|
||||
Arc::new(UpstreamManager::new(
|
||||
vec![],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
)),
|
||||
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
|
||||
Arc::new(BufferPool::new()),
|
||||
Arc::new(SecureRandom::new()),
|
||||
|
|
@ -372,11 +434,7 @@ async fn stress_user_connection_reservation_concurrent_same_ip_exhaustion() {
|
|||
let ip_tracker = ip_tracker.clone();
|
||||
tasks.spawn(async move {
|
||||
RunningClientHandler::acquire_user_connection_reservation_static(
|
||||
user,
|
||||
&config,
|
||||
stats,
|
||||
peer,
|
||||
ip_tracker,
|
||||
user, &config, stats, peer, ip_tracker,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
|
|
|||
|
|
@ -7,6 +7,11 @@ use std::sync::Arc;
|
|||
use std::time::Duration;
|
||||
use tokio::io::{AsyncWriteExt, duplex};
|
||||
|
||||
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
|
||||
let user_stats = stats.get_or_create_user_stats_handle(user);
|
||||
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invariant_wrap_tls_application_record_exact_multiples() {
|
||||
let chunk_size = u16::MAX as usize;
|
||||
|
|
@ -37,7 +42,15 @@ async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking()
|
|||
"198.51.100.20:55000".parse().unwrap(),
|
||||
config,
|
||||
stats.clone(),
|
||||
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
|
||||
Arc::new(UpstreamManager::new(
|
||||
vec![],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
)),
|
||||
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
|
||||
Arc::new(BufferPool::new()),
|
||||
Arc::new(SecureRandom::new()),
|
||||
|
|
@ -60,7 +73,9 @@ async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking()
|
|||
.unwrap();
|
||||
client_side.shutdown().await.unwrap();
|
||||
|
||||
let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap();
|
||||
let _ = tokio::time::timeout(Duration::from_secs(2), handler)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(stats.get_connects_bad(), 1);
|
||||
}
|
||||
|
||||
|
|
@ -68,7 +83,10 @@ async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking()
|
|||
async fn invariant_acquire_reservation_ip_limit_rollback() {
|
||||
let user = "rollback-test-user";
|
||||
let mut config = ProxyConfig::default();
|
||||
config.access.user_max_tcp_conns.insert(user.to_string(), 10);
|
||||
config
|
||||
.access
|
||||
.user_max_tcp_conns
|
||||
.insert(user.to_string(), 10);
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
|
|
@ -114,7 +132,7 @@ async fn invariant_quota_exact_boundary_inclusive() {
|
|||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
let peer = "198.51.100.23:55000".parse().unwrap();
|
||||
|
||||
stats.add_user_octets_from(user, 999);
|
||||
preload_user_quota(stats.as_ref(), user, 999);
|
||||
let res1 = RunningClientHandler::acquire_user_connection_reservation_static(
|
||||
user,
|
||||
&config,
|
||||
|
|
@ -126,7 +144,7 @@ async fn invariant_quota_exact_boundary_inclusive() {
|
|||
assert!(res1.is_ok());
|
||||
res1.unwrap().release().await;
|
||||
|
||||
stats.add_user_octets_from(user, 1);
|
||||
preload_user_quota(stats.as_ref(), user, 1);
|
||||
let res2 = RunningClientHandler::acquire_user_connection_reservation_static(
|
||||
user,
|
||||
&config,
|
||||
|
|
@ -154,7 +172,15 @@ async fn invariant_direct_mode_partial_header_eof_is_error_not_bad_connect() {
|
|||
"198.51.100.25:55000".parse().unwrap(),
|
||||
config,
|
||||
stats.clone(),
|
||||
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
|
||||
Arc::new(UpstreamManager::new(
|
||||
vec![],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
)),
|
||||
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
|
||||
Arc::new(BufferPool::new()),
|
||||
Arc::new(SecureRandom::new()),
|
||||
|
|
|
|||
|
|
@ -100,14 +100,7 @@ async fn run_http2_fragment_case(split_at: usize, delay_ms: u64, peer: SocketAdd
|
|||
|
||||
#[tokio::test]
|
||||
async fn http2_preface_fragmentation_matrix_is_classified_and_forwarded() {
|
||||
let cases = [
|
||||
(2usize, 0u64),
|
||||
(3, 0),
|
||||
(4, 0),
|
||||
(2, 7),
|
||||
(3, 7),
|
||||
(8, 1),
|
||||
];
|
||||
let cases = [(2usize, 0u64), (3, 0), (4, 0), (2, 7), (3, 7), (8, 1)];
|
||||
|
||||
for (i, (split_at, delay_ms)) in cases.into_iter().enumerate() {
|
||||
let peer: SocketAddr = format!("198.51.100.{}:58{}", 140 + i, 100 + i)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,10 @@ async fn configured_prefetch_budget_20ms_recovers_tail_delayed_15ms() {
|
|||
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
|
||||
.await
|
||||
.expect("tail bytes must be writable");
|
||||
writer.shutdown().await.expect("writer shutdown must succeed");
|
||||
writer
|
||||
.shutdown()
|
||||
.await
|
||||
.expect("writer shutdown must succeed");
|
||||
});
|
||||
|
||||
let mut initial_data = b"C".to_vec();
|
||||
|
|
@ -60,7 +63,10 @@ async fn configured_prefetch_budget_5ms_misses_tail_delayed_15ms() {
|
|||
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
|
||||
.await
|
||||
.expect("tail bytes must be writable");
|
||||
writer.shutdown().await.expect("writer shutdown must succeed");
|
||||
writer
|
||||
.shutdown()
|
||||
.await
|
||||
.expect("writer shutdown must succeed");
|
||||
});
|
||||
|
||||
let mut initial_data = b"C".to_vec();
|
||||
|
|
|
|||
|
|
@ -245,7 +245,10 @@ async fn blackhat_integration_empty_initial_data_path_is_byte_exact_and_eof_clea
|
|||
assert_eq!(head[0], 0x16);
|
||||
read_and_discard_tls_record_body(&mut client_side, head).await;
|
||||
|
||||
client_side.write_all(&invalid_mtproto_record).await.unwrap();
|
||||
client_side
|
||||
.write_all(&invalid_mtproto_record)
|
||||
.await
|
||||
.unwrap();
|
||||
client_side.write_all(&trailing_record).await.unwrap();
|
||||
client_side.shutdown().await.unwrap();
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@ async fn run_strict_prefetch_case(prefetch_ms: u64, tail_delay_ms: u64) -> Vec<u
|
|||
|
||||
let writer_task = tokio::spawn(async move {
|
||||
sleep(Duration::from_millis(tail_delay_ms)).await;
|
||||
let _ = writer.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n").await;
|
||||
let _ = writer
|
||||
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
|
||||
.await;
|
||||
let _ = writer.shutdown().await;
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,10 @@ async fn run_prefetch_budget_case(prefetch_budget_ms: u64, delayed_tail_ms: u64)
|
|||
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
|
||||
.await
|
||||
.expect("tail bytes must be writable");
|
||||
writer.shutdown().await.expect("writer shutdown must succeed");
|
||||
writer
|
||||
.shutdown()
|
||||
.await
|
||||
.expect("writer shutdown must succeed");
|
||||
});
|
||||
|
||||
let mut initial_data = b"C".to_vec();
|
||||
|
|
|
|||
|
|
@ -67,9 +67,10 @@ async fn run_replay_candidate_session(
|
|||
cfg.censorship.mask_port = 1;
|
||||
cfg.censorship.mask_timing_normalization_enabled = false;
|
||||
cfg.access.ignore_time_skew = true;
|
||||
cfg.access
|
||||
.users
|
||||
.insert("user".to_string(), "abababababababababababababababab".to_string());
|
||||
cfg.access.users.insert(
|
||||
"user".to_string(),
|
||||
"abababababababababababababababab".to_string(),
|
||||
);
|
||||
|
||||
let config = Arc::new(cfg);
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
|
@ -99,7 +100,10 @@ async fn run_replay_candidate_session(
|
|||
|
||||
if drive_mtproto_fail {
|
||||
let mut server_hello_head = [0u8; 5];
|
||||
client_side.read_exact(&mut server_hello_head).await.unwrap();
|
||||
client_side
|
||||
.read_exact(&mut server_hello_head)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(server_hello_head[0], 0x16);
|
||||
let body_len = u16::from_be_bytes([server_hello_head[3], server_hello_head[4]]) as usize;
|
||||
let mut body = vec![0u8; body_len];
|
||||
|
|
@ -110,7 +114,10 @@ async fn run_replay_candidate_session(
|
|||
invalid_mtproto_record.extend_from_slice(&TLS_VERSION);
|
||||
invalid_mtproto_record.extend_from_slice(&(HANDSHAKE_LEN as u16).to_be_bytes());
|
||||
invalid_mtproto_record.extend_from_slice(&vec![0u8; HANDSHAKE_LEN]);
|
||||
client_side.write_all(&invalid_mtproto_record).await.unwrap();
|
||||
client_side
|
||||
.write_all(&invalid_mtproto_record)
|
||||
.await
|
||||
.unwrap();
|
||||
client_side
|
||||
.write_all(b"GET /replay-fallback HTTP/1.1\r\nHost: x\r\n\r\n")
|
||||
.await
|
||||
|
|
@ -154,8 +161,7 @@ async fn replay_reject_still_honors_masking_timing_budget() {
|
|||
.await;
|
||||
|
||||
assert!(
|
||||
replay_elapsed >= Duration::from_millis(40)
|
||||
&& replay_elapsed < Duration::from_millis(250),
|
||||
replay_elapsed >= Duration::from_millis(40) && replay_elapsed < Duration::from_millis(250),
|
||||
"replay rejection path must still satisfy masking timing budget without unbounded DB/CPU delay"
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,11 @@ use std::sync::Arc;
|
|||
use std::time::Duration;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
|
||||
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
|
||||
let user_stats = stats.get_or_create_user_stats_handle(user);
|
||||
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_mask_delay_bypassed_if_max_is_zero() {
|
||||
let mut config = ProxyConfig::default();
|
||||
|
|
@ -42,17 +47,13 @@ async fn boundary_user_data_quota_exact_match_rejects() {
|
|||
config.access.user_data_quota.insert(user.to_string(), 1024);
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
stats.add_user_octets_from(user, 1024);
|
||||
preload_user_quota(stats.as_ref(), user, 1024);
|
||||
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
let peer = "198.51.100.10:55000".parse().unwrap();
|
||||
|
||||
let result = RunningClientHandler::acquire_user_connection_reservation_static(
|
||||
user,
|
||||
&config,
|
||||
stats,
|
||||
peer,
|
||||
ip_tracker,
|
||||
user, &config, stats, peer, ip_tracker,
|
||||
)
|
||||
.await;
|
||||
|
||||
|
|
@ -74,11 +75,7 @@ async fn boundary_user_expiration_in_past_rejects() {
|
|||
let peer = "198.51.100.11:55000".parse().unwrap();
|
||||
|
||||
let result = RunningClientHandler::acquire_user_connection_reservation_static(
|
||||
user,
|
||||
&config,
|
||||
stats,
|
||||
peer,
|
||||
ip_tracker,
|
||||
user, &config, stats, peer, ip_tracker,
|
||||
)
|
||||
.await;
|
||||
|
||||
|
|
@ -98,7 +95,15 @@ async fn blackhat_proxy_protocol_massive_garbage_rejected_quickly() {
|
|||
"198.51.100.12:55000".parse().unwrap(),
|
||||
config,
|
||||
stats.clone(),
|
||||
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
|
||||
Arc::new(UpstreamManager::new(
|
||||
vec![],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
)),
|
||||
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
|
||||
Arc::new(BufferPool::new()),
|
||||
Arc::new(SecureRandom::new()),
|
||||
|
|
@ -136,7 +141,15 @@ async fn edge_tls_body_immediate_eof_triggers_masking_and_bad_connect() {
|
|||
"198.51.100.13:55000".parse().unwrap(),
|
||||
config,
|
||||
stats.clone(),
|
||||
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
|
||||
Arc::new(UpstreamManager::new(
|
||||
vec![],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
)),
|
||||
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
|
||||
Arc::new(BufferPool::new()),
|
||||
Arc::new(SecureRandom::new()),
|
||||
|
|
@ -148,10 +161,15 @@ async fn edge_tls_body_immediate_eof_triggers_masking_and_bad_connect() {
|
|||
false,
|
||||
));
|
||||
|
||||
client_side.write_all(&[0x16, 0x03, 0x01, 0x00, 100]).await.unwrap();
|
||||
client_side
|
||||
.write_all(&[0x16, 0x03, 0x01, 0x00, 100])
|
||||
.await
|
||||
.unwrap();
|
||||
client_side.shutdown().await.unwrap();
|
||||
|
||||
let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap();
|
||||
let _ = tokio::time::timeout(Duration::from_secs(2), handler)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(stats.get_connects_bad(), 1);
|
||||
}
|
||||
|
|
@ -172,7 +190,15 @@ async fn security_classic_mode_disabled_masks_valid_length_payload() {
|
|||
"198.51.100.15:55000".parse().unwrap(),
|
||||
config,
|
||||
stats.clone(),
|
||||
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
|
||||
Arc::new(UpstreamManager::new(
|
||||
vec![],
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
false,
|
||||
stats.clone(),
|
||||
)),
|
||||
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
|
||||
Arc::new(BufferPool::new()),
|
||||
Arc::new(SecureRandom::new()),
|
||||
|
|
@ -187,7 +213,9 @@ async fn security_classic_mode_disabled_masks_valid_length_payload() {
|
|||
client_side.write_all(&vec![0xEF; 64]).await.unwrap();
|
||||
client_side.shutdown().await.unwrap();
|
||||
|
||||
let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap();
|
||||
let _ = tokio::time::timeout(Duration::from_secs(2), handler)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(stats.get_connects_bad(), 1);
|
||||
}
|
||||
|
||||
|
|
@ -195,7 +223,10 @@ async fn security_classic_mode_disabled_masks_valid_length_payload() {
|
|||
async fn concurrency_ip_tracker_strict_limit_one_rapid_churn() {
|
||||
let user = "rapid-churn-user";
|
||||
let mut config = ProxyConfig::default();
|
||||
config.access.user_max_tcp_conns.insert(user.to_string(), 10);
|
||||
config
|
||||
.access
|
||||
.user_max_tcp_conns
|
||||
.insert(user.to_string(), 10);
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let ip_tracker = Arc::new(UserIpTracker::new());
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ use crate::protocol::tls;
|
|||
use crate::proxy::handshake::HandshakeSuccess;
|
||||
use crate::stream::{CryptoReader, CryptoWriter};
|
||||
use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::Rng;
|
||||
use rand::SeedableRng;
|
||||
use rand::rngs::StdRng;
|
||||
use std::net::Ipv4Addr;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
|
@ -34,7 +34,10 @@ fn handshake_timeout_with_mask_grace_includes_mask_margin() {
|
|||
config.timeouts.client_handshake = 2;
|
||||
|
||||
config.censorship.mask = false;
|
||||
assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_secs(2));
|
||||
assert_eq!(
|
||||
handshake_timeout_with_mask_grace(&config),
|
||||
Duration::from_secs(2)
|
||||
);
|
||||
|
||||
config.censorship.mask = true;
|
||||
assert_eq!(
|
||||
|
|
@ -86,7 +89,10 @@ impl tokio::io::AsyncRead for ErrorReader {
|
|||
_cx: &mut std::task::Context<'_>,
|
||||
_buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
std::task::Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "fake error")))
|
||||
std::task::Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"fake error",
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -124,7 +130,10 @@ fn handshake_timeout_without_mask_is_exact_base() {
|
|||
config.timeouts.client_handshake = 7;
|
||||
config.censorship.mask = false;
|
||||
|
||||
assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_secs(7));
|
||||
assert_eq!(
|
||||
handshake_timeout_with_mask_grace(&config),
|
||||
Duration::from_secs(7)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -133,7 +142,10 @@ fn handshake_timeout_mask_enabled_adds_750ms() {
|
|||
config.timeouts.client_handshake = 3;
|
||||
config.censorship.mask = true;
|
||||
|
||||
assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_millis(3750));
|
||||
assert_eq!(
|
||||
handshake_timeout_with_mask_grace(&config),
|
||||
Duration::from_millis(3750)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -155,10 +167,12 @@ async fn read_with_progress_fragmented_io_works_over_multiple_calls() {
|
|||
let mut b = vec![0u8; chunk_size];
|
||||
let n = read_with_progress(&mut cursor, &mut b).await.unwrap();
|
||||
result.extend_from_slice(&b[..n]);
|
||||
if n == 0 { break; }
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(result, vec![1,2,3,4,5]);
|
||||
assert_eq!(result, vec![1, 2, 3, 4, 5]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -174,7 +188,9 @@ async fn read_with_progress_stress_randomized_chunk_sizes() {
|
|||
let mut b = vec![0u8; chunk];
|
||||
let read = read_with_progress(&mut cursor, &mut b).await.unwrap();
|
||||
collected.extend_from_slice(&b[..read]);
|
||||
if read == 0 { break; }
|
||||
if read == 0 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(collected, input);
|
||||
|
|
@ -215,10 +231,12 @@ fn wrap_tls_application_record_roundtrip_size_check() {
|
|||
let mut consumed = 0;
|
||||
while idx + 5 <= wrapped.len() {
|
||||
assert_eq!(wrapped[idx], 0x17);
|
||||
let len = u16::from_be_bytes([wrapped[idx+3], wrapped[idx+4]]) as usize;
|
||||
let len = u16::from_be_bytes([wrapped[idx + 3], wrapped[idx + 4]]) as usize;
|
||||
consumed += len;
|
||||
idx += 5 + len;
|
||||
if idx >= wrapped.len() { break; }
|
||||
if idx >= wrapped.len() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(consumed, payload_len);
|
||||
|
|
@ -242,6 +260,11 @@ where
|
|||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
|
||||
let user_stats = stats.get_or_create_user_stats_handle(user);
|
||||
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() {
|
||||
let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new());
|
||||
|
|
@ -3040,7 +3063,7 @@ async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() {
|
|||
.insert("user".to_string(), 1024);
|
||||
|
||||
let stats = Stats::new();
|
||||
stats.add_user_octets_from("user", 1024);
|
||||
preload_user_quota(&stats, "user", 1024);
|
||||
|
||||
let ip_tracker = UserIpTracker::new();
|
||||
let peer_addr: SocketAddr = "203.0.113.211:50001".parse().unwrap();
|
||||
|
|
|
|||
|
|
@ -25,13 +25,26 @@ fn wrap_tls_application_record_oversized_payload_is_chunked_without_truncation()
|
|||
let len = u16::from_be_bytes([record[offset + 3], record[offset + 4]]) as usize;
|
||||
let body_start = offset + 5;
|
||||
let body_end = body_start + len;
|
||||
assert!(body_end <= record.len(), "declared TLS record length must be in-bounds");
|
||||
assert!(
|
||||
body_end <= record.len(),
|
||||
"declared TLS record length must be in-bounds"
|
||||
);
|
||||
recovered.extend_from_slice(&record[body_start..body_end]);
|
||||
offset = body_end;
|
||||
frames += 1;
|
||||
}
|
||||
|
||||
assert_eq!(offset, record.len(), "record parser must consume exact output size");
|
||||
assert_eq!(frames, 2, "oversized payload should split into exactly two records");
|
||||
assert_eq!(recovered, payload, "chunked records must preserve full payload");
|
||||
assert_eq!(
|
||||
offset,
|
||||
record.len(),
|
||||
"record parser must consume exact output size"
|
||||
);
|
||||
assert_eq!(
|
||||
frames, 2,
|
||||
"oversized payload should split into exactly two records"
|
||||
);
|
||||
assert_eq!(
|
||||
recovered, payload,
|
||||
"chunked records must preserve full payload"
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -773,8 +773,7 @@ fn anchored_open_nix_path_writes_expected_lines() {
|
|||
"target/telemt-unknown-dc-anchored-open-ok-{}/unknown-dc.log",
|
||||
std::process::id()
|
||||
);
|
||||
let sanitized =
|
||||
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
|
||||
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
|
||||
let _ = fs::remove_file(&sanitized.resolved_path);
|
||||
|
||||
let mut first = open_unknown_dc_log_append_anchored(&sanitized)
|
||||
|
|
@ -787,7 +786,10 @@ fn anchored_open_nix_path_writes_expected_lines() {
|
|||
|
||||
let content =
|
||||
fs::read_to_string(&sanitized.resolved_path).expect("anchored log file must be readable");
|
||||
let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect();
|
||||
let lines: Vec<&str> = content
|
||||
.lines()
|
||||
.filter(|line| !line.trim().is_empty())
|
||||
.collect();
|
||||
assert_eq!(lines.len(), 2, "expected one line per anchored append call");
|
||||
assert!(
|
||||
lines.contains(&"dc_idx=31200") && lines.contains(&"dc_idx=31201"),
|
||||
|
|
@ -811,8 +813,7 @@ fn anchored_open_parallel_appends_preserve_line_integrity() {
|
|||
"target/telemt-unknown-dc-anchored-open-parallel-{}/unknown-dc.log",
|
||||
std::process::id()
|
||||
);
|
||||
let sanitized =
|
||||
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
|
||||
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
|
||||
let _ = fs::remove_file(&sanitized.resolved_path);
|
||||
|
||||
let mut workers = Vec::new();
|
||||
|
|
@ -831,8 +832,15 @@ fn anchored_open_parallel_appends_preserve_line_integrity() {
|
|||
|
||||
let content =
|
||||
fs::read_to_string(&sanitized.resolved_path).expect("parallel log file must be readable");
|
||||
let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect();
|
||||
assert_eq!(lines.len(), 64, "expected one complete line per worker append");
|
||||
let lines: Vec<&str> = content
|
||||
.lines()
|
||||
.filter(|line| !line.trim().is_empty())
|
||||
.collect();
|
||||
assert_eq!(
|
||||
lines.len(),
|
||||
64,
|
||||
"expected one complete line per worker append"
|
||||
);
|
||||
for line in lines {
|
||||
assert!(
|
||||
line.starts_with("dc_idx="),
|
||||
|
|
@ -867,8 +875,7 @@ fn anchored_open_creates_private_0600_file_permissions() {
|
|||
"target/telemt-unknown-dc-anchored-perms-{}/unknown-dc.log",
|
||||
std::process::id()
|
||||
);
|
||||
let sanitized =
|
||||
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
|
||||
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
|
||||
let _ = fs::remove_file(&sanitized.resolved_path);
|
||||
|
||||
let mut file = open_unknown_dc_log_append_anchored(&sanitized)
|
||||
|
|
@ -905,8 +912,7 @@ fn anchored_open_rejects_existing_symlink_target() {
|
|||
"target/telemt-unknown-dc-anchored-symlink-target-{}/unknown-dc.log",
|
||||
std::process::id()
|
||||
);
|
||||
let sanitized =
|
||||
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
|
||||
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
|
||||
|
||||
let outside = std::env::temp_dir().join(format!(
|
||||
"telemt-unknown-dc-anchored-symlink-outside-{}.log",
|
||||
|
|
@ -943,8 +949,7 @@ fn anchored_open_high_contention_multi_write_preserves_complete_lines() {
|
|||
"target/telemt-unknown-dc-anchored-contention-{}/unknown-dc.log",
|
||||
std::process::id()
|
||||
);
|
||||
let sanitized =
|
||||
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
|
||||
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
|
||||
let _ = fs::remove_file(&sanitized.resolved_path);
|
||||
|
||||
let workers = 24usize;
|
||||
|
|
@ -970,7 +975,10 @@ fn anchored_open_high_contention_multi_write_preserves_complete_lines() {
|
|||
|
||||
let content = fs::read_to_string(&sanitized.resolved_path)
|
||||
.expect("contention output file must be readable");
|
||||
let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect();
|
||||
let lines: Vec<&str> = content
|
||||
.lines()
|
||||
.filter(|line| !line.trim().is_empty())
|
||||
.collect();
|
||||
assert_eq!(
|
||||
lines.len(),
|
||||
workers * rounds,
|
||||
|
|
@ -1014,8 +1022,7 @@ fn append_unknown_dc_line_returns_error_for_read_only_descriptor() {
|
|||
"target/telemt-unknown-dc-append-ro-{}/unknown-dc.log",
|
||||
std::process::id()
|
||||
);
|
||||
let sanitized =
|
||||
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
|
||||
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
|
||||
fs::write(&sanitized.resolved_path, "seed\n").expect("seed file must be writable");
|
||||
|
||||
let mut readonly = std::fs::OpenOptions::new()
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use super::*;
|
||||
use crate::crypto::{sha256, sha256_hmac, AesCtr};
|
||||
use crate::crypto::{AesCtr, sha256, sha256_hmac};
|
||||
use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES};
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
|
|
@ -175,7 +175,10 @@ async fn tls_minimum_viable_length_boundary() {
|
|||
None,
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(res, HandshakeResult::Success(_)), "Exact minimum length TLS handshake must succeed");
|
||||
assert!(
|
||||
matches!(res, HandshakeResult::Success(_)),
|
||||
"Exact minimum length TLS handshake must succeed"
|
||||
);
|
||||
|
||||
let short_handshake = vec![0x42u8; min_len - 1];
|
||||
let res_short = handle_tls_handshake(
|
||||
|
|
@ -189,7 +192,10 @@ async fn tls_minimum_viable_length_boundary() {
|
|||
None,
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(res_short, HandshakeResult::BadClient { .. }), "Handshake 1 byte shorter than minimum must fail closed");
|
||||
assert!(
|
||||
matches!(res_short, HandshakeResult::BadClient { .. }),
|
||||
"Handshake 1 byte shorter than minimum must fail closed"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -219,9 +225,16 @@ async fn mtproto_extreme_dc_index_serialization() {
|
|||
|
||||
match res {
|
||||
HandshakeResult::Success((_, _, success)) => {
|
||||
assert_eq!(success.dc_idx, extreme_dc, "Extreme DC index {} must serialize/deserialize perfectly", extreme_dc);
|
||||
assert_eq!(
|
||||
success.dc_idx, extreme_dc,
|
||||
"Extreme DC index {} must serialize/deserialize perfectly",
|
||||
extreme_dc
|
||||
);
|
||||
}
|
||||
_ => panic!("MTProto handshake with extreme DC index {} failed", extreme_dc),
|
||||
_ => panic!(
|
||||
"MTProto handshake with extreme DC index {} failed",
|
||||
extreme_dc
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -253,7 +266,11 @@ async fn alpn_strict_case_and_padding_rejection() {
|
|||
None,
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(res, HandshakeResult::BadClient { .. }), "ALPN strict enforcement must reject {:?}", bad_alpn);
|
||||
assert!(
|
||||
matches!(res, HandshakeResult::BadClient { .. }),
|
||||
"ALPN strict enforcement must reject {:?}",
|
||||
bad_alpn
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -265,8 +282,15 @@ fn ipv4_mapped_ipv6_bucketing_anomaly() {
|
|||
let norm_1 = normalize_auth_probe_ip(ipv4_mapped_1);
|
||||
let norm_2 = normalize_auth_probe_ip(ipv4_mapped_2);
|
||||
|
||||
assert_eq!(norm_1, norm_2, "IPv4-mapped IPv6 addresses must collapse into the same /64 bucket (::0)");
|
||||
assert_eq!(norm_1, IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), "The bucket must be exactly ::0");
|
||||
assert_eq!(
|
||||
norm_1, norm_2,
|
||||
"IPv4-mapped IPv6 addresses must collapse into the same /64 bucket (::0)"
|
||||
);
|
||||
assert_eq!(
|
||||
norm_1,
|
||||
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
|
||||
"The bucket must be exactly ::0"
|
||||
);
|
||||
}
|
||||
|
||||
// --- Category 2: Adversarial & Black Hat ---
|
||||
|
|
@ -309,7 +333,10 @@ async fn mtproto_invalid_ciphertext_does_not_poison_replay_cache() {
|
|||
None,
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(res_valid, HandshakeResult::Success(_)), "Invalid MTProto ciphertext must not poison the replay cache");
|
||||
assert!(
|
||||
matches!(res_valid, HandshakeResult::Success(_)),
|
||||
"Invalid MTProto ciphertext must not poison the replay cache"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -352,7 +379,10 @@ async fn tls_invalid_session_does_not_poison_replay_cache() {
|
|||
None,
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(res_valid, HandshakeResult::Success(_)), "Invalid TLS payload must not poison the replay cache");
|
||||
assert!(
|
||||
matches!(res_valid, HandshakeResult::Success(_)),
|
||||
"Invalid TLS payload must not poison the replay cache"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -387,7 +417,10 @@ async fn server_hello_delay_timing_neutrality_on_hmac_failure() {
|
|||
let elapsed = start.elapsed();
|
||||
|
||||
assert!(matches!(res, HandshakeResult::BadClient { .. }));
|
||||
assert!(elapsed >= Duration::from_millis(45), "Invalid HMAC must still incur the configured ServerHello delay to prevent timing side-channels");
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(45),
|
||||
"Invalid HMAC must still incur the configured ServerHello delay to prevent timing side-channels"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -421,7 +454,10 @@ async fn server_hello_delay_inversion_resilience() {
|
|||
let elapsed = start.elapsed();
|
||||
|
||||
assert!(matches!(res, HandshakeResult::Success(_)));
|
||||
assert!(elapsed >= Duration::from_millis(90), "Delay logic must gracefully handle min > max inversions via max.max(min)");
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(90),
|
||||
"Delay logic must gracefully handle min > max inversions via max.max(min)"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -436,10 +472,16 @@ async fn mixed_valid_and_invalid_user_secrets_configuration() {
|
|||
|
||||
for i in 0..9 {
|
||||
let bad_secret = if i % 2 == 0 { "badhex!" } else { "1122" };
|
||||
config.access.users.insert(format!("bad_user_{}", i), bad_secret.to_string());
|
||||
config
|
||||
.access
|
||||
.users
|
||||
.insert(format!("bad_user_{}", i), bad_secret.to_string());
|
||||
}
|
||||
let valid_secret_hex = "99999999999999999999999999999999";
|
||||
config.access.users.insert("good_user".to_string(), valid_secret_hex.to_string());
|
||||
config
|
||||
.access
|
||||
.users
|
||||
.insert("good_user".to_string(), valid_secret_hex.to_string());
|
||||
config.general.modes.secure = true;
|
||||
config.general.modes.classic = true;
|
||||
config.general.modes.tls = true;
|
||||
|
|
@ -463,7 +505,10 @@ async fn mixed_valid_and_invalid_user_secrets_configuration() {
|
|||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(res, HandshakeResult::Success(_)), "Proxy must gracefully skip invalid secrets and authenticate the valid one");
|
||||
assert!(
|
||||
matches!(res, HandshakeResult::Success(_)),
|
||||
"Proxy must gracefully skip invalid secrets and authenticate the valid one"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -494,7 +539,10 @@ async fn tls_emulation_fallback_when_cache_missing() {
|
|||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(res, HandshakeResult::Success(_)), "TLS emulation must gracefully fall back to standard ServerHello if cache is missing");
|
||||
assert!(
|
||||
matches!(res, HandshakeResult::Success(_)),
|
||||
"TLS emulation must gracefully fall back to standard ServerHello if cache is missing"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -524,7 +572,10 @@ async fn classic_mode_over_tls_transport_protocol_confusion() {
|
|||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(res, HandshakeResult::Success(_)), "Intermediate tag over TLS must succeed if classic mode is enabled, locking in cross-transport behavior");
|
||||
assert!(
|
||||
matches!(res, HandshakeResult::Success(_)),
|
||||
"Intermediate tag over TLS must succeed if classic mode is enabled, locking in cross-transport behavior"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -543,9 +594,15 @@ fn generate_tg_nonce_never_emits_reserved_bytes() {
|
|||
false,
|
||||
);
|
||||
|
||||
assert!(!RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]), "Nonce must never start with reserved bytes");
|
||||
assert!(
|
||||
!RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]),
|
||||
"Nonce must never start with reserved bytes"
|
||||
);
|
||||
let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]];
|
||||
assert!(!RESERVED_NONCE_BEGINNINGS.contains(&first_four), "Nonce must never match reserved 4-byte beginnings");
|
||||
assert!(
|
||||
!RESERVED_NONCE_BEGINNINGS.contains(&first_four),
|
||||
"Nonce must never match reserved 4-byte beginnings"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -568,11 +625,18 @@ async fn dashmap_concurrent_saturation_stress() {
|
|||
}
|
||||
|
||||
for task in tasks {
|
||||
task.await.expect("Task panicked during concurrent DashMap stress");
|
||||
task.await
|
||||
.expect("Task panicked during concurrent DashMap stress");
|
||||
}
|
||||
|
||||
assert!(auth_probe_is_throttled_for_testing(ip_a), "IP A must be throttled after concurrent stress");
|
||||
assert!(auth_probe_is_throttled_for_testing(ip_b), "IP B must be throttled after concurrent stress");
|
||||
assert!(
|
||||
auth_probe_is_throttled_for_testing(ip_a),
|
||||
"IP A must be throttled after concurrent stress"
|
||||
);
|
||||
assert!(
|
||||
auth_probe_is_throttled_for_testing(ip_b),
|
||||
"IP B must be throttled after concurrent stress"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -586,7 +650,12 @@ fn prototag_invalid_bytes_fail_closed() {
|
|||
];
|
||||
|
||||
for tag in invalid_tags {
|
||||
assert_eq!(ProtoTag::from_bytes(tag), None, "Invalid ProtoTag bytes {:?} must fail closed", tag);
|
||||
assert_eq!(
|
||||
ProtoTag::from_bytes(tag),
|
||||
None,
|
||||
"Invalid ProtoTag bytes {:?} must fail closed",
|
||||
tag
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -603,7 +672,10 @@ fn auth_probe_eviction_hash_collision_stress() {
|
|||
auth_probe_record_failure_with_state(state, ip, now);
|
||||
}
|
||||
|
||||
assert!(state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, "Eviction logic must successfully bound the map size under heavy insertion stress");
|
||||
assert!(
|
||||
state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES,
|
||||
"Eviction logic must successfully bound the map size under heavy insertion stress"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -88,6 +88,9 @@ fn light_fuzz_offset_always_stays_inside_state_len() {
|
|||
let now = base + Duration::from_nanos(seed & 0x0fff);
|
||||
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
|
||||
|
||||
assert!(start < state_len, "scan offset must stay inside state length");
|
||||
assert!(
|
||||
start < state_len,
|
||||
"scan offset must stay inside state length"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
use super::*;
|
||||
use crate::crypto::{sha256, sha256_hmac, AesCtr};
|
||||
use crate::crypto::{AesCtr, sha256, sha256_hmac};
|
||||
use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES};
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use std::collections::HashSet;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
|
|
@ -223,7 +223,10 @@ fn auth_probe_backoff_extreme_fail_streak_clamps_safely() {
|
|||
assert_eq!(updated.fail_streak, u32::MAX);
|
||||
|
||||
let expected_blocked_until = now + Duration::from_millis(AUTH_PROBE_BACKOFF_MAX_MS);
|
||||
assert_eq!(updated.blocked_until, expected_blocked_until, "Extreme fail streak must clamp cleanly to AUTH_PROBE_BACKOFF_MAX_MS");
|
||||
assert_eq!(
|
||||
updated.blocked_until, expected_blocked_until,
|
||||
"Extreme fail streak must clamp cleanly to AUTH_PROBE_BACKOFF_MAX_MS"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -250,12 +253,19 @@ fn generate_tg_nonce_cryptographic_uniqueness_and_entropy() {
|
|||
total_set_bits += byte.count_ones() as usize;
|
||||
}
|
||||
|
||||
assert!(nonces.insert(nonce), "generate_tg_nonce emitted a duplicate nonce! RNG is stuck.");
|
||||
assert!(
|
||||
nonces.insert(nonce),
|
||||
"generate_tg_nonce emitted a duplicate nonce! RNG is stuck."
|
||||
);
|
||||
}
|
||||
|
||||
let total_bits = iterations * HANDSHAKE_LEN * 8;
|
||||
let ratio = (total_set_bits as f64) / (total_bits as f64);
|
||||
assert!(ratio > 0.48 && ratio < 0.52, "Nonce entropy is degraded. Set bit ratio: {}", ratio);
|
||||
assert!(
|
||||
ratio > 0.48 && ratio < 0.52,
|
||||
"Nonce entropy is degraded. Set bit ratio: {}",
|
||||
ratio
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -267,10 +277,19 @@ async fn mtproto_multi_user_decryption_isolation() {
|
|||
config.general.modes.secure = true;
|
||||
config.access.ignore_time_skew = true;
|
||||
|
||||
config.access.users.insert("user_a".to_string(), "11111111111111111111111111111111".to_string());
|
||||
config.access.users.insert("user_b".to_string(), "22222222222222222222222222222222".to_string());
|
||||
config.access.users.insert(
|
||||
"user_a".to_string(),
|
||||
"11111111111111111111111111111111".to_string(),
|
||||
);
|
||||
config.access.users.insert(
|
||||
"user_b".to_string(),
|
||||
"22222222222222222222222222222222".to_string(),
|
||||
);
|
||||
let good_secret_hex = "33333333333333333333333333333333";
|
||||
config.access.users.insert("user_c".to_string(), good_secret_hex.to_string());
|
||||
config
|
||||
.access
|
||||
.users
|
||||
.insert("user_c".to_string(), good_secret_hex.to_string());
|
||||
|
||||
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
||||
let peer: SocketAddr = "192.0.2.104:12345".parse().unwrap();
|
||||
|
|
@ -291,9 +310,14 @@ async fn mtproto_multi_user_decryption_isolation() {
|
|||
|
||||
match res {
|
||||
HandshakeResult::Success((_, _, success)) => {
|
||||
assert_eq!(success.user, "user_c", "Decryption attempts on previous users must not corrupt the handshake buffer for the valid user");
|
||||
assert_eq!(
|
||||
success.user, "user_c",
|
||||
"Decryption attempts on previous users must not corrupt the handshake buffer for the valid user"
|
||||
);
|
||||
}
|
||||
_ => panic!("Multi-user MTProto handshake failed. Decryption buffer might be mutating in place."),
|
||||
_ => panic!(
|
||||
"Multi-user MTProto handshake failed. Decryption buffer might be mutating in place."
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -325,7 +349,9 @@ async fn invalid_secret_warning_lock_contention_and_bound() {
|
|||
}
|
||||
|
||||
let warned = INVALID_SECRET_WARNED.get().unwrap();
|
||||
let guard = warned.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
let guard = warned
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner());
|
||||
|
||||
assert_eq!(
|
||||
guard.len(),
|
||||
|
|
@ -342,7 +368,11 @@ async fn mtproto_strict_concurrent_replay_race_condition() {
|
|||
let secret_hex = "4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A";
|
||||
let config = Arc::new(test_config_with_secret_hex(secret_hex));
|
||||
let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60)));
|
||||
let valid_handshake = Arc::new(make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1));
|
||||
let valid_handshake = Arc::new(make_valid_mtproto_handshake(
|
||||
secret_hex,
|
||||
ProtoTag::Secure,
|
||||
1,
|
||||
));
|
||||
|
||||
let tasks = 100;
|
||||
let barrier = Arc::new(Barrier::new(tasks));
|
||||
|
|
@ -355,7 +385,10 @@ async fn mtproto_strict_concurrent_replay_race_condition() {
|
|||
let hs = valid_handshake.clone();
|
||||
|
||||
handles.push(tokio::spawn(async move {
|
||||
let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 250) as u8)), 10000 + i as u16);
|
||||
let peer = SocketAddr::new(
|
||||
IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 250) as u8)),
|
||||
10000 + i as u16,
|
||||
);
|
||||
b.wait().await;
|
||||
handle_mtproto_handshake(
|
||||
&hs,
|
||||
|
|
@ -382,8 +415,15 @@ async fn mtproto_strict_concurrent_replay_race_condition() {
|
|||
}
|
||||
}
|
||||
|
||||
assert_eq!(successes, 1, "Replay cache race condition allowed multiple identical MTProto handshakes to succeed");
|
||||
assert_eq!(failures, tasks - 1, "Replay cache failed to forcefully reject concurrent duplicates");
|
||||
assert_eq!(
|
||||
successes, 1,
|
||||
"Replay cache race condition allowed multiple identical MTProto handshakes to succeed"
|
||||
);
|
||||
assert_eq!(
|
||||
failures,
|
||||
tasks - 1,
|
||||
"Replay cache failed to forcefully reject concurrent duplicates"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -398,7 +438,8 @@ async fn tls_alpn_zero_length_protocol_handled_safely() {
|
|||
let rng = SecureRandom::new();
|
||||
let peer: SocketAddr = "192.0.2.107:12345".parse().unwrap();
|
||||
|
||||
let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b""]);
|
||||
let handshake =
|
||||
make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b""]);
|
||||
|
||||
let res = handle_tls_handshake(
|
||||
&handshake,
|
||||
|
|
@ -412,7 +453,10 @@ async fn tls_alpn_zero_length_protocol_handled_safely() {
|
|||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(res, HandshakeResult::BadClient { .. }), "0-length ALPN must be safely rejected without panicking");
|
||||
assert!(
|
||||
matches!(res, HandshakeResult::BadClient { .. }),
|
||||
"0-length ALPN must be safely rejected without panicking"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -427,7 +471,8 @@ async fn tls_sni_massive_hostname_does_not_panic() {
|
|||
let peer: SocketAddr = "192.0.2.108:12345".parse().unwrap();
|
||||
|
||||
let massive_hostname = String::from_utf8(vec![b'a'; 65000]).unwrap();
|
||||
let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, &massive_hostname, &[]);
|
||||
let handshake =
|
||||
make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, &massive_hostname, &[]);
|
||||
|
||||
let res = handle_tls_handshake(
|
||||
&handshake,
|
||||
|
|
@ -441,7 +486,13 @@ async fn tls_sni_massive_hostname_does_not_panic() {
|
|||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(res, HandshakeResult::Success(_) | HandshakeResult::BadClient { .. }), "Massive SNI hostname must be processed or ignored without stack overflow or panic");
|
||||
assert!(
|
||||
matches!(
|
||||
res,
|
||||
HandshakeResult::Success(_) | HandshakeResult::BadClient { .. }
|
||||
),
|
||||
"Massive SNI hostname must be processed or ignored without stack overflow or panic"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -455,7 +506,8 @@ async fn tls_progressive_truncation_fuzzing_no_panics() {
|
|||
let rng = SecureRandom::new();
|
||||
let peer: SocketAddr = "192.0.2.109:12345".parse().unwrap();
|
||||
|
||||
let valid_handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b"h2"]);
|
||||
let valid_handshake =
|
||||
make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b"h2"]);
|
||||
let full_len = valid_handshake.len();
|
||||
|
||||
// Truncated corpus only: full_len is a valid baseline and should not be
|
||||
|
|
@ -473,7 +525,11 @@ async fn tls_progressive_truncation_fuzzing_no_panics() {
|
|||
None,
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(res, HandshakeResult::BadClient { .. }), "Truncated TLS handshake at len {} must fail safely without panicking", i);
|
||||
assert!(
|
||||
matches!(res, HandshakeResult::BadClient { .. }),
|
||||
"Truncated TLS handshake at len {} must fail safely without panicking",
|
||||
i
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -504,7 +560,10 @@ async fn mtproto_pure_entropy_fuzzing_no_panics() {
|
|||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(res, HandshakeResult::BadClient { .. }), "Pure entropy MTProto payload must fail closed and never panic");
|
||||
assert!(
|
||||
matches!(res, HandshakeResult::BadClient { .. }),
|
||||
"Pure entropy MTProto payload must fail closed and never panic"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -517,10 +576,16 @@ fn decode_user_secret_odd_length_hex_rejection() {
|
|||
|
||||
let mut config = ProxyConfig::default();
|
||||
config.access.users.clear();
|
||||
config.access.users.insert("odd_user".to_string(), "1234567890123456789012345678901".to_string());
|
||||
config.access.users.insert(
|
||||
"odd_user".to_string(),
|
||||
"1234567890123456789012345678901".to_string(),
|
||||
);
|
||||
|
||||
let decoded = decode_user_secrets(&config, None);
|
||||
assert!(decoded.is_empty(), "Odd-length hex string must be gracefully rejected by hex::decode without unwrapping");
|
||||
assert!(
|
||||
decoded.is_empty(),
|
||||
"Odd-length hex string must be gracefully rejected by hex::decode without unwrapping"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -552,7 +617,10 @@ fn saturation_grace_pre_existing_high_fail_streak_immediate_throttle() {
|
|||
}
|
||||
|
||||
let is_throttled = auth_probe_should_apply_preauth_throttle(peer_ip, now);
|
||||
assert!(is_throttled, "A peer with a pre-existing high fail streak must be immediately throttled when saturation begins, receiving no unearned grace period");
|
||||
assert!(
|
||||
is_throttled,
|
||||
"A peer with a pre-existing high fail streak must be immediately throttled when saturation begins, receiving no unearned grace period"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -586,7 +654,11 @@ fn mtproto_classic_tags_rejected_when_only_secure_mode_enabled() {
|
|||
config.general.modes.tls = false;
|
||||
|
||||
assert!(!mode_enabled_for_proto(&config, ProtoTag::Abridged, false));
|
||||
assert!(!mode_enabled_for_proto(&config, ProtoTag::Intermediate, false));
|
||||
assert!(!mode_enabled_for_proto(
|
||||
&config,
|
||||
ProtoTag::Intermediate,
|
||||
false
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use super::*;
|
||||
use crate::crypto::{sha256, sha256_hmac, AesCtr, SecureRandom};
|
||||
use crate::crypto::{AesCtr, SecureRandom, sha256, sha256_hmac};
|
||||
use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION};
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
|
|
@ -80,8 +80,7 @@ fn make_valid_tls_client_hello_with_alpn(
|
|||
digest[28 + i] ^= ts[i];
|
||||
}
|
||||
|
||||
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
|
||||
.copy_from_slice(&digest);
|
||||
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest);
|
||||
|
||||
record
|
||||
}
|
||||
|
|
@ -331,7 +330,11 @@ async fn saturation_grace_exhaustion_under_concurrency_keeps_peer_throttled() {
|
|||
|
||||
let final_state = state.get(&peer_ip).expect("state must exist");
|
||||
assert!(
|
||||
final_state.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
|
||||
final_state.fail_streak
|
||||
>= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
|
||||
);
|
||||
assert!(auth_probe_should_apply_preauth_throttle(peer_ip, Instant::now()));
|
||||
assert!(auth_probe_should_apply_preauth_throttle(
|
||||
peer_ip,
|
||||
Instant::now()
|
||||
));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -956,6 +956,89 @@ async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() {
|
|||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_unknown_sni_drop_policy_returns_hard_error() {
|
||||
let secret = [0x48u8; 16];
|
||||
let mut config = test_config_with_secret_hex("48484848484848484848484848484848");
|
||||
config.censorship.unknown_sni_action = UnknownSniAction::Drop;
|
||||
|
||||
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
let peer: SocketAddr = "198.51.100.190:44326".parse().unwrap();
|
||||
let handshake =
|
||||
make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "unknown.example", &[b"h2"]);
|
||||
|
||||
let result = handle_tls_handshake(
|
||||
&handshake,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
peer,
|
||||
&config,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(
|
||||
result,
|
||||
HandshakeResult::Error(ProxyError::UnknownTlsSni)
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_unknown_sni_mask_policy_falls_back_to_bad_client() {
|
||||
let secret = [0x49u8; 16];
|
||||
let mut config = test_config_with_secret_hex("49494949494949494949494949494949");
|
||||
config.censorship.unknown_sni_action = UnknownSniAction::Mask;
|
||||
|
||||
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
let peer: SocketAddr = "198.51.100.191:44326".parse().unwrap();
|
||||
let handshake =
|
||||
make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "unknown.example", &[b"h2"]);
|
||||
|
||||
let result = handle_tls_handshake(
|
||||
&handshake,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
peer,
|
||||
&config,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, HandshakeResult::BadClient { .. }));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tls_missing_sni_keeps_legacy_auth_path() {
|
||||
let secret = [0x4Au8; 16];
|
||||
let mut config = test_config_with_secret_hex("4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a");
|
||||
config.censorship.unknown_sni_action = UnknownSniAction::Drop;
|
||||
|
||||
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
|
||||
let rng = SecureRandom::new();
|
||||
let peer: SocketAddr = "198.51.100.192:44326".parse().unwrap();
|
||||
let handshake = make_valid_tls_handshake(&secret, 0);
|
||||
|
||||
let result = handle_tls_handshake(
|
||||
&handshake,
|
||||
tokio::io::empty(),
|
||||
tokio::io::sink(),
|
||||
peer,
|
||||
&config,
|
||||
&replay_checker,
|
||||
&rng,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, HandshakeResult::Success(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn alpn_enforce_rejects_unsupported_client_alpn() {
|
||||
let secret = [0x33u8; 16];
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use super::*;
|
||||
use crate::crypto::{sha256, sha256_hmac, AesCtr, SecureRandom};
|
||||
use crate::crypto::{AesCtr, SecureRandom, sha256, sha256_hmac};
|
||||
use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION};
|
||||
use std::net::SocketAddr;
|
||||
use std::time::{Duration, Instant};
|
||||
|
|
@ -169,10 +169,10 @@ async fn mtproto_user_scan_timing_manual_benchmark() {
|
|||
);
|
||||
}
|
||||
|
||||
config.access.users.insert(
|
||||
preferred_user.to_string(),
|
||||
target_secret_hex.to_string(),
|
||||
);
|
||||
config
|
||||
.access
|
||||
.users
|
||||
.insert(preferred_user.to_string(), target_secret_hex.to_string());
|
||||
|
||||
let replay_checker_preferred = ReplayChecker::new(65_536, Duration::from_secs(60));
|
||||
let replay_checker_full_scan = ReplayChecker::new(65_536, Duration::from_secs(60));
|
||||
|
|
|
|||
|
|
@ -544,7 +544,6 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
|
|||
if hardened_acc + 0.05 <= baseline_acc {
|
||||
meaningful_improvement_seen = true;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
assert!(
|
||||
|
|
|
|||
|
|
@ -78,7 +78,11 @@ fn timing_normalization_zero_floor_safety_net_defaults_to_mask_timeout() {
|
|||
config.censorship.mask_timing_normalization_ceiling_ms = 0;
|
||||
|
||||
let budget = mask_outcome_target_budget(&config);
|
||||
assert_eq!(budget, MASK_TIMEOUT);
|
||||
assert_eq!(
|
||||
budget,
|
||||
Duration::from_millis(0),
|
||||
"zero floor/ceiling must produce zero extra normalization budget"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
|
|||
|
|
@ -85,7 +85,10 @@ async fn aggressive_mode_shapes_backend_silent_non_eof_path() {
|
|||
let legacy = capture_forwarded_len_with_mode(body_sent, false, false, false, 0).await;
|
||||
let aggressive = capture_forwarded_len_with_mode(body_sent, false, true, false, 0).await;
|
||||
|
||||
assert!(legacy < floor, "legacy mode should keep timeout path unshaped");
|
||||
assert!(
|
||||
legacy < floor,
|
||||
"legacy mode should keep timeout path unshaped"
|
||||
);
|
||||
assert!(
|
||||
aggressive >= floor,
|
||||
"aggressive mode must shape backend-silent non-EOF paths (aggressive={aggressive}, floor={floor})"
|
||||
|
|
|
|||
|
|
@ -52,7 +52,10 @@ async fn run_connect_failure_case(
|
|||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(n, 0, "connect-failure path must close client-visible writer");
|
||||
assert_eq!(
|
||||
n, 0,
|
||||
"connect-failure path must close client-visible writer"
|
||||
);
|
||||
|
||||
started.elapsed()
|
||||
}
|
||||
|
|
@ -67,13 +70,9 @@ async fn connect_failure_refusal_close_behavior_matrix() {
|
|||
let peer: SocketAddr = format!("203.0.113.210:{}", 54100 + idx as u16)
|
||||
.parse()
|
||||
.unwrap();
|
||||
let elapsed = run_connect_failure_case(
|
||||
"127.0.0.1",
|
||||
unused_port,
|
||||
timing_normalization_enabled,
|
||||
peer,
|
||||
)
|
||||
.await;
|
||||
let elapsed =
|
||||
run_connect_failure_case("127.0.0.1", unused_port, timing_normalization_enabled, peer)
|
||||
.await;
|
||||
|
||||
if timing_normalization_enabled {
|
||||
assert!(
|
||||
|
|
|
|||
|
|
@ -79,7 +79,10 @@ async fn io_error_terminates_cleanly() {
|
|||
}
|
||||
}
|
||||
|
||||
tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(ErrReader, usize::MAX))
|
||||
.await
|
||||
.expect("consume_client_data did not return on I/O error");
|
||||
tokio::time::timeout(
|
||||
MASK_RELAY_TIMEOUT,
|
||||
consume_client_data(ErrReader, usize::MAX),
|
||||
)
|
||||
.await
|
||||
.expect("consume_client_data did not return on I/O error");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,8 +32,16 @@ async fn run_self_target_refusal(
|
|||
let (mut client, server) = duplex(1024);
|
||||
let started = Instant::now();
|
||||
let task = tokio::spawn(async move {
|
||||
handle_bad_client(server, tokio::io::sink(), initial, peer, local_addr, &config, &beobachten)
|
||||
.await;
|
||||
handle_bad_client(
|
||||
server,
|
||||
tokio::io::sink(),
|
||||
initial,
|
||||
peer,
|
||||
local_addr,
|
||||
&config,
|
||||
&beobachten,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
|
||||
client
|
||||
|
|
|
|||
|
|
@ -2,7 +2,13 @@ use super::*;
|
|||
|
||||
#[test]
|
||||
fn exact_four_byte_http_tokens_are_classified() {
|
||||
for token in [b"GET ".as_ref(), b"POST".as_ref(), b"HEAD".as_ref(), b"PUT ".as_ref(), b"PRI ".as_ref()] {
|
||||
for token in [
|
||||
b"GET ".as_ref(),
|
||||
b"POST".as_ref(),
|
||||
b"HEAD".as_ref(),
|
||||
b"PUT ".as_ref(),
|
||||
b"PRI ".as_ref(),
|
||||
] {
|
||||
assert!(
|
||||
is_http_probe(token),
|
||||
"exact 4-byte token must be classified as HTTP probe: {:?}",
|
||||
|
|
|
|||
|
|
@ -37,7 +37,10 @@ async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() {
|
|||
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
|
||||
let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await;
|
||||
|
||||
assert!(!is_local, "different port must not be treated as local listener");
|
||||
assert!(
|
||||
!is_local,
|
||||
"different port must not be treated as local listener"
|
||||
);
|
||||
assert_eq!(
|
||||
local_interface_enumerations_for_tests(),
|
||||
0,
|
||||
|
|
|
|||
|
|
@ -63,17 +63,11 @@ impl AsyncWrite for CountingWriter {
|
|||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use super::*;
|
||||
use std::net::TcpListener as StdTcpListener;
|
||||
use std::net::SocketAddr;
|
||||
use std::net::TcpListener as StdTcpListener;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::net::TcpListener;
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
|
@ -15,74 +15,38 @@ fn closed_local_port() -> u16 {
|
|||
#[tokio::test]
|
||||
async fn self_target_detection_matches_literal_ipv4_listener() {
|
||||
let local: SocketAddr = "198.51.100.40:443".parse().unwrap();
|
||||
assert!(is_mask_target_local_listener_async(
|
||||
"198.51.100.40",
|
||||
443,
|
||||
local,
|
||||
None,
|
||||
)
|
||||
.await);
|
||||
assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, None,).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_matches_bracketed_ipv6_listener() {
|
||||
let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap();
|
||||
assert!(is_mask_target_local_listener_async(
|
||||
"[2001:db8::44]",
|
||||
8443,
|
||||
local,
|
||||
None,
|
||||
)
|
||||
.await);
|
||||
assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, None,).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_keeps_same_ip_different_port_forwardable() {
|
||||
let local: SocketAddr = "203.0.113.44:443".parse().unwrap();
|
||||
assert!(!is_mask_target_local_listener_async(
|
||||
"203.0.113.44",
|
||||
8443,
|
||||
local,
|
||||
None,
|
||||
)
|
||||
.await);
|
||||
assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, None,).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() {
|
||||
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
|
||||
assert!(is_mask_target_local_listener_async(
|
||||
"::ffff:127.0.0.1",
|
||||
443,
|
||||
local,
|
||||
None,
|
||||
)
|
||||
.await);
|
||||
assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, None,).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_unspecified_bind_blocks_loopback_target() {
|
||||
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
||||
assert!(is_mask_target_local_listener_async(
|
||||
"127.0.0.1",
|
||||
443,
|
||||
local,
|
||||
None,
|
||||
)
|
||||
.await);
|
||||
assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, None,).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() {
|
||||
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
|
||||
let remote: SocketAddr = "198.51.100.44:443".parse().unwrap();
|
||||
assert!(!is_mask_target_local_listener_async(
|
||||
"mask.example",
|
||||
443,
|
||||
local,
|
||||
Some(remote),
|
||||
)
|
||||
.await);
|
||||
assert!(!is_mask_target_local_listener_async("mask.example", 443, local, Some(remote),).await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -306,7 +270,10 @@ async fn offline_mask_target_refusal_respects_timing_normalization_budget() {
|
|||
});
|
||||
|
||||
client.shutdown().await.unwrap();
|
||||
timeout(Duration::from_secs(2), task).await.unwrap().unwrap();
|
||||
timeout(Duration::from_secs(2), task)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let elapsed = started.elapsed();
|
||||
|
||||
assert!(
|
||||
|
|
@ -350,7 +317,10 @@ async fn offline_mask_target_refusal_with_idle_client_is_bounded_by_consume_time
|
|||
.await
|
||||
.expect("connection should still be open before consume timeout expires");
|
||||
|
||||
timeout(Duration::from_secs(2), task).await.unwrap().unwrap();
|
||||
timeout(Duration::from_secs(2), task)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let elapsed = started.elapsed();
|
||||
|
||||
assert!(
|
||||
|
|
|
|||
|
|
@ -40,7 +40,10 @@ async fn adversarial_delayed_interface_lookup_does_not_consume_outcome_floor_bud
|
|||
|
||||
tokio::time::sleep(Duration::from_millis(80)).await;
|
||||
drop(held_refresh_guard);
|
||||
client.shutdown().await.expect("client shutdown must succeed");
|
||||
client
|
||||
.shutdown()
|
||||
.await
|
||||
.expect("client shutdown must succeed");
|
||||
|
||||
timeout(Duration::from_secs(2), task)
|
||||
.await
|
||||
|
|
|
|||
|
|
@ -0,0 +1,189 @@
|
|||
use super::*;
|
||||
use crate::crypto::AesCtr;
|
||||
use bytes::Bytes;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::AsyncWrite;
|
||||
|
||||
struct CountedWriter {
|
||||
write_calls: Arc<AtomicUsize>,
|
||||
fail_writes: bool,
|
||||
}
|
||||
|
||||
impl CountedWriter {
|
||||
fn new(write_calls: Arc<AtomicUsize>, fail_writes: bool) -> Self {
|
||||
Self {
|
||||
write_calls,
|
||||
fail_writes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for CountedWriter {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
this.write_calls.fetch_add(1, Ordering::Relaxed);
|
||||
if this.fail_writes {
|
||||
Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
"forced write failure",
|
||||
)))
|
||||
} else {
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
fn make_crypto_writer(inner: CountedWriter) -> CryptoWriter<CountedWriter> {
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(inner, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() {
|
||||
let stats = Stats::new();
|
||||
let user = "middle-me-writer-no-rollback-user";
|
||||
let user_stats = stats.get_or_create_user_stats_handle(user);
|
||||
let write_calls = Arc::new(AtomicUsize::new(0));
|
||||
let mut writer = make_crypto_writer(CountedWriter::new(write_calls.clone(), true));
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let payload = Bytes::from_static(&[0x11, 0x22, 0x33, 0x44, 0x55]);
|
||||
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: payload.clone(),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(user_stats.as_ref()),
|
||||
Some(64),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
11,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
matches!(result, Err(ProxyError::Io(_))),
|
||||
"write failure must propagate as I/O error"
|
||||
);
|
||||
assert!(
|
||||
write_calls.load(Ordering::Relaxed) > 0,
|
||||
"writer must be attempted after successful quota reservation"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_user_quota_used(user),
|
||||
payload.len() as u64,
|
||||
"reserved quota must not roll back on write failure"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_quota_write_fail_bytes_total(),
|
||||
payload.len() as u64,
|
||||
"write-fail byte metric must include failed payload size"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_quota_write_fail_events_total(),
|
||||
1,
|
||||
"write-fail events metric must increment once"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_user_total_octets(user),
|
||||
0,
|
||||
"telemetry octets_to should not advance when write fails"
|
||||
);
|
||||
assert_eq!(
|
||||
bytes_me2c.load(Ordering::Relaxed),
|
||||
0,
|
||||
"ME->C committed byte counter must not advance on write failure"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() {
|
||||
let stats = Stats::new();
|
||||
let user = "middle-me-writer-precheck-user";
|
||||
let limit = 8u64;
|
||||
let user_stats = stats.get_or_create_user_stats_handle(user);
|
||||
stats.quota_charge_post_write(user_stats.as_ref(), limit);
|
||||
|
||||
let write_calls = Arc::new(AtomicUsize::new(0));
|
||||
let mut writer = make_crypto_writer(CountedWriter::new(write_calls.clone(), false));
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xAA, 0xBB, 0xCC]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(user_stats.as_ref()),
|
||||
Some(limit),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
12,
|
||||
true,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
matches!(result, Err(ProxyError::DataQuotaExceeded { .. })),
|
||||
"pre-write quota rejection must return typed quota error"
|
||||
);
|
||||
assert_eq!(
|
||||
write_calls.load(Ordering::Relaxed),
|
||||
0,
|
||||
"writer must not be polled when pre-write quota reservation fails"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_me_d2c_quota_reject_pre_write_total(),
|
||||
1,
|
||||
"pre-write quota reject metric must increment"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_user_quota_used(user),
|
||||
limit,
|
||||
"failed pre-write reservation must keep previous quota usage unchanged"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_quota_write_fail_bytes_total(),
|
||||
0,
|
||||
"write-fail bytes metric must stay unchanged on pre-write reject"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_quota_write_fail_events_total(),
|
||||
0,
|
||||
"write-fail events metric must stay unchanged on pre-write reject"
|
||||
);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
|
@ -1,113 +0,0 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use tokio::sync::Barrier;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let _pressure_guard = super::relay_idle_pressure_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!(
|
||||
"middle-blackhat-held-{}-{idx}",
|
||||
std::process::id()
|
||||
)));
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
map.len(),
|
||||
QUOTA_USER_LOCKS_MAX,
|
||||
"precondition: bounded lock cache must be saturated"
|
||||
);
|
||||
|
||||
let (tx, _rx) = mpsc::channel::<C2MeCommand>(1);
|
||||
tx.send(C2MeCommand::Close)
|
||||
.await
|
||||
.expect("queue prefill should succeed");
|
||||
|
||||
let pressure_seq_before = relay_pressure_event_seq();
|
||||
let pressure_errors = Arc::new(AtomicUsize::new(0));
|
||||
let mut pressure_workers = Vec::new();
|
||||
for _ in 0..16 {
|
||||
let tx = tx.clone();
|
||||
let pressure_errors = Arc::clone(&pressure_errors);
|
||||
pressure_workers.push(tokio::spawn(async move {
|
||||
if enqueue_c2me_command(&tx, C2MeCommand::Close).await.is_err() {
|
||||
pressure_errors.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("middle-blackhat-quota-race-{}", std::process::id());
|
||||
let gate = Arc::new(Barrier::new(16));
|
||||
|
||||
let mut quota_workers = Vec::new();
|
||||
for _ in 0..16u8 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let gate = Arc::clone(&gate);
|
||||
quota_workers.push(tokio::spawn(async move {
|
||||
gate.wait().await;
|
||||
let user_lock = quota_user_lock(&user);
|
||||
let _quota_guard = user_lock.lock().await;
|
||||
|
||||
if quota_would_be_exceeded_for_user(&stats, &user, Some(1), 1) {
|
||||
return false;
|
||||
}
|
||||
stats.add_user_octets_to(&user, 1);
|
||||
true
|
||||
}));
|
||||
}
|
||||
|
||||
let mut ok_count = 0usize;
|
||||
let mut denied_count = 0usize;
|
||||
for worker in quota_workers {
|
||||
let result = timeout(Duration::from_secs(2), worker)
|
||||
.await
|
||||
.expect("quota worker must finish")
|
||||
.expect("quota worker must not panic");
|
||||
if result {
|
||||
ok_count += 1;
|
||||
} else {
|
||||
denied_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
for worker in pressure_workers {
|
||||
timeout(Duration::from_secs(2), worker)
|
||||
.await
|
||||
.expect("pressure worker must finish")
|
||||
.expect("pressure worker must not panic");
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
stats.get_user_total_octets(&user),
|
||||
1,
|
||||
"black-hat campaign must not overshoot same-user quota under saturation"
|
||||
);
|
||||
assert!(ok_count <= 1, "at most one quota contender may succeed");
|
||||
assert!(
|
||||
denied_count >= 15,
|
||||
"all remaining contenders must be quota-denied"
|
||||
);
|
||||
|
||||
let pressure_seq_after = relay_pressure_event_seq();
|
||||
assert!(
|
||||
pressure_seq_after > pressure_seq_before,
|
||||
"queue pressure leg must trigger pressure accounting"
|
||||
);
|
||||
assert!(
|
||||
pressure_errors.load(Ordering::Relaxed) >= 1,
|
||||
"at least one pressure worker should fail from persistent backpressure"
|
||||
);
|
||||
|
||||
drop(retained);
|
||||
}
|
||||
|
|
@ -1,777 +0,0 @@
|
|||
use super::*;
|
||||
use crate::crypto::AesCtr;
|
||||
use crate::crypto::SecureRandom;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::{BufferPool, PooledBuffer};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::duplex;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::{Duration as TokioDuration, timeout};
|
||||
|
||||
fn make_pooled_payload(data: &[u8]) -> PooledBuffer {
|
||||
let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4));
|
||||
let mut payload = pool.get();
|
||||
payload.resize(data.len(), 0);
|
||||
payload[..data.len()].copy_from_slice(data);
|
||||
payload
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_client_payload_abridged_short_quickack_sets_flag_and_preserves_payload() {
|
||||
let (mut read_side, write_side) = duplex(4096);
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
|
||||
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
let rng = SecureRandom::new();
|
||||
let mut frame_buf = Vec::new();
|
||||
let payload = vec![0xA1, 0xB2, 0xC3, 0xD4, 0x10, 0x20, 0x30, 0x40];
|
||||
|
||||
write_client_payload(
|
||||
&mut writer,
|
||||
ProtoTag::Abridged,
|
||||
RPC_FLAG_QUICKACK,
|
||||
&payload,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
)
|
||||
.await
|
||||
.expect("abridged quickack payload should serialize");
|
||||
writer.flush().await.expect("flush must succeed");
|
||||
|
||||
let mut encrypted = vec![0u8; 1 + payload.len()];
|
||||
read_side
|
||||
.read_exact(&mut encrypted)
|
||||
.await
|
||||
.expect("must read serialized abridged frame");
|
||||
let plaintext = decryptor.decrypt(&encrypted);
|
||||
|
||||
assert_eq!(plaintext[0], 0x80 | ((payload.len() / 4) as u8));
|
||||
assert_eq!(&plaintext[1..], payload.as_slice());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_client_payload_abridged_extended_header_is_encoded_correctly() {
|
||||
let (mut read_side, write_side) = duplex(16 * 1024);
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
|
||||
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
let rng = SecureRandom::new();
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
// Boundary where abridged switches to extended length encoding.
|
||||
let payload = vec![0x5Au8; 0x7f * 4];
|
||||
|
||||
write_client_payload(
|
||||
&mut writer,
|
||||
ProtoTag::Abridged,
|
||||
RPC_FLAG_QUICKACK,
|
||||
&payload,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
)
|
||||
.await
|
||||
.expect("extended abridged payload should serialize");
|
||||
writer.flush().await.expect("flush must succeed");
|
||||
|
||||
let mut encrypted = vec![0u8; 4 + payload.len()];
|
||||
read_side
|
||||
.read_exact(&mut encrypted)
|
||||
.await
|
||||
.expect("must read serialized extended abridged frame");
|
||||
let plaintext = decryptor.decrypt(&encrypted);
|
||||
|
||||
assert_eq!(plaintext[0], 0xff, "0x7f with quickack bit must be set");
|
||||
assert_eq!(&plaintext[1..4], &[0x7f, 0x00, 0x00]);
|
||||
assert_eq!(&plaintext[4..], payload.as_slice());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_client_payload_abridged_misaligned_is_rejected_fail_closed() {
|
||||
let (_read_side, write_side) = duplex(1024);
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
|
||||
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
|
||||
let rng = SecureRandom::new();
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
let err = write_client_payload(
|
||||
&mut writer,
|
||||
ProtoTag::Abridged,
|
||||
0,
|
||||
&[1, 2, 3],
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
)
|
||||
.await
|
||||
.expect_err("misaligned abridged payload must be rejected");
|
||||
|
||||
let msg = format!("{err}");
|
||||
assert!(
|
||||
msg.contains("4-byte aligned"),
|
||||
"error should explain alignment contract, got: {msg}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_client_payload_secure_misaligned_is_rejected_fail_closed() {
|
||||
let (_read_side, write_side) = duplex(1024);
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
|
||||
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
|
||||
let rng = SecureRandom::new();
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
let err = write_client_payload(
|
||||
&mut writer,
|
||||
ProtoTag::Secure,
|
||||
0,
|
||||
&[9, 8, 7, 6, 5],
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
)
|
||||
.await
|
||||
.expect_err("misaligned secure payload must be rejected");
|
||||
|
||||
let msg = format!("{err}");
|
||||
assert!(
|
||||
msg.contains("Secure payload must be 4-byte aligned"),
|
||||
"error should be explicit for fail-closed triage, got: {msg}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_client_payload_intermediate_quickack_sets_length_msb() {
|
||||
let (mut read_side, write_side) = duplex(4096);
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
|
||||
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
let rng = SecureRandom::new();
|
||||
let mut frame_buf = Vec::new();
|
||||
let payload = b"hello-middle-relay";
|
||||
|
||||
write_client_payload(
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
RPC_FLAG_QUICKACK,
|
||||
payload,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
)
|
||||
.await
|
||||
.expect("intermediate quickack payload should serialize");
|
||||
writer.flush().await.expect("flush must succeed");
|
||||
|
||||
let mut encrypted = vec![0u8; 4 + payload.len()];
|
||||
read_side
|
||||
.read_exact(&mut encrypted)
|
||||
.await
|
||||
.expect("must read intermediate frame");
|
||||
let plaintext = decryptor.decrypt(&encrypted);
|
||||
|
||||
let mut len_bytes = [0u8; 4];
|
||||
len_bytes.copy_from_slice(&plaintext[..4]);
|
||||
let len_with_flags = u32::from_le_bytes(len_bytes);
|
||||
assert_ne!(len_with_flags & 0x8000_0000, 0, "quickack bit must be set");
|
||||
assert_eq!((len_with_flags & 0x7fff_ffff) as usize, payload.len());
|
||||
assert_eq!(&plaintext[4..], payload);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_client_payload_secure_quickack_prefix_and_padding_bounds_hold() {
|
||||
let (mut read_side, write_side) = duplex(4096);
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
|
||||
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
let rng = SecureRandom::new();
|
||||
let mut frame_buf = Vec::new();
|
||||
let payload = vec![0x33u8; 100]; // 4-byte aligned as required by secure mode.
|
||||
|
||||
write_client_payload(
|
||||
&mut writer,
|
||||
ProtoTag::Secure,
|
||||
RPC_FLAG_QUICKACK,
|
||||
&payload,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
)
|
||||
.await
|
||||
.expect("secure quickack payload should serialize");
|
||||
writer.flush().await.expect("flush must succeed");
|
||||
|
||||
// Secure mode adds 1..=3 bytes of randomized tail padding.
|
||||
let mut encrypted_header = [0u8; 4];
|
||||
read_side
|
||||
.read_exact(&mut encrypted_header)
|
||||
.await
|
||||
.expect("must read secure header");
|
||||
let decrypted_header = decryptor.decrypt(&encrypted_header);
|
||||
let header: [u8; 4] = decrypted_header
|
||||
.try_into()
|
||||
.expect("decrypted secure header must be 4 bytes");
|
||||
let wire_len_raw = u32::from_le_bytes(header);
|
||||
|
||||
assert_ne!(
|
||||
wire_len_raw & 0x8000_0000,
|
||||
0,
|
||||
"secure quickack bit must be set"
|
||||
);
|
||||
|
||||
let wire_len = (wire_len_raw & 0x7fff_ffff) as usize;
|
||||
assert!(wire_len >= payload.len());
|
||||
let padding_len = wire_len - payload.len();
|
||||
assert!(
|
||||
(1..=3).contains(&padding_len),
|
||||
"secure writer must add bounded random tail padding, got {padding_len}"
|
||||
);
|
||||
|
||||
let mut encrypted_body = vec![0u8; wire_len];
|
||||
read_side
|
||||
.read_exact(&mut encrypted_body)
|
||||
.await
|
||||
.expect("must read secure body");
|
||||
let decrypted_body = decryptor.decrypt(&encrypted_body);
|
||||
assert_eq!(&decrypted_body[..payload.len()], payload.as_slice());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "heavy: allocates >64MiB to validate abridged too-large fail-closed branch"]
|
||||
async fn write_client_payload_abridged_too_large_is_rejected_fail_closed() {
|
||||
let (_read_side, write_side) = duplex(1024);
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
|
||||
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
|
||||
let rng = SecureRandom::new();
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
// Exactly one 4-byte word above the encodable 24-bit abridged length range.
|
||||
let payload = vec![0x00u8; (1 << 24) * 4];
|
||||
let err = write_client_payload(
|
||||
&mut writer,
|
||||
ProtoTag::Abridged,
|
||||
0,
|
||||
&payload,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
)
|
||||
.await
|
||||
.expect_err("oversized abridged payload must be rejected");
|
||||
|
||||
let msg = format!("{err}");
|
||||
assert!(
|
||||
msg.contains("Abridged frame too large"),
|
||||
"error must clearly indicate oversize fail-close path, got: {msg}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_client_ack_intermediate_is_little_endian() {
|
||||
let (mut read_side, write_side) = duplex(1024);
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
|
||||
write_client_ack(&mut writer, ProtoTag::Intermediate, 0x11_22_33_44)
|
||||
.await
|
||||
.expect("ack serialization should succeed");
|
||||
writer.flush().await.expect("flush must succeed");
|
||||
|
||||
let mut encrypted = [0u8; 4];
|
||||
read_side
|
||||
.read_exact(&mut encrypted)
|
||||
.await
|
||||
.expect("must read ack bytes");
|
||||
let plain = decryptor.decrypt(&encrypted);
|
||||
assert_eq!(plain.as_slice(), &0x11_22_33_44u32.to_le_bytes());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_client_ack_abridged_is_big_endian() {
|
||||
let (mut read_side, write_side) = duplex(1024);
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
|
||||
write_client_ack(&mut writer, ProtoTag::Abridged, 0xDE_AD_BE_EF)
|
||||
.await
|
||||
.expect("ack serialization should succeed");
|
||||
writer.flush().await.expect("flush must succeed");
|
||||
|
||||
let mut encrypted = [0u8; 4];
|
||||
read_side
|
||||
.read_exact(&mut encrypted)
|
||||
.await
|
||||
.expect("must read ack bytes");
|
||||
let plain = decryptor.decrypt(&encrypted);
|
||||
assert_eq!(plain.as_slice(), &0xDE_AD_BE_EFu32.to_be_bytes());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_client_payload_abridged_short_boundary_0x7e_is_single_byte_header() {
|
||||
let (mut read_side, write_side) = duplex(1024 * 1024);
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
let rng = SecureRandom::new();
|
||||
let mut frame_buf = Vec::new();
|
||||
let payload = vec![0xABu8; 0x7e * 4];
|
||||
|
||||
write_client_payload(
|
||||
&mut writer,
|
||||
ProtoTag::Abridged,
|
||||
0,
|
||||
&payload,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
)
|
||||
.await
|
||||
.expect("boundary payload should serialize");
|
||||
writer.flush().await.expect("flush must succeed");
|
||||
|
||||
let mut encrypted = vec![0u8; 1 + payload.len()];
|
||||
read_side.read_exact(&mut encrypted).await.unwrap();
|
||||
let plain = decryptor.decrypt(&encrypted);
|
||||
assert_eq!(plain[0], 0x7e);
|
||||
assert_eq!(&plain[1..], payload.as_slice());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_client_payload_abridged_extended_without_quickack_has_clean_prefix() {
|
||||
let (mut read_side, write_side) = duplex(16 * 1024);
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
let rng = SecureRandom::new();
|
||||
let mut frame_buf = Vec::new();
|
||||
let payload = vec![0x42u8; 0x80 * 4];
|
||||
|
||||
write_client_payload(
|
||||
&mut writer,
|
||||
ProtoTag::Abridged,
|
||||
0,
|
||||
&payload,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
)
|
||||
.await
|
||||
.expect("extended payload should serialize");
|
||||
writer.flush().await.expect("flush must succeed");
|
||||
|
||||
let mut encrypted = vec![0u8; 4 + payload.len()];
|
||||
read_side.read_exact(&mut encrypted).await.unwrap();
|
||||
let plain = decryptor.decrypt(&encrypted);
|
||||
assert_eq!(plain[0], 0x7f);
|
||||
assert_eq!(&plain[1..4], &[0x80, 0x00, 0x00]);
|
||||
assert_eq!(&plain[4..], payload.as_slice());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_client_payload_intermediate_zero_length_emits_header_only() {
|
||||
let (mut read_side, write_side) = duplex(1024);
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
let rng = SecureRandom::new();
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
write_client_payload(
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
0,
|
||||
&[],
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
)
|
||||
.await
|
||||
.expect("zero-length intermediate payload should serialize");
|
||||
writer.flush().await.expect("flush must succeed");
|
||||
|
||||
let mut encrypted = [0u8; 4];
|
||||
read_side.read_exact(&mut encrypted).await.unwrap();
|
||||
let plain = decryptor.decrypt(&encrypted);
|
||||
assert_eq!(plain.as_slice(), &[0, 0, 0, 0]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_client_payload_intermediate_ignores_unrelated_flags() {
|
||||
let (mut read_side, write_side) = duplex(1024);
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
let rng = SecureRandom::new();
|
||||
let mut frame_buf = Vec::new();
|
||||
let payload = [7u8; 12];
|
||||
|
||||
write_client_payload(
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
0x4000_0000,
|
||||
&payload,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
)
|
||||
.await
|
||||
.expect("payload should serialize");
|
||||
writer.flush().await.expect("flush must succeed");
|
||||
|
||||
let mut encrypted = [0u8; 16];
|
||||
read_side.read_exact(&mut encrypted).await.unwrap();
|
||||
let plain = decryptor.decrypt(&encrypted);
|
||||
let len = u32::from_le_bytes(plain[0..4].try_into().unwrap());
|
||||
assert_eq!(len, payload.len() as u32, "only quickack bit may affect header");
|
||||
assert_eq!(&plain[4..], payload.as_slice());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_client_payload_secure_without_quickack_keeps_msb_clear() {
|
||||
let (mut read_side, write_side) = duplex(4096);
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
let rng = SecureRandom::new();
|
||||
let mut frame_buf = Vec::new();
|
||||
let payload = [0x1Du8; 64];
|
||||
|
||||
write_client_payload(
|
||||
&mut writer,
|
||||
ProtoTag::Secure,
|
||||
0,
|
||||
&payload,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
)
|
||||
.await
|
||||
.expect("payload should serialize");
|
||||
writer.flush().await.expect("flush must succeed");
|
||||
|
||||
let mut encrypted_header = [0u8; 4];
|
||||
read_side.read_exact(&mut encrypted_header).await.unwrap();
|
||||
let plain_header = decryptor.decrypt(&encrypted_header);
|
||||
let h: [u8; 4] = plain_header.as_slice().try_into().unwrap();
|
||||
let wire_len_raw = u32::from_le_bytes(h);
|
||||
assert_eq!(wire_len_raw & 0x8000_0000, 0, "quickack bit must stay clear");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn secure_padding_light_fuzz_distribution_has_multiple_outcomes() {
|
||||
let (mut read_side, write_side) = duplex(256 * 1024);
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
let rng = SecureRandom::new();
|
||||
let mut frame_buf = Vec::new();
|
||||
let payload = [0x55u8; 100];
|
||||
let mut seen = [false; 4];
|
||||
|
||||
for _ in 0..96 {
|
||||
write_client_payload(
|
||||
&mut writer,
|
||||
ProtoTag::Secure,
|
||||
0,
|
||||
&payload,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
)
|
||||
.await
|
||||
.expect("secure payload should serialize");
|
||||
writer.flush().await.expect("flush must succeed");
|
||||
|
||||
let mut encrypted_header = [0u8; 4];
|
||||
read_side.read_exact(&mut encrypted_header).await.unwrap();
|
||||
let plain_header = decryptor.decrypt(&encrypted_header);
|
||||
let h: [u8; 4] = plain_header.as_slice().try_into().unwrap();
|
||||
let wire_len = (u32::from_le_bytes(h) & 0x7fff_ffff) as usize;
|
||||
let padding_len = wire_len - payload.len();
|
||||
assert!((1..=3).contains(&padding_len));
|
||||
seen[padding_len] = true;
|
||||
|
||||
let mut encrypted_body = vec![0u8; wire_len];
|
||||
read_side.read_exact(&mut encrypted_body).await.unwrap();
|
||||
let _ = decryptor.decrypt(&encrypted_body);
|
||||
}
|
||||
|
||||
let distinct = (1..=3).filter(|idx| seen[*idx]).count();
|
||||
assert!(
|
||||
distinct >= 2,
|
||||
"padding generator should not collapse to a single outcome under campaign"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn write_client_payload_mixed_proto_sequence_preserves_stream_sync() {
|
||||
let (mut read_side, write_side) = duplex(128 * 1024);
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
|
||||
let mut decryptor = AesCtr::new(&key, iv);
|
||||
let rng = SecureRandom::new();
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
let p1 = vec![1u8; 8];
|
||||
let p2 = vec![2u8; 16];
|
||||
let p3 = vec![3u8; 20];
|
||||
|
||||
write_client_payload(&mut writer, ProtoTag::Abridged, 0, &p1, &rng, &mut frame_buf)
|
||||
.await
|
||||
.unwrap();
|
||||
write_client_payload(
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
RPC_FLAG_QUICKACK,
|
||||
&p2,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
write_client_payload(&mut writer, ProtoTag::Secure, 0, &p3, &rng, &mut frame_buf)
|
||||
.await
|
||||
.unwrap();
|
||||
writer.flush().await.unwrap();
|
||||
|
||||
// Frame 1: abridged short.
|
||||
let mut e1 = vec![0u8; 1 + p1.len()];
|
||||
read_side.read_exact(&mut e1).await.unwrap();
|
||||
let d1 = decryptor.decrypt(&e1);
|
||||
assert_eq!(d1[0], (p1.len() / 4) as u8);
|
||||
assert_eq!(&d1[1..], p1.as_slice());
|
||||
|
||||
// Frame 2: intermediate with quickack.
|
||||
let mut e2 = vec![0u8; 4 + p2.len()];
|
||||
read_side.read_exact(&mut e2).await.unwrap();
|
||||
let d2 = decryptor.decrypt(&e2);
|
||||
let l2 = u32::from_le_bytes(d2[0..4].try_into().unwrap());
|
||||
assert_ne!(l2 & 0x8000_0000, 0);
|
||||
assert_eq!((l2 & 0x7fff_ffff) as usize, p2.len());
|
||||
assert_eq!(&d2[4..], p2.as_slice());
|
||||
|
||||
// Frame 3: secure with bounded tail.
|
||||
let mut e3h = [0u8; 4];
|
||||
read_side.read_exact(&mut e3h).await.unwrap();
|
||||
let d3h = decryptor.decrypt(&e3h);
|
||||
let l3 = (u32::from_le_bytes(d3h.as_slice().try_into().unwrap()) & 0x7fff_ffff) as usize;
|
||||
assert!(l3 >= p3.len());
|
||||
assert!((1..=3).contains(&(l3 - p3.len())));
|
||||
let mut e3b = vec![0u8; l3];
|
||||
read_side.read_exact(&mut e3b).await.unwrap();
|
||||
let d3b = decryptor.decrypt(&e3b);
|
||||
assert_eq!(&d3b[..p3.len()], p3.as_slice());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_yield_sender_boundary_matrix_blackhat() {
|
||||
assert!(!should_yield_c2me_sender(0, false));
|
||||
assert!(!should_yield_c2me_sender(0, true));
|
||||
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true));
|
||||
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false));
|
||||
assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true));
|
||||
assert!(should_yield_c2me_sender(
|
||||
C2ME_SENDER_FAIRNESS_BUDGET.saturating_add(1024),
|
||||
true
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_yield_sender_light_fuzz_matches_oracle() {
|
||||
let mut s: u64 = 0xD00D_BAAD_F00D_CAFE;
|
||||
for _ in 0..5000 {
|
||||
s ^= s << 7;
|
||||
s ^= s >> 9;
|
||||
s ^= s << 8;
|
||||
let sent = (s as usize) & 0x1fff;
|
||||
let backlog = (s & 1) != 0;
|
||||
|
||||
let expected = backlog && sent >= C2ME_SENDER_FAIRNESS_BUDGET;
|
||||
assert_eq!(should_yield_c2me_sender(sent, backlog), expected);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_would_be_exceeded_exact_remaining_one_byte() {
|
||||
let stats = Stats::new();
|
||||
let user = "quota-edge";
|
||||
let quota = 100u64;
|
||||
stats.add_user_octets_to(user, 99);
|
||||
|
||||
assert!(
|
||||
!quota_would_be_exceeded_for_user(&stats, user, Some(quota), 1),
|
||||
"exactly remaining budget should be allowed"
|
||||
);
|
||||
assert!(
|
||||
quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2),
|
||||
"one byte beyond remaining budget must be rejected"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_would_be_exceeded_saturating_edge_remains_fail_closed() {
|
||||
let stats = Stats::new();
|
||||
let user = "quota-saturating-edge";
|
||||
let quota = u64::MAX - 3;
|
||||
stats.add_user_octets_to(user, u64::MAX - 4);
|
||||
|
||||
assert!(
|
||||
quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2),
|
||||
"saturating arithmetic edge must stay fail-closed"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_exceeded_boundary_is_inclusive() {
|
||||
let stats = Stats::new();
|
||||
let user = "quota-inclusive-boundary";
|
||||
stats.add_user_octets_to(user, 50);
|
||||
|
||||
assert!(quota_exceeded_for_user(&stats, user, Some(50)));
|
||||
assert!(!quota_exceeded_for_user(&stats, user, Some(51)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_soft_helper_matches_capped_generic_helper_matrix() {
|
||||
let stats = Stats::new();
|
||||
let user = "quota-soft-parity";
|
||||
|
||||
for used in [0u64, 1, 7, 63, 127, 255] {
|
||||
stats.sub_user_octets_to(user, stats.get_user_total_octets(user));
|
||||
stats.add_user_octets_to(user, used);
|
||||
|
||||
for quota in [8u64, 64, 128, 256] {
|
||||
for overshoot in [0u64, 1, 5, 32] {
|
||||
for bytes in [0u64, 1, 2, 7, 31, 64] {
|
||||
let soft = quota_would_be_exceeded_for_user_soft(
|
||||
&stats,
|
||||
user,
|
||||
Some(quota),
|
||||
bytes,
|
||||
overshoot,
|
||||
);
|
||||
let capped = quota_would_be_exceeded_for_user(
|
||||
&stats,
|
||||
user,
|
||||
Some(quota_soft_cap(quota, overshoot)),
|
||||
bytes,
|
||||
);
|
||||
assert_eq!(
|
||||
soft, capped,
|
||||
"soft helper parity mismatch: used={used} quota={quota} overshoot={overshoot} bytes={bytes}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_soft_helper_none_limit_never_rejects() {
|
||||
let stats = Stats::new();
|
||||
let user = "quota-soft-none";
|
||||
stats.add_user_octets_to(user, u64::MAX);
|
||||
|
||||
assert!(!quota_would_be_exceeded_for_user_soft(
|
||||
&stats,
|
||||
user,
|
||||
None,
|
||||
u64::MAX,
|
||||
u64::MAX,
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_soft_cap_saturates_and_stays_fail_closed() {
|
||||
let stats = Stats::new();
|
||||
let user = "quota-soft-saturating";
|
||||
let quota = u64::MAX - 2;
|
||||
let overshoot = 100;
|
||||
|
||||
assert_eq!(quota_soft_cap(quota, overshoot), u64::MAX);
|
||||
|
||||
stats.add_user_octets_to(user, u64::MAX - 1);
|
||||
assert!(quota_would_be_exceeded_for_user_soft(
|
||||
&stats,
|
||||
user,
|
||||
Some(quota),
|
||||
2,
|
||||
overshoot,
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() {
|
||||
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(4);
|
||||
enqueue_c2me_command(&tx, C2MeCommand::Close)
|
||||
.await
|
||||
.expect("close should enqueue on fast path");
|
||||
|
||||
let recv = timeout(TokioDuration::from_millis(50), rx.recv())
|
||||
.await
|
||||
.expect("must receive close command")
|
||||
.expect("close command should be present");
|
||||
assert!(matches!(recv, C2MeCommand::Close));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn enqueue_c2me_data_full_then_drain_preserves_order() {
|
||||
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
|
||||
tx.send(C2MeCommand::Data {
|
||||
payload: make_pooled_payload(&[1]),
|
||||
flags: 10,
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tx2 = tx.clone();
|
||||
let producer = tokio::spawn(async move {
|
||||
enqueue_c2me_command(
|
||||
&tx2,
|
||||
C2MeCommand::Data {
|
||||
payload: make_pooled_payload(&[2, 2]),
|
||||
flags: 20,
|
||||
},
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
tokio::time::sleep(TokioDuration::from_millis(10)).await;
|
||||
|
||||
let first = rx.recv().await.expect("first item should exist");
|
||||
match first {
|
||||
C2MeCommand::Data { payload, flags } => {
|
||||
assert_eq!(payload.as_ref(), &[1]);
|
||||
assert_eq!(flags, 10);
|
||||
}
|
||||
C2MeCommand::Close => panic!("unexpected close as first item"),
|
||||
}
|
||||
|
||||
producer.await.unwrap().expect("producer should complete");
|
||||
|
||||
let second = timeout(TokioDuration::from_millis(100), rx.recv())
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("second item should exist");
|
||||
match second {
|
||||
C2MeCommand::Data { payload, flags } => {
|
||||
assert_eq!(payload.as_ref(), &[2, 2]);
|
||||
assert_eq!(flags, 20);
|
||||
}
|
||||
C2MeCommand::Close => panic!("unexpected close as second item"),
|
||||
}
|
||||
}
|
||||
|
|
@ -1,295 +0,0 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct BlockingWriteState {
|
||||
write_entered: AtomicBool,
|
||||
released: AtomicBool,
|
||||
write_waker: Mutex<Option<Waker>>,
|
||||
write_entered_notify: Notify,
|
||||
}
|
||||
|
||||
struct BlockingWrite {
|
||||
state: Arc<BlockingWriteState>,
|
||||
}
|
||||
|
||||
impl BlockingWrite {
|
||||
fn new(state: Arc<BlockingWriteState>) -> Self {
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for BlockingWrite {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.state.write_entered.store(true, Ordering::Release);
|
||||
self.state.write_entered_notify.notify_waiters();
|
||||
|
||||
if self.state.released.load(Ordering::Acquire) {
|
||||
return Poll::Ready(Ok(buf.len()));
|
||||
}
|
||||
|
||||
if let Ok(mut slot) = self.state.write_waker.lock() {
|
||||
*slot = Some(cx.waker().clone());
|
||||
}
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_until_blocking_write_entered(state: &Arc<BlockingWriteState>) {
|
||||
for _ in 0..8 {
|
||||
if state.write_entered.load(Ordering::Acquire) {
|
||||
return;
|
||||
}
|
||||
let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await;
|
||||
}
|
||||
|
||||
panic!("blocking writer did not enter poll_write in bounded time");
|
||||
}
|
||||
|
||||
fn release_blocking_write(state: &Arc<BlockingWriteState>) {
|
||||
state.released.store(true, Ordering::Release);
|
||||
if let Ok(mut slot) = state.write_waker.lock()
|
||||
&& let Some(waker) = slot.take()
|
||||
{
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn adversarial_blocked_write_releases_cross_mode_lock_and_preserves_fail_closed_quota() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("middle-cross-release-regression-{}", std::process::id());
|
||||
let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user));
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
let writer_state = Arc::new(BlockingWriteState::default());
|
||||
|
||||
let first = {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let cross_mode_lock = Arc::clone(&cross_mode_lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
let writer_state = Arc::clone(&writer_state);
|
||||
tokio::spawn(async move {
|
||||
let mut writer = make_crypto_writer(BlockingWrite::new(writer_state));
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xAA, 0xBB, 0xCC, 0xDD]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(4),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
bytes_me2c.as_ref(),
|
||||
41_000,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
wait_until_blocking_write_entered(&writer_state).await;
|
||||
|
||||
let guard = timeout(Duration::from_millis(40), cross_mode_lock.lock())
|
||||
.await
|
||||
.expect("cross-mode lock must be released while first write is pending");
|
||||
drop(guard);
|
||||
|
||||
let second = {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let cross_mode_lock = Arc::clone(&cross_mode_lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
tokio::spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
timeout(
|
||||
Duration::from_millis(150),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xEE]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(4),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
bytes_me2c.as_ref(),
|
||||
41_001,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
let second_result = second
|
||||
.await
|
||||
.expect("second task must not panic")
|
||||
.expect("second write must not block on cross-mode lock");
|
||||
assert!(
|
||||
matches!(second_result, Err(ProxyError::DataQuotaExceeded { .. })),
|
||||
"second write must fail closed due to first write reservation"
|
||||
);
|
||||
|
||||
release_blocking_write(&writer_state);
|
||||
|
||||
let first_result = timeout(Duration::from_millis(300), first)
|
||||
.await
|
||||
.expect("first task timed out")
|
||||
.expect("first task must not panic");
|
||||
assert!(first_result.is_ok());
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(&user), 4);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_pending_write_does_not_starve_same_user_waiters_after_quota_boundary() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("middle-cross-release-stress-{}", std::process::id());
|
||||
let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user));
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
let writer_state = Arc::new(BlockingWriteState::default());
|
||||
|
||||
let first = {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let cross_mode_lock = Arc::clone(&cross_mode_lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
let writer_state = Arc::clone(&writer_state);
|
||||
tokio::spawn(async move {
|
||||
let mut writer = make_crypto_writer(BlockingWrite::new(writer_state));
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x01, 0x02]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(3),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
bytes_me2c.as_ref(),
|
||||
41_100,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
wait_until_blocking_write_entered(&writer_state).await;
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for idx in 0..48u64 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let cross_mode_lock = Arc::clone(&cross_mode_lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
set.spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
timeout(
|
||||
Duration::from_millis(200),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x10]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(3),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
bytes_me2c.as_ref(),
|
||||
41_200 + idx,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
let mut ok = 0usize;
|
||||
let mut quota_exceeded = 0usize;
|
||||
while let Some(done) = set.join_next().await {
|
||||
let timed = done.expect("waiter task must not panic");
|
||||
let result = timed.expect("waiter must not block behind pending first write");
|
||||
match result {
|
||||
Ok(_) => ok += 1,
|
||||
Err(ProxyError::DataQuotaExceeded { .. }) => quota_exceeded += 1,
|
||||
Err(other) => panic!("unexpected error in waiter: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(ok, 1, "exactly one waiter should consume remaining one-byte quota");
|
||||
assert_eq!(quota_exceeded, 47);
|
||||
|
||||
release_blocking_write(&writer_state);
|
||||
|
||||
let first_result = timeout(Duration::from_millis(300), first)
|
||||
.await
|
||||
.expect("first task timed out")
|
||||
.expect("first task must not panic");
|
||||
assert!(first_result.is_ok());
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(&user), 3);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3);
|
||||
}
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
fn lookup_counter_test_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_prefetched_cross_mode_lock_avoids_per_frame_registry_lookup_in_me_to_client_writer() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-mode-lookup-{}", std::process::id());
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
|
||||
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
for idx in 0..8u64 {
|
||||
let outcome = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xAB]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
&bytes_me2c,
|
||||
20_000 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(outcome.is_ok());
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
|
||||
0,
|
||||
"prefetched lock path must not re-query lock registry per frame"
|
||||
);
|
||||
assert_eq!(stats.get_user_total_octets(&user), 8);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 8);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn control_without_prefetched_lock_still_uses_registry_lookup_path() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-mode-lookup-control-{}", std::process::id());
|
||||
|
||||
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let outcome = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xCD]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
None,
|
||||
&bytes_me2c,
|
||||
20_100,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(outcome.is_ok());
|
||||
assert_eq!(
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
|
||||
1,
|
||||
"fallback path without prefetched lock should perform a registry lookup"
|
||||
);
|
||||
}
|
||||
|
|
@ -1,376 +0,0 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_quota_limited_me_to_client_write_updates_counters_exactly_once() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-matrix-positive-{}", std::process::id());
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[1, 2, 3, 4]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(128),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
10_001,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 4);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_held_cross_mode_lock_blocks_quota_limited_me_to_client_path() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-matrix-negative-{}", std::process::id());
|
||||
let held = cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock before ME->C call");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x41]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(256),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
10_002,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(blocked.is_err());
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_quota_none_bypasses_cross_mode_lock_guard_in_me_to_client_path() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-matrix-edge-none-{}", std::process::id());
|
||||
let held = cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock while quota is disabled");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let outcome = timeout(
|
||||
Duration::from_millis(80),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x11, 0x22]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
None,
|
||||
0,
|
||||
&bytes_me2c,
|
||||
10_003,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("quota-none path must not wait on cross-mode lock");
|
||||
|
||||
assert!(outcome.is_ok());
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_same_user_parallel_quota_limited_writes_stay_hard_capped() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("middle-cross-matrix-adversarial-{}", std::process::id());
|
||||
let limit = 64u64;
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
let mut tasks = Vec::new();
|
||||
|
||||
for idx in 0..256u64 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
let user = user.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xEE]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(limit),
|
||||
0,
|
||||
bytes_me2c.as_ref(),
|
||||
11_000 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
}));
|
||||
}
|
||||
|
||||
let mut ok = 0usize;
|
||||
for task in tasks {
|
||||
match task.await.expect("task must not panic") {
|
||||
Ok(_) => ok += 1,
|
||||
Err(ProxyError::DataQuotaExceeded { .. }) => {}
|
||||
Err(other) => panic!("unexpected error in adversarial parallel case: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(ok, limit as usize);
|
||||
assert_eq!(stats.get_user_total_octets(&user), limit);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), limit);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_shared_lock_blocks_direct_relay_and_middle_relay_for_same_user() {
|
||||
let user = format!("middle-cross-matrix-integration-{}", std::process::id());
|
||||
let relay_lock = crate::proxy::relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let middle_lock = cross_mode_quota_user_lock_for_tests(&user);
|
||||
assert!(
|
||||
Arc::ptr_eq(&relay_lock, &middle_lock),
|
||||
"relay and middle-relay must share the same cross-mode lock identity"
|
||||
);
|
||||
|
||||
let held_guard = relay_lock
|
||||
.try_lock()
|
||||
.expect("test must hold shared cross-mode lock");
|
||||
|
||||
let stats = Stats::new();
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let middle_blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x92]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
12_001,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
assert!(middle_blocked.is_err());
|
||||
|
||||
drop(held_guard);
|
||||
|
||||
let middle_ready = timeout(
|
||||
Duration::from_millis(250),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x94]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
12_002,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("middle path must complete after release");
|
||||
|
||||
assert!(middle_ready.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_mixed_payload_sizes_with_periodic_lock_holds_keeps_accounting_consistent() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-cross-matrix-fuzz-{}", std::process::id());
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let mut seed = 0xC0DE_1234_55AA_9988u64;
|
||||
|
||||
for case in 0..96u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold = (seed & 0x03) == 0;
|
||||
let mut held_lock = None;
|
||||
let maybe_guard = if hold {
|
||||
held_lock = Some(cross_mode_quota_user_lock_for_tests(&user));
|
||||
Some(
|
||||
held_lock
|
||||
.as_ref()
|
||||
.expect("held lock should be present")
|
||||
.try_lock()
|
||||
.expect("cross-mode lock should be acquirable in fuzz round"),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let payload_len = ((seed >> 8) as usize % 8) + 1;
|
||||
let payload = vec![(seed & 0xff) as u8; payload_len];
|
||||
let before = stats.get_user_total_octets(&user);
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
let timed = timeout(
|
||||
Duration::from_millis(20),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(payload),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
13_000 + case as u64,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
if hold {
|
||||
assert!(timed.is_err(), "held-lock fuzz round must block within timeout");
|
||||
assert_eq!(stats.get_user_total_octets(&user), before);
|
||||
} else {
|
||||
let done = timed.expect("unheld fuzz round must complete in time");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
|
||||
drop(maybe_guard);
|
||||
drop(held_lock);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), stats.get_user_total_octets(&user));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_held_user_lock_does_not_block_other_users_me_to_client_writes() {
|
||||
let held_user = format!("middle-cross-matrix-stress-held-{}", std::process::id());
|
||||
let free_user = format!("middle-cross-matrix-stress-free-{}", std::process::id());
|
||||
|
||||
let held = cross_mode_quota_user_lock_for_tests(&held_user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock for blocked user");
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
for idx in 0..64u64 {
|
||||
let user = free_user.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let stats = Stats::new();
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xA0]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
14_000 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
}));
|
||||
}
|
||||
|
||||
timeout(Duration::from_secs(2), async {
|
||||
for task in tasks {
|
||||
let done = task.await.expect("free-user task must not panic");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("free-user tasks should complete without waiting for held user's lock");
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
|
@ -1,254 +0,0 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct BlockingWriteState {
|
||||
write_entered: AtomicBool,
|
||||
released: AtomicBool,
|
||||
write_waker: Mutex<Option<Waker>>,
|
||||
write_entered_notify: Notify,
|
||||
}
|
||||
|
||||
struct BlockingWrite {
|
||||
state: Arc<BlockingWriteState>,
|
||||
}
|
||||
|
||||
impl BlockingWrite {
|
||||
fn new(state: Arc<BlockingWriteState>) -> Self {
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for BlockingWrite {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
self.state.write_entered.store(true, Ordering::Release);
|
||||
self.state.write_entered_notify.notify_waiters();
|
||||
|
||||
if self.state.released.load(Ordering::Acquire) {
|
||||
return Poll::Ready(Ok(buf.len()));
|
||||
}
|
||||
|
||||
if let Ok(mut slot) = self.state.write_waker.lock() {
|
||||
*slot = Some(cx.waker().clone());
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_until_blocking_write_entered(state: &Arc<BlockingWriteState>) {
|
||||
for _ in 0..8 {
|
||||
if state.write_entered.load(Ordering::Acquire) {
|
||||
return;
|
||||
}
|
||||
let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await;
|
||||
}
|
||||
|
||||
panic!("blocking writer did not enter poll_write in bounded time");
|
||||
}
|
||||
|
||||
fn release_blocking_write(state: &Arc<BlockingWriteState>) {
|
||||
state.released.store(true, Ordering::Release);
|
||||
if let Ok(mut slot) = state.write_waker.lock()
|
||||
&& let Some(waker) = slot.take()
|
||||
{
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_held_cross_mode_lock_blocks_me_to_client_quota_reservation_path() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-me2c-cross-mode-held-{}", std::process::id());
|
||||
let held = cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared cross-mode lock before ME->C write path");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x41]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
9901,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
blocked.is_err(),
|
||||
"ME->C quota reservation path must be serialized by held shared cross-mode lock"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
|
||||
let released = timeout(
|
||||
Duration::from_millis(250),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x42]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
9902,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("ME->C write must complete after cross-mode lock release");
|
||||
|
||||
assert!(released.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn business_uncontended_cross_mode_lock_allows_me_to_client_quota_reservation() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("middle-me2c-cross-mode-free-{}", std::process::id());
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let outcome = timeout(
|
||||
Duration::from_millis(250),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x55, 0x66]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
9903,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("uncontended ME->C path should not stall");
|
||||
|
||||
assert!(outcome.is_ok());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 2);
|
||||
assert_eq!(bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), 2);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn adversarial_cross_mode_lock_is_released_before_me_to_client_write_await() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("middle-me2c-lock-drop-before-write-{}", std::process::id());
|
||||
let cross_mode_lock = cross_mode_quota_user_lock_for_tests(&user);
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
let writer_state = Arc::new(BlockingWriteState::default());
|
||||
|
||||
let worker = {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let cross_mode_lock = Arc::clone(&cross_mode_lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
let writer_state = Arc::clone(&writer_state);
|
||||
tokio::spawn(async move {
|
||||
let mut writer = make_crypto_writer(BlockingWrite::new(writer_state));
|
||||
let mut frame_buf = Vec::new();
|
||||
let rng = SecureRandom::new();
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xDE, 0xAD, 0xBE, 0xEF]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(1024),
|
||||
0,
|
||||
Some(&cross_mode_lock),
|
||||
bytes_me2c.as_ref(),
|
||||
9910,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
wait_until_blocking_write_entered(&writer_state).await;
|
||||
|
||||
let acquired_guard = timeout(Duration::from_millis(40), cross_mode_lock.lock())
|
||||
.await
|
||||
.expect("cross-mode lock must be free while ME->C write is pending");
|
||||
drop(acquired_guard);
|
||||
|
||||
release_blocking_write(&writer_state);
|
||||
|
||||
let result = timeout(Duration::from_millis(300), worker)
|
||||
.await
|
||||
.expect("ME->C worker timed out after releasing blocking writer")
|
||||
.expect("ME->C worker must not panic");
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 4);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4);
|
||||
}
|
||||
|
|
@ -1,232 +0,0 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct GateState {
|
||||
open: AtomicBool,
|
||||
parked_waker: std::sync::Mutex<Option<Waker>>,
|
||||
}
|
||||
|
||||
impl GateState {
|
||||
fn open(&self) {
|
||||
self.open.store(true, Ordering::Relaxed);
|
||||
if let Ok(mut guard) = self.parked_waker.lock()
|
||||
&& let Some(w) = guard.take()
|
||||
{
|
||||
w.wake();
|
||||
}
|
||||
}
|
||||
|
||||
fn has_waiter(&self) -> bool {
|
||||
self.parked_waker
|
||||
.lock()
|
||||
.map(|guard| guard.is_some())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct GateWriter {
|
||||
gate: Arc<GateState>,
|
||||
}
|
||||
|
||||
impl GateWriter {
|
||||
fn new(gate: Arc<GateState>) -> Self {
|
||||
Self { gate }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for GateWriter {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
if self.gate.open.load(Ordering::Relaxed) {
|
||||
return Poll::Ready(Ok(buf.len()));
|
||||
}
|
||||
|
||||
if let Ok(mut guard) = self.gate.parked_waker.lock() {
|
||||
*guard = Some(cx.waker().clone());
|
||||
}
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
struct FailingWriter;
|
||||
|
||||
impl AsyncWrite for FailingWriter {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
_buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::BrokenPipe,
|
||||
"injected writer failure",
|
||||
)))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn adversarial_same_user_slow_writer_must_not_hol_block_peer_connection() {
|
||||
let stats = Stats::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let rng = SecureRandom::new();
|
||||
let quota_limit = Some(1024);
|
||||
let user = "hol-quota-user";
|
||||
|
||||
let gate = Arc::new(GateState::default());
|
||||
|
||||
let mut blocked_writer = make_crypto_writer(GateWriter::new(Arc::clone(&gate)));
|
||||
let slow_task = tokio::spawn(async move {
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x10, 0x20, 0x30, 0x40]),
|
||||
},
|
||||
&mut blocked_writer,
|
||||
ProtoTag::Intermediate,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
quota_limit,
|
||||
0,
|
||||
&bytes_me2c,
|
||||
7001,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
timeout(Duration::from_millis(100), async {
|
||||
loop {
|
||||
if gate.has_waiter() {
|
||||
break;
|
||||
}
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("first writer must reach backpressure and park");
|
||||
|
||||
let stats_fast = Stats::new();
|
||||
let bytes_fast = AtomicU64::new(0);
|
||||
let rng_fast = SecureRandom::new();
|
||||
let mut fast_writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf_fast = Vec::new();
|
||||
|
||||
timeout(
|
||||
Duration::from_millis(50),
|
||||
process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x41]),
|
||||
},
|
||||
&mut fast_writer,
|
||||
ProtoTag::Intermediate,
|
||||
&rng_fast,
|
||||
&mut frame_buf_fast,
|
||||
&stats_fast,
|
||||
user,
|
||||
quota_limit,
|
||||
0,
|
||||
&bytes_fast,
|
||||
7002,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("peer connection must not be blocked by same-user stalled write")
|
||||
.expect("fast peer write must succeed");
|
||||
|
||||
gate.open();
|
||||
let slow_result = timeout(Duration::from_secs(1), slow_task)
|
||||
.await
|
||||
.expect("stalled task must complete once gate opens")
|
||||
.expect("stalled task must not panic");
|
||||
assert!(slow_result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_write_failure_rolls_back_pre_accounted_quota_and_forensics_bytes() {
|
||||
let stats = Stats::new();
|
||||
let user = "rollback-user";
|
||||
stats.add_user_octets_from(user, 7);
|
||||
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let rng = SecureRandom::new();
|
||||
let mut writer = make_crypto_writer(FailingWriter);
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
let result = process_me_writer_response(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[1, 2, 3, 4]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&rng,
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
user,
|
||||
Some(64),
|
||||
0,
|
||||
&bytes_me2c,
|
||||
7003,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::Io(_))));
|
||||
assert_eq!(
|
||||
stats.get_user_total_octets(user),
|
||||
7,
|
||||
"failed client write must not overcharge user quota accounting"
|
||||
);
|
||||
assert_eq!(
|
||||
bytes_me2c.load(Ordering::Relaxed),
|
||||
0,
|
||||
"failed client write must not inflate ME->C forensic byte counter"
|
||||
);
|
||||
}
|
||||
|
|
@ -2,8 +2,8 @@ use super::*;
|
|||
use crate::crypto::AesCtr;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::{BufferPool, CryptoReader};
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::duplex;
|
||||
use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout};
|
||||
|
|
|
|||
|
|
@ -29,7 +29,10 @@ fn blackhat_registry_poison_recovers_with_fail_closed_reset_and_pressure_account
|
|||
let before = relay_pressure_event_seq();
|
||||
note_relay_pressure_event();
|
||||
let after = relay_pressure_event_seq();
|
||||
assert!(after > before, "pressure accounting must still advance after poison");
|
||||
assert!(
|
||||
after > before,
|
||||
"pressure accounting must still advance after poison"
|
||||
);
|
||||
|
||||
clear_relay_idle_pressure_state_for_testing();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,372 +0,0 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::error::ProxyError;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, OnceLock, Mutex};
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
fn lookup_test_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_me2c_quota_counts_bytes_exactly_once() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-positive-{}", std::process::id());
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[1, 2, 3, 4, 5]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(64),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_001,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 5);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_held_crossmode_lock_blocks_me2c_write() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-negative-{}", std::process::id());
|
||||
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
let _held = lock.try_lock().expect("lock must be held");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xFE]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(16),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_101,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(blocked.is_err());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_zero_quota_zero_payload_is_fail_closed() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-edge-{}", std::process::id());
|
||||
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::new(),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(0),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_201,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_parallel_me2c_race_falls_back_to_quota_error() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("quota-middle-ext-blackhat-{}", std::process::id());
|
||||
let quota = 64u64;
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for i in 0..256u64 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let lock = Arc::clone(&lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
|
||||
set.spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let payload = vec![((i & 0xFF) as u8); (i % 4 + 1) as usize];
|
||||
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(payload),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(quota),
|
||||
0,
|
||||
Some(&lock),
|
||||
bytes_me2c.as_ref(),
|
||||
70_301 + i,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
let mut succeeded = 0usize;
|
||||
while let Some(done) = set.join_next().await {
|
||||
match done.expect("task must not panic") {
|
||||
Ok(_) => succeeded += 1,
|
||||
Err(ProxyError::DataQuotaExceeded { .. }) => {}
|
||||
Err(other) => panic!("unexpected error {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(&user), bytes_me2c.load(Ordering::Relaxed));
|
||||
assert!(stats.get_user_total_octets(&user) <= quota);
|
||||
assert!(succeeded <= quota as usize);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_shared_prefetched_lock_blocks_then_releases_writer() {
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-integration-{}", std::process::id());
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
let held = lock
|
||||
.try_lock()
|
||||
.expect("integration test must hold prefetched lock first");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xA1]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(8),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_360,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
assert!(blocked.is_err());
|
||||
|
||||
drop(held);
|
||||
|
||||
let after_release = timeout(
|
||||
Duration::from_millis(150),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xA2]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(8),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_361,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await
|
||||
.expect("writer should progress once the shared lock is released");
|
||||
|
||||
assert!(after_release.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_small_payloads_toggle_lock_state_stays_consistent() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-fuzz-{}", std::process::id());
|
||||
let mut seed = 0xCAFE_BABE_1234u64;
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
for case in 0..48u32 {
|
||||
seed ^= seed << 5;
|
||||
seed ^= seed >> 12;
|
||||
seed ^= seed << 13;
|
||||
let hold = (seed & 0x1) == 0;
|
||||
|
||||
let lock = Arc::new(AsyncMutex::new(()));
|
||||
let maybe_guard = if hold {
|
||||
Some(lock.try_lock().unwrap())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
let result = timeout(
|
||||
Duration::from_millis(30),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(vec![(seed & 0xFF) as u8; ((seed as usize % 5) + 1)]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(128),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
70_401 + case as u64,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
if hold {
|
||||
assert!(result.is_err());
|
||||
} else {
|
||||
assert!(result.unwrap().is_ok());
|
||||
}
|
||||
|
||||
drop(maybe_guard);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_parallel_free_users_during_held_user_lock_maintains_liveness() {
|
||||
let _guard = lookup_test_lock().lock().unwrap();
|
||||
let held = Arc::new(AsyncMutex::new(()));
|
||||
let _held_guard = held.try_lock().unwrap();
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for i in 0..48u64 {
|
||||
set.spawn(async move {
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-middle-ext-stress-free-{i}");
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let free_lock = Arc::new(AsyncMutex::new(()));
|
||||
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xEE]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(1),
|
||||
0,
|
||||
Some(&free_lock),
|
||||
&bytes_me2c,
|
||||
70_500 + i,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
timeout(Duration::from_secs(2), async {
|
||||
while let Some(task) = set.join_next().await {
|
||||
task.unwrap().unwrap();
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -1,131 +0,0 @@
|
|||
use super::*;
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn saturation_uses_stable_overflow_lock_without_cache_growth() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let prefix = format!("middle-quota-held-{}", std::process::id());
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
|
||||
}
|
||||
|
||||
assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX);
|
||||
|
||||
let user = format!("middle-quota-overflow-{}", std::process::id());
|
||||
let first = quota_user_lock(&user);
|
||||
let second = quota_user_lock(&user);
|
||||
|
||||
assert!(
|
||||
Arc::ptr_eq(&first, &second),
|
||||
"overflow user must get deterministic same lock while cache is saturated"
|
||||
);
|
||||
assert_eq!(
|
||||
map.len(),
|
||||
QUOTA_USER_LOCKS_MAX,
|
||||
"overflow path must not grow bounded lock map"
|
||||
);
|
||||
assert!(
|
||||
map.get(&user).is_none(),
|
||||
"overflow user should stay outside bounded lock map under saturation"
|
||||
);
|
||||
|
||||
drop(retained);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn overflow_striping_keeps_different_users_distributed() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let prefix = format!("middle-quota-dist-held-{}", std::process::id());
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
|
||||
}
|
||||
|
||||
let a = quota_user_lock("middle-overflow-user-a");
|
||||
let b = quota_user_lock("middle-overflow-user-b");
|
||||
let c = quota_user_lock("middle-overflow-user-c");
|
||||
|
||||
let distinct = [
|
||||
Arc::as_ptr(&a) as usize,
|
||||
Arc::as_ptr(&b) as usize,
|
||||
Arc::as_ptr(&c) as usize,
|
||||
]
|
||||
.iter()
|
||||
.copied()
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
.len();
|
||||
|
||||
assert!(
|
||||
distinct >= 2,
|
||||
"striped overflow lock set should avoid collapsing all users to one lock"
|
||||
);
|
||||
|
||||
drop(retained);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reclaim_path_caches_new_user_after_stale_entries_drop() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let prefix = format!("middle-quota-reclaim-held-{}", std::process::id());
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
|
||||
}
|
||||
|
||||
drop(retained);
|
||||
|
||||
let user = format!("middle-quota-reclaim-user-{}", std::process::id());
|
||||
let got = quota_user_lock(&user);
|
||||
assert!(map.get(&user).is_some());
|
||||
assert!(
|
||||
Arc::strong_count(&got) >= 2,
|
||||
"after reclaim, lock should be held both by caller and map"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn overflow_path_same_user_is_stable_across_parallel_threads() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!(
|
||||
"middle-quota-thread-held-{}-{idx}",
|
||||
std::process::id()
|
||||
)));
|
||||
}
|
||||
|
||||
let user = format!("middle-quota-overflow-thread-user-{}", std::process::id());
|
||||
let mut workers = Vec::new();
|
||||
for _ in 0..32 {
|
||||
let user = user.clone();
|
||||
workers.push(std::thread::spawn(move || quota_user_lock(&user)));
|
||||
}
|
||||
|
||||
let first = workers
|
||||
.remove(0)
|
||||
.join()
|
||||
.expect("thread must return lock handle");
|
||||
for worker in workers {
|
||||
let got = worker.join().expect("thread must return lock handle");
|
||||
assert!(
|
||||
Arc::ptr_eq(&first, &got),
|
||||
"same overflow user should resolve to one striped lock even under contention"
|
||||
);
|
||||
}
|
||||
|
||||
drop(retained);
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,399 +0,0 @@
|
|||
use super::*;
|
||||
use crate::crypto::{AesCtr, SecureRandom};
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::CryptoWriter;
|
||||
use bytes::Bytes;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn make_crypto_writer<W>(writer: W) -> CryptoWriter<W>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let key = [0u8; 32];
|
||||
let iv = 0u128;
|
||||
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
|
||||
}
|
||||
|
||||
fn lookup_counter_test_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_prefetched_cross_mode_lock_multi_frame_accounting_is_exact() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-extreme-positive-{}", std::process::id());
|
||||
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
for idx in 0..12u64 {
|
||||
let payload = vec![0x5A; ((idx % 4) + 1) as usize];
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(payload),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(512),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
31_000 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
|
||||
0,
|
||||
"prefetched lock path must avoid hot-path registry lookups"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_user_total_octets(&user),
|
||||
bytes_me2c.load(Ordering::Relaxed),
|
||||
"forensics and quota accounting must remain synchronized"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_held_prefetched_lock_blocks_writer_without_accounting_mutation() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-extreme-negative-{}", std::process::id());
|
||||
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold lock before calling ME->C writer");
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let blocked = timeout(
|
||||
Duration::from_millis(25),
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[1, 2, 3]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(64),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
31_100,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(blocked.is_err());
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_zero_quota_and_zero_payload_is_fail_closed() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-extreme-edge-{}", std::process::id());
|
||||
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::new(),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(0),
|
||||
0,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
31_200,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_blackhat_parallel_quota_race_never_overshoots_soft_cap() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("quota-extreme-blackhat-{}", std::process::id());
|
||||
let quota = 80u64;
|
||||
let overshoot = 7u64;
|
||||
let soft_limit = quota + overshoot;
|
||||
let lock = Arc::new(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user));
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for idx in 0..256u64 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let lock = Arc::clone(&lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
|
||||
set.spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let len = ((idx % 5) + 1) as usize;
|
||||
let payload = vec![0xAA; len];
|
||||
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(payload),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(quota),
|
||||
overshoot,
|
||||
Some(&lock),
|
||||
bytes_me2c.as_ref(),
|
||||
31_300 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(done) = set.join_next().await {
|
||||
match done.expect("task must not panic") {
|
||||
Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }) => {}
|
||||
Err(other) => panic!("unexpected error variant under black-hat race: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
let total = stats.get_user_total_octets(&user);
|
||||
assert!(
|
||||
total <= soft_limit,
|
||||
"parallel adversarial race must stay under soft cap"
|
||||
);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), total);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_without_prefetched_lock_uses_registry_lookup_path() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-extreme-integration-{}", std::process::id());
|
||||
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
|
||||
for idx in 0..3u64 {
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0x41]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(16),
|
||||
0,
|
||||
None,
|
||||
&bytes_me2c,
|
||||
31_400 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
|
||||
3,
|
||||
"control path should perform one lock-registry lookup per call"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_quota_matrix_preserves_fail_closed_accounting() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Stats::new();
|
||||
let user = format!("quota-extreme-fuzz-{}", std::process::id());
|
||||
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let bytes_me2c = AtomicU64::new(0);
|
||||
let mut seed = 0xA11C_55EE_2026_0323u64;
|
||||
|
||||
for idx in 0..512u64 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let quota = 24 + (seed & 0x3f);
|
||||
let overshoot = (seed >> 13) & 0x0f;
|
||||
let len = ((seed >> 19) & 0x07) + 1;
|
||||
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
let before = stats.get_user_total_octets(&user);
|
||||
|
||||
let result = process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from(vec![0x11; len as usize]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
&stats,
|
||||
&user,
|
||||
Some(quota),
|
||||
overshoot,
|
||||
Some(&lock),
|
||||
&bytes_me2c,
|
||||
31_500 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
|
||||
let after = stats.get_user_total_octets(&user);
|
||||
if result.is_ok() {
|
||||
assert!(after >= before);
|
||||
} else {
|
||||
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert_eq!(after, before);
|
||||
}
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), after);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_prefetched_lock_high_fanout_exact_quota_success_count() {
|
||||
let _guard = lookup_counter_test_lock()
|
||||
.lock()
|
||||
.unwrap_or_else(|poison| poison.into_inner());
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("quota-extreme-stress-{}", std::process::id());
|
||||
let quota = 96u64;
|
||||
let lock: Arc<AsyncMutex<()>> = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let bytes_me2c = Arc::new(AtomicU64::new(0));
|
||||
|
||||
crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests();
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for idx in 0..384u64 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let lock = Arc::clone(&lock);
|
||||
let bytes_me2c = Arc::clone(&bytes_me2c);
|
||||
|
||||
set.spawn(async move {
|
||||
let mut writer = make_crypto_writer(tokio::io::sink());
|
||||
let mut frame_buf = Vec::new();
|
||||
process_me_writer_response_with_cross_mode_lock(
|
||||
MeResponse::Data {
|
||||
flags: 0,
|
||||
data: Bytes::from_static(&[0xFF]),
|
||||
},
|
||||
&mut writer,
|
||||
ProtoTag::Intermediate,
|
||||
&SecureRandom::new(),
|
||||
&mut frame_buf,
|
||||
stats.as_ref(),
|
||||
&user,
|
||||
Some(quota),
|
||||
0,
|
||||
Some(&lock),
|
||||
bytes_me2c.as_ref(),
|
||||
31_600 + idx,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
let mut success = 0usize;
|
||||
while let Some(done) = set.join_next().await {
|
||||
match done.expect("task must not panic") {
|
||||
Ok(_) => success += 1,
|
||||
Err(ProxyError::DataQuotaExceeded { .. }) => {}
|
||||
Err(other) => panic!("unexpected error variant in stress fanout: {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(success, quota as usize);
|
||||
assert_eq!(stats.get_user_total_octets(&user), quota);
|
||||
assert_eq!(bytes_me2c.load(Ordering::Relaxed), quota);
|
||||
assert_eq!(
|
||||
crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user),
|
||||
0,
|
||||
"stress prefetched path must not use lock registry lookups"
|
||||
);
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -217,7 +217,9 @@ async fn adversarial_lockstep_alternating_attack_under_jitter_closes() {
|
|||
}
|
||||
}
|
||||
|
||||
writer_task.await.expect("writer jitter task must not panic");
|
||||
writer_task
|
||||
.await
|
||||
.expect("writer jitter task must not panic");
|
||||
assert!(closed, "alternating attack must close before EOF");
|
||||
});
|
||||
}
|
||||
|
|
@ -247,7 +249,10 @@ async fn integration_mixed_population_attackers_close_benign_survive() {
|
|||
plaintext.push(0x01);
|
||||
plaintext.extend_from_slice(&[n, n, n, n]);
|
||||
}
|
||||
writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap();
|
||||
writer
|
||||
.write_all(&encrypt_for_reader(&plaintext))
|
||||
.await
|
||||
.unwrap();
|
||||
drop(writer);
|
||||
|
||||
let mut closed = false;
|
||||
|
|
@ -279,7 +284,10 @@ async fn integration_mixed_population_attackers_close_benign_survive() {
|
|||
}
|
||||
plaintext.push(0x01);
|
||||
plaintext.extend_from_slice(&payload);
|
||||
writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap();
|
||||
writer
|
||||
.write_all(&encrypt_for_reader(&plaintext))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let got = read_once(
|
||||
&mut crypto_reader,
|
||||
|
|
@ -329,7 +337,10 @@ async fn light_fuzz_parallel_patterns_no_hang_or_panic() {
|
|||
}
|
||||
}
|
||||
|
||||
writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap();
|
||||
writer
|
||||
.write_all(&encrypt_for_reader(&plaintext))
|
||||
.await
|
||||
.unwrap();
|
||||
drop(writer);
|
||||
|
||||
for _ in 0..320 {
|
||||
|
|
|
|||
|
|
@ -51,7 +51,9 @@ fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
|
|||
fn append_tiny_frame(plaintext: &mut Vec<u8>, proto: ProtoTag) {
|
||||
match proto {
|
||||
ProtoTag::Abridged => plaintext.push(0x00),
|
||||
ProtoTag::Intermediate | ProtoTag::Secure => plaintext.extend_from_slice(&0u32.to_le_bytes()),
|
||||
ProtoTag::Intermediate | ProtoTag::Secure => {
|
||||
plaintext.extend_from_slice(&0u32.to_le_bytes())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -206,7 +208,11 @@ async fn intermediate_chunked_alternating_attack_closes_before_eof() {
|
|||
let mut plaintext = Vec::with_capacity(8 * 200);
|
||||
for n in 0..180u8 {
|
||||
append_tiny_frame(&mut plaintext, ProtoTag::Intermediate);
|
||||
append_real_frame(&mut plaintext, ProtoTag::Intermediate, [n, n ^ 1, n ^ 2, n ^ 3]);
|
||||
append_real_frame(
|
||||
&mut plaintext,
|
||||
ProtoTag::Intermediate,
|
||||
[n, n ^ 1, n ^ 2, n ^ 3],
|
||||
);
|
||||
}
|
||||
let encrypted = encrypt_for_reader(&plaintext);
|
||||
|
||||
|
|
@ -240,7 +246,9 @@ async fn intermediate_chunked_alternating_attack_closes_before_eof() {
|
|||
}
|
||||
}
|
||||
|
||||
writer_task.await.expect("intermediate writer task must not panic");
|
||||
writer_task
|
||||
.await
|
||||
.expect("intermediate writer task must not panic");
|
||||
assert!(closed, "intermediate alternating attack must fail closed");
|
||||
}
|
||||
|
||||
|
|
@ -290,7 +298,9 @@ async fn secure_chunked_alternating_attack_closes_before_eof() {
|
|||
}
|
||||
}
|
||||
|
||||
writer_task.await.expect("secure writer task must not panic");
|
||||
writer_task
|
||||
.await
|
||||
.expect("secure writer task must not panic");
|
||||
assert!(closed, "secure alternating attack must fail closed");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ use super::*;
|
|||
use crate::crypto::AesCtr;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::{BufferPool, CryptoReader};
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::time::Instant;
|
||||
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
|
||||
|
||||
|
|
@ -156,7 +156,10 @@ fn alternating_one_to_one_closes_with_bounded_real_frame_count() {
|
|||
}
|
||||
let (closed_at, _, reals) = simulate_tiny_debt_pattern(&pattern, pattern.len());
|
||||
assert!(closed_at.is_some());
|
||||
assert!(reals <= 80, "expected bounded real frames before close, got {reals}");
|
||||
assert!(
|
||||
reals <= 80,
|
||||
"expected bounded real frames before close, got {reals}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -183,7 +186,10 @@ fn alternating_one_to_seven_eventually_closes() {
|
|||
}
|
||||
}
|
||||
let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
|
||||
assert!(closed_at.is_some(), "1:7 tiny-to-real must eventually close");
|
||||
assert!(
|
||||
closed_at.is_some(),
|
||||
"1:7 tiny-to-real must eventually close"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ use super::*;
|
|||
use crate::crypto::AesCtr;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::{BufferPool, CryptoReader};
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::time::Instant;
|
||||
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
|
||||
|
||||
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
|
||||
where
|
||||
|
|
|
|||
|
|
@ -1,108 +0,0 @@
|
|||
use super::*;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
fn cross_mode_lock_test_guard() -> std::sync::MutexGuard<'static, ()> {
|
||||
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
TEST_LOCK
|
||||
.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn same_user_returns_same_lock_identity() {
|
||||
let _guard = cross_mode_lock_test_guard();
|
||||
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
locks.clear();
|
||||
|
||||
let a = cross_mode_quota_user_lock("cross-mode-same-user");
|
||||
let b = cross_mode_quota_user_lock("cross-mode-same-user");
|
||||
|
||||
assert!(
|
||||
Arc::ptr_eq(&a, &b),
|
||||
"same user must reuse a stable lock identity"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn saturation_overflow_path_returns_stable_striped_lock_without_cache_growth() {
|
||||
let _guard = cross_mode_lock_test_guard();
|
||||
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
locks.clear();
|
||||
|
||||
let prefix = format!("cross-mode-saturated-{}", std::process::id());
|
||||
let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..CROSS_MODE_QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}")));
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
locks.len(),
|
||||
CROSS_MODE_QUOTA_USER_LOCKS_MAX,
|
||||
"lock cache must be saturated for overflow check"
|
||||
);
|
||||
|
||||
let overflow_user = format!("cross-mode-overflow-{}", std::process::id());
|
||||
let overflow_a = cross_mode_quota_user_lock(&overflow_user);
|
||||
let overflow_b = cross_mode_quota_user_lock(&overflow_user);
|
||||
|
||||
assert_eq!(
|
||||
locks.len(),
|
||||
CROSS_MODE_QUOTA_USER_LOCKS_MAX,
|
||||
"overflow path must not grow bounded lock cache"
|
||||
);
|
||||
assert!(
|
||||
locks.get(&overflow_user).is_none(),
|
||||
"overflow user must stay on striped fallback while cache is saturated"
|
||||
);
|
||||
assert!(
|
||||
Arc::ptr_eq(&overflow_a, &overflow_b),
|
||||
"overflow user must receive a stable striped lock across repeated lookups"
|
||||
);
|
||||
|
||||
drop(retained);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reclaim_drops_stale_entries_but_preserves_active_user_lock_identity() {
|
||||
let _guard = cross_mode_lock_test_guard();
|
||||
let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
locks.clear();
|
||||
|
||||
let prefix = format!("cross-mode-reclaim-{}", std::process::id());
|
||||
let protected_user = format!("{prefix}-protected");
|
||||
|
||||
let protected_lock = cross_mode_quota_user_lock(&protected_user);
|
||||
let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1));
|
||||
for idx in 0..(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1)) {
|
||||
retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}")));
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
locks.len(),
|
||||
CROSS_MODE_QUOTA_USER_LOCKS_MAX,
|
||||
"fixture must saturate lock cache before reclaim path is exercised"
|
||||
);
|
||||
|
||||
drop(retained);
|
||||
|
||||
let newcomer_user = format!("{prefix}-newcomer");
|
||||
let _newcomer = cross_mode_quota_user_lock(&newcomer_user);
|
||||
|
||||
assert!(
|
||||
locks.get(&protected_user).is_some(),
|
||||
"active protected user must remain cache-resident after reclaim"
|
||||
);
|
||||
let locked = locks
|
||||
.get(&protected_user)
|
||||
.expect("protected user must remain in map after reclaim");
|
||||
assert!(
|
||||
Arc::ptr_eq(locked.value(), &protected_lock),
|
||||
"reclaim must not swap active user lock identity"
|
||||
);
|
||||
assert!(
|
||||
locks.get(&newcomer_user).is_some(),
|
||||
"newcomer should become cacheable after stale entries are reclaimed"
|
||||
);
|
||||
}
|
||||
|
|
@ -78,7 +78,8 @@ async fn relay_hol_blocking_prevention_regression() {
|
|||
async fn relay_quota_mid_session_cutoff() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-mid-user";
|
||||
let quota = 5000;
|
||||
let quota = 5000u64;
|
||||
let c2s_buf_size = 1024usize;
|
||||
|
||||
let (client_peer, relay_client) = duplex(8192);
|
||||
let (relay_server, server_peer) = duplex(8192);
|
||||
|
|
@ -93,7 +94,7 @@ async fn relay_quota_mid_session_cutoff() {
|
|||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
1024,
|
||||
c2s_buf_size,
|
||||
1024,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
|
|
@ -120,9 +121,25 @@ async fn relay_quota_mid_session_cutoff() {
|
|||
other => panic!("Expected DataQuotaExceeded error, got: {:?}", other),
|
||||
}
|
||||
|
||||
let mut small_buf = [0u8; 1];
|
||||
let n = sp_reader.read(&mut small_buf).await.unwrap();
|
||||
assert_eq!(n, 0, "Server must see EOF after quota reached");
|
||||
let mut overshoot_bytes = 0usize;
|
||||
let mut buf = [0u8; 256];
|
||||
loop {
|
||||
match timeout(Duration::from_millis(20), sp_reader.read(&mut buf)).await {
|
||||
Ok(Ok(0)) => break,
|
||||
Ok(Ok(n)) => overshoot_bytes = overshoot_bytes.saturating_add(n),
|
||||
Ok(Err(e)) => panic!("server read must not fail after relay cutoff: {e}"),
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
overshoot_bytes <= c2s_buf_size,
|
||||
"post-write cutoff may leak at most one C->S chunk after boundary, got {overshoot_bytes}"
|
||||
);
|
||||
assert!(
|
||||
stats.get_user_quota_used(user) <= quota.saturating_add(c2s_buf_size as u64),
|
||||
"accounted quota must remain bounded by one in-flight chunk overshoot"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,243 @@
|
|||
use super::*;
|
||||
use std::collections::VecDeque;
|
||||
use std::io;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncWrite, AsyncWriteExt};
|
||||
use tokio::time::Instant;
|
||||
|
||||
struct ScriptedWriter {
|
||||
scripted_writes: Arc<Mutex<VecDeque<usize>>>,
|
||||
write_calls: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl ScriptedWriter {
|
||||
fn new(script: &[usize], write_calls: Arc<AtomicUsize>) -> Self {
|
||||
Self {
|
||||
scripted_writes: Arc::new(Mutex::new(script.iter().copied().collect())),
|
||||
write_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for ScriptedWriter {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let this = self.get_mut();
|
||||
this.write_calls.fetch_add(1, Ordering::Relaxed);
|
||||
let planned = this
|
||||
.scripted_writes
|
||||
.lock()
|
||||
.unwrap_or_else(|poisoned| poisoned.into_inner())
|
||||
.pop_front()
|
||||
.unwrap_or(buf.len());
|
||||
Poll::Ready(Ok(planned.min(buf.len())))
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
fn make_stats_io_with_script(
|
||||
user: &str,
|
||||
quota_limit: u64,
|
||||
precharged_quota: u64,
|
||||
script: &[usize],
|
||||
) -> (
|
||||
StatsIo<ScriptedWriter>,
|
||||
Arc<Stats>,
|
||||
Arc<AtomicUsize>,
|
||||
Arc<AtomicBool>,
|
||||
) {
|
||||
let stats = Arc::new(Stats::new());
|
||||
if precharged_quota > 0 {
|
||||
let user_stats = stats.get_or_create_user_stats_handle(user);
|
||||
stats.quota_charge_post_write(user_stats.as_ref(), precharged_quota);
|
||||
}
|
||||
|
||||
let write_calls = Arc::new(AtomicUsize::new(0));
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||
let io = StatsIo::new(
|
||||
ScriptedWriter::new(script, write_calls.clone()),
|
||||
Arc::new(SharedCounters::new()),
|
||||
stats.clone(),
|
||||
user.to_string(),
|
||||
Some(quota_limit),
|
||||
quota_exceeded.clone(),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
(io, stats, write_calls, quota_exceeded)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn direct_partial_write_charges_only_committed_bytes_without_double_charge() {
|
||||
let user = "direct-partial-charge-user";
|
||||
let (mut io, stats, write_calls, quota_exceeded) =
|
||||
make_stats_io_with_script(user, 1_048_576, 0, &[8 * 1024, 8 * 1024, 48 * 1024]);
|
||||
let payload = vec![0xAB; 64 * 1024];
|
||||
|
||||
let n1 = io
|
||||
.write(&payload)
|
||||
.await
|
||||
.expect("first partial write must succeed");
|
||||
let n2 = io
|
||||
.write(&payload)
|
||||
.await
|
||||
.expect("second partial write must succeed");
|
||||
let n3 = io.write(&payload).await.expect("tail write must succeed");
|
||||
|
||||
assert_eq!(n1, 8 * 1024);
|
||||
assert_eq!(n2, 8 * 1024);
|
||||
assert_eq!(n3, 48 * 1024);
|
||||
assert_eq!(write_calls.load(Ordering::Relaxed), 3);
|
||||
assert_eq!(
|
||||
stats.get_user_quota_used(user),
|
||||
(n1 + n2 + n3) as u64,
|
||||
"quota accounting must follow committed bytes only"
|
||||
);
|
||||
assert_eq!(
|
||||
stats.get_user_total_octets(user),
|
||||
(n1 + n2 + n3) as u64,
|
||||
"telemetry octets should match committed bytes on successful writes"
|
||||
);
|
||||
assert!(
|
||||
!quota_exceeded.load(Ordering::Acquire),
|
||||
"quota flag should stay false under large remaining budget"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn direct_hybrid_branch_selection_matches_contract() {
|
||||
let near_limit = 256 * 1024u64;
|
||||
let near_remaining = 32 * 1024u64;
|
||||
let (mut near_io, _stats, _calls, _flag) = make_stats_io_with_script(
|
||||
"direct-near-limit-hard-check-user",
|
||||
near_limit,
|
||||
near_limit - near_remaining,
|
||||
&[4 * 1024],
|
||||
);
|
||||
let near_payload = vec![0x11; 4 * 1024];
|
||||
let near_written = near_io
|
||||
.write(&near_payload)
|
||||
.await
|
||||
.expect("near-limit write must succeed");
|
||||
assert_eq!(near_written, 4 * 1024);
|
||||
assert_eq!(
|
||||
near_io.quota_bytes_since_check, 0,
|
||||
"near-limit branch must go through immediate hard check"
|
||||
);
|
||||
|
||||
let (mut far_small_io, _stats, _calls, _flag) =
|
||||
make_stats_io_with_script("direct-far-small-amortized-user", 1_048_576, 0, &[4 * 1024]);
|
||||
let far_small_payload = vec![0x22; 4 * 1024];
|
||||
let far_small_written = far_small_io
|
||||
.write(&far_small_payload)
|
||||
.await
|
||||
.expect("small far-from-limit write must succeed");
|
||||
assert_eq!(far_small_written, 4 * 1024);
|
||||
assert_eq!(
|
||||
far_small_io.quota_bytes_since_check,
|
||||
4 * 1024,
|
||||
"small far-from-limit write must go through amortized path"
|
||||
);
|
||||
|
||||
let (mut far_large_io, _stats, _calls, _flag) = make_stats_io_with_script(
|
||||
"direct-far-large-hard-check-user",
|
||||
1_048_576,
|
||||
0,
|
||||
&[32 * 1024],
|
||||
);
|
||||
let far_large_payload = vec![0x33; 32 * 1024];
|
||||
let far_large_written = far_large_io
|
||||
.write(&far_large_payload)
|
||||
.await
|
||||
.expect("large write must succeed");
|
||||
assert_eq!(far_large_written, 32 * 1024);
|
||||
assert_eq!(
|
||||
far_large_io.quota_bytes_since_check, 0,
|
||||
"large write must force immediate hard check even far from limit"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn remaining_before_zero_rejects_without_calling_inner_writer() {
|
||||
let user = "direct-zero-remaining-user";
|
||||
let limit = 8u64;
|
||||
let (mut io, stats, write_calls, quota_exceeded) =
|
||||
make_stats_io_with_script(user, limit, limit, &[1]);
|
||||
|
||||
let err = io
|
||||
.write(&[0x44])
|
||||
.await
|
||||
.expect_err("write must fail when remaining quota is zero");
|
||||
|
||||
assert!(
|
||||
is_quota_io_error(&err),
|
||||
"zero-remaining gate must return typed quota I/O error"
|
||||
);
|
||||
assert_eq!(
|
||||
write_calls.load(Ordering::Relaxed),
|
||||
0,
|
||||
"inner poll_write must not be called when remaining quota is zero"
|
||||
);
|
||||
assert!(
|
||||
quota_exceeded.load(Ordering::Acquire),
|
||||
"zero-remaining gate must set exceeded flag"
|
||||
);
|
||||
assert_eq!(stats.get_user_quota_used(user), limit);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn exceeded_flag_blocks_following_poll_before_inner_write() {
|
||||
let user = "direct-exceeded-visibility-user";
|
||||
let (mut io, stats, write_calls, quota_exceeded) =
|
||||
make_stats_io_with_script(user, 1, 0, &[1, 1]);
|
||||
|
||||
let first = io
|
||||
.write(&[0x55])
|
||||
.await
|
||||
.expect("first byte should consume remaining quota");
|
||||
assert_eq!(first, 1);
|
||||
assert!(
|
||||
quota_exceeded.load(Ordering::Acquire),
|
||||
"hard check should store quota_exceeded after boundary hit"
|
||||
);
|
||||
|
||||
let second = io
|
||||
.write(&[0x66])
|
||||
.await
|
||||
.expect_err("next write must be rejected by early exceeded gate");
|
||||
assert!(
|
||||
is_quota_io_error(&second),
|
||||
"following write must fail with typed quota error"
|
||||
);
|
||||
assert_eq!(
|
||||
write_calls.load(Ordering::Relaxed),
|
||||
1,
|
||||
"second write must be cut before touching inner writer"
|
||||
);
|
||||
assert_eq!(stats.get_user_quota_used(user), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adaptive_interval_clamp_matches_contract() {
|
||||
assert_eq!(quota_adaptive_interval_bytes(0), 4 * 1024);
|
||||
assert_eq!(quota_adaptive_interval_bytes(2 * 1024), 4 * 1024);
|
||||
assert_eq!(quota_adaptive_interval_bytes(32 * 1024), 16 * 1024);
|
||||
assert_eq!(quota_adaptive_interval_bytes(256 * 1024), 64 * 1024);
|
||||
|
||||
assert!(should_immediate_quota_check(32 * 1024, 4 * 1024));
|
||||
assert!(should_immediate_quota_check(1_048_576, 32 * 1024));
|
||||
assert!(!should_immediate_quota_check(1_048_576, 4 * 1024));
|
||||
}
|
||||
|
|
@ -1,267 +0,0 @@
|
|||
use super::relay_bidirectional;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_same_user_pipeline_stalls_while_middle_lock_is_held() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("relay-pipeline-stall-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared cross-mode lock");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let (mut client_peer, relay_client) = duplex(1024);
|
||||
let (relay_server, mut server_peer) = duplex(1024);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_user = user.clone();
|
||||
let relay_stats = Arc::clone(&stats);
|
||||
let relay_task = tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
&relay_user,
|
||||
relay_stats,
|
||||
Some(1024),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
server_peer
|
||||
.write_all(&[0xA1])
|
||||
.await
|
||||
.expect("server write should enqueue while relay is stalled");
|
||||
|
||||
let mut one = [0u8; 1];
|
||||
let blocked_read = timeout(Duration::from_millis(40), client_peer.read_exact(&mut one)).await;
|
||||
assert!(
|
||||
blocked_read.is_err(),
|
||||
"same-user relay must remain blocked while cross-mode lock is held"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_millis(400), client_peer.read_exact(&mut one))
|
||||
.await
|
||||
.expect("blocked relay must resume after cross-mode lock release")
|
||||
.expect("resumed relay must deliver queued byte");
|
||||
assert_eq!(one, [0xA1]);
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(1), relay_task)
|
||||
.await
|
||||
.expect("relay task must complete")
|
||||
.expect("relay task must not panic");
|
||||
assert!(relay_result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_other_user_pipeline_progresses_while_blocked_user_is_stalled() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let blocked_user = format!("relay-pipeline-blocked-{}", std::process::id());
|
||||
let free_user = format!("relay-pipeline-free-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold blocked user's shared cross-mode lock");
|
||||
|
||||
let stats_blocked = Arc::new(Stats::new());
|
||||
let stats_free = Arc::new(Stats::new());
|
||||
|
||||
let (mut blocked_client, blocked_relay_client) = duplex(1024);
|
||||
let (blocked_relay_server, mut blocked_server) = duplex(1024);
|
||||
let (blocked_client_reader, blocked_client_writer) = tokio::io::split(blocked_relay_client);
|
||||
let (blocked_server_reader, blocked_server_writer) = tokio::io::split(blocked_relay_server);
|
||||
|
||||
let (mut free_client, free_relay_client) = duplex(1024);
|
||||
let (free_relay_server, mut free_server) = duplex(1024);
|
||||
let (free_client_reader, free_client_writer) = tokio::io::split(free_relay_client);
|
||||
let (free_server_reader, free_server_writer) = tokio::io::split(free_relay_server);
|
||||
|
||||
let blocked_task = {
|
||||
let user = blocked_user.clone();
|
||||
let stats = Arc::clone(&stats_blocked);
|
||||
tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
blocked_client_reader,
|
||||
blocked_client_writer,
|
||||
blocked_server_reader,
|
||||
blocked_server_writer,
|
||||
256,
|
||||
256,
|
||||
&user,
|
||||
stats,
|
||||
Some(1024),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
let free_task = {
|
||||
let user = free_user.clone();
|
||||
let stats = Arc::clone(&stats_free);
|
||||
tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
free_client_reader,
|
||||
free_client_writer,
|
||||
free_server_reader,
|
||||
free_server_writer,
|
||||
256,
|
||||
256,
|
||||
&user,
|
||||
stats,
|
||||
Some(1024),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
})
|
||||
};
|
||||
|
||||
blocked_server
|
||||
.write_all(&[0xB1])
|
||||
.await
|
||||
.expect("blocked user server write should queue");
|
||||
free_server
|
||||
.write_all(&[0xC1])
|
||||
.await
|
||||
.expect("free user server write should queue");
|
||||
|
||||
let mut blocked_buf = [0u8; 1];
|
||||
let mut free_buf = [0u8; 1];
|
||||
|
||||
let blocked_stalled = timeout(
|
||||
Duration::from_millis(40),
|
||||
blocked_client.read_exact(&mut blocked_buf),
|
||||
)
|
||||
.await;
|
||||
assert!(
|
||||
blocked_stalled.is_err(),
|
||||
"blocked user must remain stalled while its lock is held"
|
||||
);
|
||||
|
||||
timeout(Duration::from_millis(250), free_client.read_exact(&mut free_buf))
|
||||
.await
|
||||
.expect("free user must make progress while other user is blocked")
|
||||
.expect("free user read must succeed");
|
||||
assert_eq!(free_buf, [0xC1]);
|
||||
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_millis(400), blocked_client.read_exact(&mut blocked_buf))
|
||||
.await
|
||||
.expect("blocked user must resume after release")
|
||||
.expect("blocked user resumed read must succeed");
|
||||
assert_eq!(blocked_buf, [0xB1]);
|
||||
|
||||
drop(blocked_client);
|
||||
drop(blocked_server);
|
||||
drop(free_client);
|
||||
drop(free_server);
|
||||
|
||||
assert!(
|
||||
timeout(Duration::from_secs(1), blocked_task)
|
||||
.await
|
||||
.expect("blocked relay task must complete")
|
||||
.expect("blocked relay task must not panic")
|
||||
.is_ok()
|
||||
);
|
||||
assert!(
|
||||
timeout(Duration::from_secs(1), free_task)
|
||||
.await
|
||||
.expect("free relay task must complete")
|
||||
.expect("free relay task must not panic")
|
||||
.is_ok()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_jittered_hold_release_cycles_preserve_pipeline_liveness() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut seed = 0x5EED_C0DE_2026_0323u64;
|
||||
for round in 0..24u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold_ms = 2 + (seed % 10);
|
||||
let user = format!("relay-pipeline-fuzz-{}-{round}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock during fuzz round");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let (mut client_peer, relay_client) = duplex(1024);
|
||||
let (relay_server, mut server_peer) = duplex(1024);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_user = user.clone();
|
||||
let relay_stats = Arc::clone(&stats);
|
||||
let relay_task = tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
&relay_user,
|
||||
relay_stats,
|
||||
Some(1024),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
server_peer
|
||||
.write_all(&[0xD1])
|
||||
.await
|
||||
.expect("server write should queue in fuzz round");
|
||||
|
||||
let mut one = [0u8; 1];
|
||||
let stalled = timeout(Duration::from_millis(30), client_peer.read_exact(&mut one)).await;
|
||||
assert!(stalled.is_err(), "held phase must stall same-user relay");
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(hold_ms)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_millis(400), client_peer.read_exact(&mut one))
|
||||
.await
|
||||
.expect("released phase must resume same-user relay")
|
||||
.expect("released phase read must succeed");
|
||||
assert_eq!(one, [0xD1]);
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
assert!(
|
||||
timeout(Duration::from_secs(1), relay_task)
|
||||
.await
|
||||
.expect("fuzz relay task must complete")
|
||||
.expect("fuzz relay task must not panic")
|
||||
.is_ok()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,213 +0,0 @@
|
|||
use super::relay_bidirectional;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::sync::{Barrier, watch};
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn percentile_index(len: usize, percentile: usize) -> usize {
|
||||
((len * percentile) / 100).min(len.saturating_sub(1))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn micro_benchmark_pipeline_release_to_delivery_latency_stays_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let rounds = 64usize;
|
||||
let user = format!("relay-pipeline-latency-single-{}", std::process::id());
|
||||
let mut samples_ms = Vec::with_capacity(rounds);
|
||||
|
||||
for round in 0..rounds {
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared cross-mode lock before round");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let (mut client_peer, relay_client) = duplex(1024);
|
||||
let (relay_server, mut server_peer) = duplex(1024);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_user = user.clone();
|
||||
let relay_stats = Arc::clone(&stats);
|
||||
let relay_task = tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
&relay_user,
|
||||
relay_stats,
|
||||
Some(2048),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
server_peer
|
||||
.write_all(&[(round as u8) ^ 0xA5])
|
||||
.await
|
||||
.expect("server write should queue before release");
|
||||
|
||||
let release_at = Instant::now();
|
||||
drop(held_guard);
|
||||
|
||||
let mut one = [0u8; 1];
|
||||
timeout(Duration::from_millis(450), client_peer.read_exact(&mut one))
|
||||
.await
|
||||
.expect("client must receive queued byte after release")
|
||||
.expect("queued byte read must succeed");
|
||||
samples_ms.push(release_at.elapsed().as_millis() as u64);
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(1), relay_task)
|
||||
.await
|
||||
.expect("relay task must complete")
|
||||
.expect("relay task must not panic");
|
||||
assert!(relay_result.is_ok());
|
||||
}
|
||||
|
||||
samples_ms.sort_unstable();
|
||||
let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)];
|
||||
let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)];
|
||||
|
||||
assert!(
|
||||
p50_ms <= 45,
|
||||
"single-flow release latency p50 must stay bounded; p50_ms={p50_ms}, samples={samples_ms:?}"
|
||||
);
|
||||
assert!(
|
||||
p95_ms <= 130,
|
||||
"single-flow release latency p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_128_waiter_pipeline_release_latency_p95_stays_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let waiters = 128usize;
|
||||
let user = format!("relay-pipeline-latency-fanout-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared lock before fanout release benchmark");
|
||||
|
||||
let ready_barrier = Arc::new(Barrier::new(waiters + 1));
|
||||
let release_at = Arc::new(Mutex::new(None::<Instant>));
|
||||
let (release_tx, release_rx) = watch::channel(false);
|
||||
let mut tasks = Vec::with_capacity(waiters);
|
||||
|
||||
for idx in 0..waiters {
|
||||
let user = user.clone();
|
||||
let barrier = Arc::clone(&ready_barrier);
|
||||
let release_at = Arc::clone(&release_at);
|
||||
let mut release_rx = release_rx.clone();
|
||||
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let (mut client_peer, relay_client) = duplex(512);
|
||||
let (relay_server, mut server_peer) = duplex(512);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay_user = user;
|
||||
let relay_stats = Arc::clone(&stats);
|
||||
let relay_task = tokio::spawn(async move {
|
||||
relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
256,
|
||||
256,
|
||||
&relay_user,
|
||||
relay_stats,
|
||||
Some(2048),
|
||||
Arc::new(BufferPool::new()),
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
server_peer
|
||||
.write_all(&[(idx as u8) ^ 0x5A])
|
||||
.await
|
||||
.expect("fanout server write should queue before release");
|
||||
|
||||
barrier.wait().await;
|
||||
release_rx
|
||||
.changed()
|
||||
.await
|
||||
.expect("release signal should remain available");
|
||||
|
||||
let started = {
|
||||
let guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner());
|
||||
guard.expect("release timestamp must be populated before signal")
|
||||
};
|
||||
|
||||
let mut one = [0u8; 1];
|
||||
timeout(Duration::from_millis(900), client_peer.read_exact(&mut one))
|
||||
.await
|
||||
.expect("fanout waiter must receive queued byte after release")
|
||||
.expect("fanout waiter read must succeed");
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay_task)
|
||||
.await
|
||||
.expect("fanout relay task must complete")
|
||||
.expect("fanout relay task must not panic");
|
||||
assert!(relay_result.is_ok());
|
||||
|
||||
started.elapsed().as_millis() as u64
|
||||
}));
|
||||
}
|
||||
|
||||
ready_barrier.wait().await;
|
||||
{
|
||||
let mut guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner());
|
||||
*guard = Some(Instant::now());
|
||||
}
|
||||
drop(held_guard);
|
||||
release_tx
|
||||
.send(true)
|
||||
.expect("release broadcast must succeed");
|
||||
|
||||
let mut samples_ms = Vec::with_capacity(waiters);
|
||||
timeout(Duration::from_secs(8), async {
|
||||
for task in tasks {
|
||||
let elapsed = task.await.expect("fanout waiter must not panic");
|
||||
samples_ms.push(elapsed);
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("fanout benchmark must complete in bounded time");
|
||||
|
||||
samples_ms.sort_unstable();
|
||||
let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)];
|
||||
let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)];
|
||||
let max_ms = *samples_ms.last().unwrap_or(&0);
|
||||
|
||||
assert!(
|
||||
p50_ms <= 120,
|
||||
"fanout release latency p50 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}"
|
||||
);
|
||||
assert!(
|
||||
p95_ms <= 260,
|
||||
"fanout release latency p95 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}"
|
||||
);
|
||||
assert!(
|
||||
max_ms <= 700,
|
||||
"fanout release latency max must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}"
|
||||
);
|
||||
}
|
||||
|
|
@ -1,604 +0,0 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
|
||||
use tokio::sync::Barrier;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
|
||||
(wake_counter, Context::from_waker(leaked_waker))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_cross_mode_uncontended_writer_progresses() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
"cross-mode-tdd-uncontended".to_string(),
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let result = io.write_all(&[0x11, 0x22]).await;
|
||||
assert!(result.is_ok(), "uncontended writer must progress");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_held_cross_mode_lock_blocks_writer_even_if_local_lock_free() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-tdd-held-{}", std::process::id());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before polling writer");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]);
|
||||
assert!(poll.is_pending(), "writer must not bypass held cross-mode lock");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_parallel_waiters_resume_after_cross_mode_release() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-tdd-resume-{}", std::process::id());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before launching waiters");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut waiters = Vec::new();
|
||||
for _ in 0..16 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
waiters.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
stats,
|
||||
user,
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x7F]).await
|
||||
}));
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_secs(1), async {
|
||||
for waiter in waiters {
|
||||
let result = waiter.await.expect("waiter task must not panic");
|
||||
assert!(result.is_ok(), "waiter must complete after cross-mode release");
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all waiters must complete in bounded time");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_cross_mode_contention_wake_budget_stays_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-tdd-wakes-{}", std::process::id());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before polling");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut ios = Vec::new();
|
||||
let mut counters = Vec::new();
|
||||
for _ in 0..20 {
|
||||
ios.push(StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
}
|
||||
|
||||
for io in &mut ios {
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let poll = Pin::new(io).poll_write(&mut cx, &[0x33]);
|
||||
assert!(poll.is_pending());
|
||||
counters.push(wake_counter);
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
let total_wakes: usize = counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
|
||||
assert!(
|
||||
total_wakes <= 20 * 4,
|
||||
"cross-mode contention should not create wake storms; wakes={total_wakes}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn light_fuzz_cross_mode_release_timing_preserves_read_write_liveness() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut seed = 0xC0DE_BAAD_2026_0322u64;
|
||||
for round in 0..16u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let sleep_ms = 2 + (seed as u64 % 8);
|
||||
let user = format!("cross-mode-tdd-fuzz-{}-{round}", std::process::id());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock in fuzz round");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user_reader = user.clone();
|
||||
let reader_task = tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user_reader,
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
let mut one = [0u8; 1];
|
||||
io.read(&mut one).await
|
||||
});
|
||||
|
||||
let user_writer = user.clone();
|
||||
let writer_task = tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user_writer,
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x44]).await
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
|
||||
drop(held_guard);
|
||||
|
||||
let read_done = timeout(Duration::from_millis(350), reader_task)
|
||||
.await
|
||||
.expect("reader task must complete after release")
|
||||
.expect("reader task must not panic");
|
||||
assert!(read_done.is_ok());
|
||||
|
||||
let write_done = timeout(Duration::from_millis(350), writer_task)
|
||||
.await
|
||||
.expect("writer task must complete after release")
|
||||
.expect("writer task must not panic");
|
||||
assert!(write_done.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_middle_lock_blocks_relay_reader_for_same_user() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-middle-reader-block-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold middle-relay shared lock");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let mut one = [0u8; 1];
|
||||
let mut buf = ReadBuf::new(&mut one);
|
||||
let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
|
||||
assert!(poll.is_pending());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn integration_middle_lock_release_unblocks_relay_reader() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-middle-reader-release-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold middle-relay shared lock");
|
||||
|
||||
let task = tokio::spawn({
|
||||
let user = user.clone();
|
||||
async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
let mut one = [0u8; 1];
|
||||
io.read(&mut one).await
|
||||
}
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
drop(held_guard);
|
||||
|
||||
let done = timeout(Duration::from_millis(300), task)
|
||||
.await
|
||||
.expect("reader task must complete after release")
|
||||
.expect("reader task must not panic");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn business_different_user_middle_lock_does_not_block_relay_writer() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let held_user = format!("cross-mode-middle-held-{}", std::process::id());
|
||||
let active_user = format!("cross-mode-middle-active-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&held_user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold middle-relay lock for other user");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
active_user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x61]);
|
||||
assert!(matches!(poll, Poll::Ready(Ok(1))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_quota_none_bypasses_cross_mode_lock_even_when_held() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-none-limit-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock while quota is disabled");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
None,
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x62, 0x63]);
|
||||
assert!(matches!(poll, Poll::Ready(Ok(2))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_quota_exceeded_flag_short_circuits_before_lock_path() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-pre-exceeded-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared lock before poll");
|
||||
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(true));
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::clone("a_exceeded),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x64]);
|
||||
assert!(matches!(poll, Poll::Ready(Err(ref e)) if is_quota_io_error(e)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_repoll_while_middle_lock_held_keeps_pending_without_usage_leak() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-repoll-held-{}", std::process::id());
|
||||
let stats = Arc::new(Stats::new());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock for repoll sequence");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
for _ in 0..8 {
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x65]);
|
||||
assert!(poll.is_pending());
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_same_user_mixed_read_write_waiters_resume_after_release() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-mixed-resume-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock before spawning mixed waiters");
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
for i in 0..12usize {
|
||||
let user = user.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
if i % 2 == 0 {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
let mut b = [0u8; 1];
|
||||
io.read(&mut b).await.map(|_| ())
|
||||
} else {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x66]).await
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(8)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_secs(1), async {
|
||||
for task in tasks {
|
||||
let result = task.await.expect("mixed waiter task must not panic");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all mixed waiters must finish after release");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_one_user_blocked_other_user_progresses_under_middle_lock() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let blocked_user = format!("cross-mode-blocked-{}", std::process::id());
|
||||
let free_user = format!("cross-mode-free-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold blocked user lock");
|
||||
|
||||
let blocked_task = tokio::spawn({
|
||||
let blocked_user = blocked_user.clone();
|
||||
async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
blocked_user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x77]).await
|
||||
}
|
||||
});
|
||||
|
||||
let free_task = tokio::spawn({
|
||||
let free_user = free_user.clone();
|
||||
async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
free_user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x78]).await
|
||||
}
|
||||
});
|
||||
|
||||
let free_done = timeout(Duration::from_millis(250), free_task)
|
||||
.await
|
||||
.expect("free user must not be blocked")
|
||||
.expect("free user task must not panic");
|
||||
assert!(free_done.is_ok());
|
||||
|
||||
drop(held_guard);
|
||||
let blocked_done = timeout(Duration::from_secs(1), blocked_task)
|
||||
.await
|
||||
.expect("blocked user must resume after release")
|
||||
.expect("blocked user task must not panic");
|
||||
assert!(blocked_done.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_middle_lock_release_allows_high_waiter_fanout_completion() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("cross-mode-fanout-{}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock before fanout");
|
||||
|
||||
let waiters = 48usize;
|
||||
let gate = Arc::new(Barrier::new(waiters + 1));
|
||||
let mut tasks = Vec::new();
|
||||
for _ in 0..waiters {
|
||||
let user = user.clone();
|
||||
let gate = Arc::clone(&gate);
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
gate.wait().await;
|
||||
io.write_all(&[0x79]).await
|
||||
}));
|
||||
}
|
||||
|
||||
gate.wait().await;
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_secs(2), async {
|
||||
for task in tasks {
|
||||
let result = task.await.expect("fanout task must not panic");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("fanout waiters must complete after release");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn light_fuzz_middle_lock_hold_release_cycles_preserve_same_user_liveness() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut seed = 0xA11C_EE55_2026_0323u64;
|
||||
for round in 0..20u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold_ms = 2 + (seed % 10);
|
||||
let user = format!("cross-mode-middle-fuzz-{}-{round}", std::process::id());
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold lock in fuzz round");
|
||||
|
||||
let writer = tokio::spawn({
|
||||
let user = user.clone();
|
||||
async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x7A]).await
|
||||
}
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(hold_ms)).await;
|
||||
drop(held_guard);
|
||||
|
||||
let done = timeout(Duration::from_millis(400), writer)
|
||||
.await
|
||||
.expect("writer must complete after lock release")
|
||||
.expect("writer task must not panic");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
}
|
||||
|
|
@ -1,81 +0,0 @@
|
|||
use super::*;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::Waker;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
|
||||
(wake_counter, Context::from_waker(leaked_waker))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_middle_held_cross_mode_lock_blocks_relay_writer() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
|
||||
let user = "cross-mode-lock-shared-user";
|
||||
let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold shared cross-mode lock before relay poll");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(crate::stats::Stats::new()),
|
||||
user.to_string(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42, 0x43]);
|
||||
|
||||
assert!(
|
||||
matches!(poll, Poll::Pending),
|
||||
"relay writer must not bypass cross-mode lock held by middle-relay path"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn business_cross_mode_lock_uncontended_allows_relay_writer_progress() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
|
||||
let user = "cross-mode-lock-progress-user";
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(crate::stats::Stats::new()),
|
||||
user.to_string(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51, 0x52]);
|
||||
|
||||
assert!(
|
||||
matches!(poll, Poll::Ready(Ok(2))),
|
||||
"relay writer should progress when shared cross-mode lock is uncontended"
|
||||
);
|
||||
}
|
||||
|
|
@ -1,340 +0,0 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_uncontended_dual_lock_writer_has_zero_retry_attempt() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
format!("dual-lock-alt-positive-{}", std::process::id()),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let write = io.write_all(&[0xAA, 0xBB]).await;
|
||||
assert!(write.is_ok(), "uncontended write must complete");
|
||||
assert_eq!(
|
||||
io.quota_write_retry_attempt, 0,
|
||||
"uncontended write must not advance retry backoff"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_alternating_local_and_cross_mode_contention_preserves_backoff_growth() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-alt-adversarial-{}", std::process::id());
|
||||
let local_lock = quota_user_lock(&user);
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
|
||||
let mut local_guard = Some(
|
||||
local_lock
|
||||
.try_lock()
|
||||
.expect("test must hold local quota lock initially"),
|
||||
);
|
||||
let mut cross_guard = None;
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]);
|
||||
assert!(first.is_pending(), "held local lock must block first poll");
|
||||
|
||||
let mut observed_wakes = 0usize;
|
||||
for idx in 0..18usize {
|
||||
tokio::time::sleep(Duration::from_millis(6)).await;
|
||||
|
||||
if idx % 2 == 0 {
|
||||
drop(local_guard.take());
|
||||
cross_guard = Some(
|
||||
cross_mode_lock
|
||||
.try_lock()
|
||||
.expect("cross-mode lock should be acquirable while local lock released"),
|
||||
);
|
||||
} else {
|
||||
drop(cross_guard.take());
|
||||
local_guard = Some(
|
||||
local_lock
|
||||
.try_lock()
|
||||
.expect("local lock should be acquirable while cross lock released"),
|
||||
);
|
||||
}
|
||||
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed_wakes {
|
||||
observed_wakes = wakes;
|
||||
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x12]);
|
||||
assert!(
|
||||
pending.is_pending(),
|
||||
"alternating contention must keep write pending while one lock is held"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
io.quota_write_retry_attempt >= 2,
|
||||
"alternating contention must still ramp retry backoff; got {}",
|
||||
io.quota_write_retry_attempt
|
||||
);
|
||||
assert!(
|
||||
wake_counter.wakes.load(Ordering::Relaxed) <= 32,
|
||||
"alternating contention must stay wake-rate-limited"
|
||||
);
|
||||
|
||||
drop(local_guard);
|
||||
drop(cross_guard);
|
||||
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x13]);
|
||||
assert!(ready.is_ready(), "writer must resume after both locks released");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_retry_scheduler_resets_after_alternating_contention_clears() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-alt-edge-reset-{}", std::process::id());
|
||||
let local_lock = quota_user_lock(&user);
|
||||
let local_guard = local_lock
|
||||
.try_lock()
|
||||
.expect("test must hold local lock for edge scenario");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0x21]);
|
||||
assert!(first.is_pending());
|
||||
tokio::time::sleep(Duration::from_millis(15)).await;
|
||||
if wake_counter.wakes.load(Ordering::Relaxed) > 0 {
|
||||
let next = Pin::new(&mut io).poll_write(&mut cx, &[0x22]);
|
||||
assert!(next.is_pending());
|
||||
}
|
||||
|
||||
drop(local_guard);
|
||||
|
||||
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x23]);
|
||||
assert!(ready.is_ready());
|
||||
assert_eq!(
|
||||
io.quota_write_retry_attempt, 0,
|
||||
"successful dual-lock acquisition must reset retry scheduler"
|
||||
);
|
||||
assert!(!io.quota_write_wake_scheduled);
|
||||
assert!(io.quota_write_retry_sleep.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_cross_mode_waiters_remain_live_under_alternating_contention_then_resume() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-alt-integration-{}", std::process::id());
|
||||
let local_lock = quota_user_lock(&user);
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
|
||||
let mut waiters = Vec::new();
|
||||
for _ in 0..16usize {
|
||||
let user = user.clone();
|
||||
waiters.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
timeout(Duration::from_secs(2), io.write_all(&[0x31])).await
|
||||
}));
|
||||
}
|
||||
|
||||
let mut local_guard = Some(
|
||||
local_lock
|
||||
.try_lock()
|
||||
.expect("integration toggle must acquire local lock first"),
|
||||
);
|
||||
let mut cross_guard = None;
|
||||
|
||||
for idx in 0..24usize {
|
||||
tokio::time::sleep(Duration::from_millis(4)).await;
|
||||
if idx % 2 == 0 {
|
||||
drop(local_guard.take());
|
||||
cross_guard = cross_mode_lock.try_lock().ok();
|
||||
} else {
|
||||
drop(cross_guard.take());
|
||||
local_guard = local_lock.try_lock().ok();
|
||||
}
|
||||
}
|
||||
|
||||
drop(local_guard);
|
||||
drop(cross_guard);
|
||||
|
||||
for waiter in waiters {
|
||||
let done = waiter.await.expect("waiter task must not panic");
|
||||
assert!(
|
||||
done.is_ok(),
|
||||
"waiter must finish once alternating contention window ends"
|
||||
);
|
||||
assert!(done.expect("waiter timeout must not fire").is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_alternating_contention_matrix_preserves_lock_gating() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-alt-fuzz-{}", std::process::id());
|
||||
let local_lock = quota_user_lock(&user);
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let mut seed = 0xD00D_BAAD_F00D_2026u64;
|
||||
|
||||
for _round in 0..64u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold_mode = (seed % 3) as u8;
|
||||
let local_guard = if hold_mode == 0 {
|
||||
Some(
|
||||
local_lock
|
||||
.try_lock()
|
||||
.expect("fuzz local lock should be acquirable"),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let cross_guard = if hold_mode == 1 {
|
||||
Some(
|
||||
cross_mode_lock
|
||||
.try_lock()
|
||||
.expect("fuzz cross lock should be acquirable"),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user.clone(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let write = timeout(Duration::from_millis(35), io.write_all(&[0x51])).await;
|
||||
if hold_mode == 2 {
|
||||
assert!(write.is_ok(), "unheld fuzz round must make progress");
|
||||
assert!(write.expect("unheld round timeout").is_ok());
|
||||
} else {
|
||||
assert!(
|
||||
write.is_err(),
|
||||
"held-lock fuzz round must remain pending inside bounded window"
|
||||
);
|
||||
}
|
||||
|
||||
drop(local_guard);
|
||||
drop(cross_guard);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_fanout_alternating_contention_recovers_without_hanging() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-alt-stress-{}", std::process::id());
|
||||
let local_lock = quota_user_lock(&user);
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
|
||||
let mut waiters = Vec::new();
|
||||
for _ in 0..48usize {
|
||||
let user = user.clone();
|
||||
waiters.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
timeout(Duration::from_secs(3), io.write_all(&[0xA0, 0xA1])).await
|
||||
}));
|
||||
}
|
||||
|
||||
let mut local_guard = Some(
|
||||
local_lock
|
||||
.try_lock()
|
||||
.expect("stress toggle must acquire local lock first"),
|
||||
);
|
||||
let mut cross_guard = None;
|
||||
for idx in 0..40usize {
|
||||
tokio::time::sleep(Duration::from_millis(3)).await;
|
||||
if idx % 2 == 0 {
|
||||
drop(local_guard.take());
|
||||
cross_guard = cross_mode_lock.try_lock().ok();
|
||||
} else {
|
||||
drop(cross_guard.take());
|
||||
local_guard = local_lock.try_lock().ok();
|
||||
}
|
||||
}
|
||||
|
||||
drop(local_guard);
|
||||
drop(cross_guard);
|
||||
|
||||
for waiter in waiters {
|
||||
let done = waiter.await.expect("stress waiter task must not panic");
|
||||
assert!(done.is_ok(), "stress waiter timed out under alternating contention");
|
||||
assert!(done.expect("stress waiter timeout should not fire").is_ok());
|
||||
}
|
||||
}
|
||||
|
|
@ -1,74 +0,0 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use tokio::time::{Duration, Instant};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_cross_mode_only_contention_backoff_attempt_must_ramp() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-backoff-{}", std::process::id());
|
||||
let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_cross_mode_guard = cross_mode_lock
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before polling");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]);
|
||||
assert!(first.is_pending(), "held cross-mode lock must block writer");
|
||||
|
||||
let started = Instant::now();
|
||||
let mut last_wakes = 0usize;
|
||||
while started.elapsed() < Duration::from_millis(120) {
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > last_wakes {
|
||||
last_wakes = wakes;
|
||||
let next = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]);
|
||||
assert!(next.is_pending(), "writer must remain blocked while lock is held");
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
assert!(
|
||||
io.quota_write_retry_attempt >= 2,
|
||||
"retry attempt must ramp under sustained second-lock contention; got {}",
|
||||
io.quota_write_retry_attempt
|
||||
);
|
||||
|
||||
drop(held_cross_mode_guard);
|
||||
}
|
||||
|
|
@ -1,325 +0,0 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
|
||||
(wake_counter, Context::from_waker(leaked_waker))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_uncontended_dual_locks_writer_completes_without_retry_state() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
format!("dual-lock-positive-{}", std::process::id()),
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x01, 0x02, 0x03]);
|
||||
assert!(poll.is_ready());
|
||||
assert_eq!(io.quota_write_retry_attempt, 0);
|
||||
assert!(!io.quota_write_wake_scheduled);
|
||||
assert!(io.quota_write_retry_sleep.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_local_lock_contention_read_retry_attempt_ramps() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-local-contention-{}", std::process::id());
|
||||
let held = quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold local quota lock before polling");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (wake_counter, mut cx) = build_context();
|
||||
let mut one = [0u8; 1];
|
||||
let mut buf = ReadBuf::new(&mut one);
|
||||
let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
|
||||
assert!(first.is_pending());
|
||||
|
||||
let started = Instant::now();
|
||||
let mut observed = 0usize;
|
||||
while started.elapsed() < Duration::from_millis(120) {
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed {
|
||||
observed = wakes;
|
||||
let mut step_buf = ReadBuf::new(&mut one);
|
||||
let next = Pin::new(&mut io).poll_read(&mut cx, &mut step_buf);
|
||||
assert!(next.is_pending());
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
assert!(
|
||||
io.quota_read_retry_attempt >= 2,
|
||||
"retry attempt must ramp under sustained local-lock contention; got {}",
|
||||
io.quota_read_retry_attempt
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_cross_mode_contention_release_resets_retry_scheduler_on_success() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-reset-{}", std::process::id());
|
||||
let cross_mode = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = cross_mode
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before polling");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (wake_counter, mut cx) = build_context();
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0x10]);
|
||||
assert!(first.is_pending());
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
if wake_counter.wakes.load(Ordering::Relaxed) > 0 {
|
||||
let next = Pin::new(&mut io).poll_write(&mut cx, &[0x11]);
|
||||
assert!(next.is_pending());
|
||||
}
|
||||
|
||||
drop(held_guard);
|
||||
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x12]);
|
||||
assert!(ready.is_ready());
|
||||
assert_eq!(io.quota_write_retry_attempt, 0);
|
||||
assert!(!io.quota_write_wake_scheduled);
|
||||
assert!(io.quota_write_retry_sleep.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_cross_mode_hold_blocks_many_waiters_without_usage_leak() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-adversarial-{}", std::process::id());
|
||||
let stats = Arc::new(Stats::new());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before launching waiters");
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
for _ in 0..24usize {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
stats,
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
timeout(Duration::from_millis(40), io.write_all(&[0x33])).await
|
||||
}));
|
||||
}
|
||||
|
||||
for task in tasks {
|
||||
let timed = task.await.expect("waiter task must not panic");
|
||||
assert!(timed.is_err(), "held cross-mode lock must keep waiter pending");
|
||||
}
|
||||
|
||||
assert_eq!(stats.get_user_total_octets(&user), 0);
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_waiters_resume_after_cross_mode_release() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-integration-{}", std::process::id());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before starting waiter");
|
||||
|
||||
let task = tokio::spawn({
|
||||
let user = user.clone();
|
||||
async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x44]).await
|
||||
}
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
drop(held_guard);
|
||||
|
||||
let done = timeout(Duration::from_secs(1), task)
|
||||
.await
|
||||
.expect("waiter task must complete after release")
|
||||
.expect("waiter task must not panic");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn light_fuzz_randomized_lock_holds_preserve_liveness_and_quota_bounds() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-fuzz-{}", std::process::id());
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut seed = 0xA55A_55AA_C3D2_E1F0u64;
|
||||
|
||||
for _round in 0..48u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold_mode = (seed % 3) as u8;
|
||||
let mut local_lock = None;
|
||||
let mut cross_lock = None;
|
||||
let mut local_guard = None;
|
||||
let mut cross_guard = None;
|
||||
|
||||
if hold_mode == 0 {
|
||||
local_lock = Some(quota_user_lock(&user));
|
||||
local_guard = Some(
|
||||
local_lock
|
||||
.as_ref()
|
||||
.expect("local lock should be present")
|
||||
.try_lock()
|
||||
.expect("local lock should be acquirable in fuzz round"),
|
||||
);
|
||||
} else if hold_mode == 1 {
|
||||
cross_lock = Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(
|
||||
&user,
|
||||
));
|
||||
cross_guard = Some(
|
||||
cross_lock
|
||||
.as_ref()
|
||||
.expect("cross lock should be present")
|
||||
.try_lock()
|
||||
.expect("cross lock should be acquirable in fuzz round"),
|
||||
);
|
||||
}
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let write = timeout(Duration::from_millis(25), io.write_all(&[0x7A])).await;
|
||||
if hold_mode == 2 {
|
||||
assert!(write.is_ok(), "unheld round must make progress");
|
||||
} else {
|
||||
assert!(write.is_err(), "held-lock round must stay blocked within timeout");
|
||||
}
|
||||
|
||||
drop(local_guard);
|
||||
drop(cross_guard);
|
||||
drop(local_lock);
|
||||
drop(cross_lock);
|
||||
}
|
||||
|
||||
assert!(stats.get_user_total_octets(&user) <= 4096);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_fanout_waiters_complete_after_release_without_panics() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-stress-{}", std::process::id());
|
||||
let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold cross-mode lock before stress fanout");
|
||||
|
||||
let waiters = 64usize;
|
||||
let mut tasks = Vec::new();
|
||||
for _ in 0..waiters {
|
||||
let user = user.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
let mut one = [0u8; 1];
|
||||
io.read(&mut one).await
|
||||
}));
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(12)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_secs(2), async {
|
||||
for task in tasks {
|
||||
let result = task.await.expect("stress waiter task must not panic");
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all stress waiters must complete after release");
|
||||
}
|
||||
|
|
@ -1,128 +0,0 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn make_stats_io(user: String) -> StatsIo<tokio::io::Sink> {
|
||||
StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn light_fuzz_1024_round_hold_release_cycles_preserve_same_user_liveness() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-race-fuzz-{}", std::process::id());
|
||||
let mut seed = 0xD1CE_BAAD_5EED_1234u64;
|
||||
|
||||
for round in 0..1024u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold = (seed & 1) == 0;
|
||||
let hold_ms = (seed % 3) as u64;
|
||||
|
||||
let maybe_lock = if hold {
|
||||
Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(
|
||||
&user,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let maybe_guard = maybe_lock.as_ref().map(|lock| {
|
||||
lock.try_lock()
|
||||
.expect("cross-mode lock must be acquirable in fuzz round")
|
||||
});
|
||||
|
||||
if hold {
|
||||
let mut blocked_io = make_stats_io(user.clone());
|
||||
let blocked = timeout(Duration::from_millis(5), blocked_io.write_all(&[0xA5])).await;
|
||||
assert!(
|
||||
blocked.is_err(),
|
||||
"held round must block waiter before lock release (round={round})"
|
||||
);
|
||||
|
||||
if hold_ms > 0 {
|
||||
tokio::time::sleep(Duration::from_millis(hold_ms)).await;
|
||||
}
|
||||
} else {
|
||||
let mut free_io = make_stats_io(user.clone());
|
||||
let free = timeout(Duration::from_millis(120), free_io.write_all(&[0xA5])).await;
|
||||
assert!(
|
||||
free.is_ok(),
|
||||
"unheld round must complete promptly (round={round})"
|
||||
);
|
||||
assert!(free.expect("unheld round should complete").is_ok());
|
||||
}
|
||||
|
||||
drop(maybe_guard);
|
||||
|
||||
let done = timeout(Duration::from_millis(350), async {
|
||||
let user = user.clone();
|
||||
let mut io = make_stats_io(user);
|
||||
io.write_all(&[0xA6]).await
|
||||
})
|
||||
.await
|
||||
.expect("post-release write must complete in bounded time");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_jittered_three_waiter_rounds_do_not_starve_after_release() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("dual-lock-race-stress-{}", std::process::id());
|
||||
let mut seed = 0xC0FF_EE77_4444_9999u64;
|
||||
|
||||
for round in 0..256u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let hold_ms = (seed % 4) as u64;
|
||||
let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user);
|
||||
let guard = lock
|
||||
.try_lock()
|
||||
.expect("cross-mode lock must be acquirable at round start");
|
||||
|
||||
let mut waiters = Vec::new();
|
||||
for _ in 0..3usize {
|
||||
let user = user.clone();
|
||||
waiters.push(tokio::spawn(async move {
|
||||
let mut io = make_stats_io(user);
|
||||
io.write_all(&[0x55]).await
|
||||
}));
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(hold_ms)).await;
|
||||
drop(guard);
|
||||
|
||||
timeout(Duration::from_secs(1), async {
|
||||
for waiter in waiters {
|
||||
let done = waiter.await.expect("waiter task must not panic");
|
||||
assert!(
|
||||
done.is_ok(),
|
||||
"waiter must complete after release (round={round})"
|
||||
);
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all waiters must complete in bounded time after release");
|
||||
}
|
||||
}
|
||||
|
|
@ -29,6 +29,11 @@ async fn read_available<R: AsyncRead + Unpin>(reader: &mut R, budget: Duration)
|
|||
total
|
||||
}
|
||||
|
||||
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
|
||||
let user_stats = stats.get_or_create_user_stats_handle(user);
|
||||
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn integration_full_duplex_exact_budget_then_hard_cutoff() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
|
@ -102,14 +107,14 @@ async fn integration_full_duplex_exact_budget_then_hard_cutoff() {
|
|||
relay_result,
|
||||
Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-full-duplex-boundary-user"
|
||||
));
|
||||
assert!(stats.get_user_total_octets(user) <= 10);
|
||||
assert!(stats.get_user_quota_used(user) <= 10);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_preloaded_quota_blocks_both_directions_immediately() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-preloaded-cutoff-user";
|
||||
stats.add_user_octets_from(user, 5);
|
||||
preload_user_quota(stats.as_ref(), user, 5);
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(2048);
|
||||
let (relay_server, mut server_peer) = duplex(2048);
|
||||
|
|
@ -154,7 +159,7 @@ async fn negative_preloaded_quota_blocks_both_directions_immediately() {
|
|||
relay_result,
|
||||
Err(ProxyError::DataQuotaExceeded { .. })
|
||||
));
|
||||
assert!(stats.get_user_total_octets(user) <= 5);
|
||||
assert!(stats.get_user_quota_used(user) <= 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -212,7 +217,7 @@ async fn edge_quota_one_bidirectional_race_allows_at_most_one_forwarded_octet()
|
|||
relay_result,
|
||||
Err(ProxyError::DataQuotaExceeded { .. })
|
||||
));
|
||||
assert!(stats.get_user_total_octets(user) <= 1);
|
||||
assert!(stats.get_user_quota_used(user) <= 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -277,7 +282,7 @@ async fn adversarial_blackhat_alternating_fragmented_jitter_never_overshoots_glo
|
|||
delivered_to_server + delivered_to_client <= quota as usize,
|
||||
"combined forwarded bytes must never exceed configured quota"
|
||||
);
|
||||
assert!(stats.get_user_total_octets(user) <= quota);
|
||||
assert!(stats.get_user_quota_used(user) <= quota);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -356,7 +361,7 @@ async fn light_fuzz_randomized_schedule_preserves_quota_and_forwarded_byte_invar
|
|||
"fuzz case {case}: forwarded bytes must not exceed quota"
|
||||
);
|
||||
assert!(
|
||||
stats.get_user_total_octets(&user) <= quota,
|
||||
stats.get_user_quota_used(&user) <= quota,
|
||||
"fuzz case {case}: accounted bytes must not exceed quota"
|
||||
);
|
||||
}
|
||||
|
|
@ -451,7 +456,7 @@ async fn stress_multi_relay_same_user_mixed_direction_jitter_respects_global_quo
|
|||
}
|
||||
|
||||
assert!(
|
||||
stats.get_user_total_octets(user) <= quota,
|
||||
stats.get_user_quota_used(user) <= quota,
|
||||
"global per-user quota must hold under concurrent mixed-direction relay stress"
|
||||
);
|
||||
assert!(
|
||||
|
|
|
|||
|
|
@ -5,10 +5,13 @@ use crate::stream::BufferPool;
|
|||
use rand::rngs::StdRng;
|
||||
use rand::{RngExt, SeedableRng};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
async fn read_available<R: tokio::io::AsyncRead + Unpin>(reader: &mut R, budget: Duration) -> usize {
|
||||
async fn read_available<R: tokio::io::AsyncRead + Unpin>(
|
||||
reader: &mut R,
|
||||
budget: Duration,
|
||||
) -> usize {
|
||||
let start = tokio::time::Instant::now();
|
||||
let mut total = 0usize;
|
||||
let mut buf = [0u8; 128];
|
||||
|
|
@ -29,6 +32,11 @@ async fn read_available<R: tokio::io::AsyncRead + Unpin>(reader: &mut R, budget:
|
|||
total
|
||||
}
|
||||
|
||||
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
|
||||
let user_stats = stats.get_or_create_user_stats_handle(user);
|
||||
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_quota_path_forwards_both_directions_within_limit() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
|
@ -52,25 +60,34 @@ async fn positive_quota_path_forwards_both_directions_within_limit() {
|
|||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
client_peer.write_all(&[0xAA, 0xBB, 0xCC, 0xDD]).await.unwrap();
|
||||
client_peer
|
||||
.write_all(&[0xAA, 0xBB, 0xCC, 0xDD])
|
||||
.await
|
||||
.unwrap();
|
||||
server_peer.read_exact(&mut [0u8; 4]).await.unwrap();
|
||||
|
||||
server_peer.write_all(&[0x11, 0x22, 0x33, 0x44]).await.unwrap();
|
||||
server_peer
|
||||
.write_all(&[0x11, 0x22, 0x33, 0x44])
|
||||
.await
|
||||
.unwrap();
|
||||
client_peer.read_exact(&mut [0u8; 4]).await.unwrap();
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
|
||||
let relay_result = timeout(Duration::from_secs(2), relay)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert!(relay_result.is_ok());
|
||||
assert!(stats.get_user_total_octets(user) <= 16);
|
||||
assert!(stats.get_user_quota_used(user) <= 16);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn negative_preloaded_quota_forbids_any_forwarding() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-extended-negative-user";
|
||||
stats.add_user_octets_from(user, 8);
|
||||
preload_user_quota(stats.as_ref(), user, 8);
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(1024);
|
||||
let (relay_server, mut server_peer) = duplex(1024);
|
||||
|
|
@ -93,12 +110,24 @@ async fn negative_preloaded_quota_forbids_any_forwarding() {
|
|||
client_peer.write_all(&[0xAA]).await.unwrap();
|
||||
server_peer.write_all(&[0xBB]).await.unwrap();
|
||||
|
||||
assert_eq!(read_available(&mut server_peer, Duration::from_millis(120)).await, 0);
|
||||
assert_eq!(read_available(&mut client_peer, Duration::from_millis(120)).await, 0);
|
||||
assert_eq!(
|
||||
read_available(&mut server_peer, Duration::from_millis(120)).await,
|
||||
0
|
||||
);
|
||||
assert_eq!(
|
||||
read_available(&mut client_peer, Duration::from_millis(120)).await,
|
||||
0
|
||||
);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
|
||||
assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
assert!(stats.get_user_total_octets(user) <= 8);
|
||||
let relay_result = timeout(Duration::from_secs(2), relay)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
relay_result,
|
||||
Err(ProxyError::DataQuotaExceeded { .. })
|
||||
));
|
||||
assert!(stats.get_user_quota_used(user) <= 8);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -130,13 +159,25 @@ async fn edge_quota_one_ensures_at_most_one_byte_across_directions() {
|
|||
);
|
||||
|
||||
let mut buf = [0u8; 1];
|
||||
let delivered_s2c = timeout(Duration::from_millis(120), client_peer.read(&mut buf)).await.unwrap().unwrap_or(0);
|
||||
let delivered_c2s = timeout(Duration::from_millis(120), server_peer.read(&mut buf)).await.unwrap().unwrap_or(0);
|
||||
let delivered_s2c = timeout(Duration::from_millis(120), client_peer.read(&mut buf))
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap_or(0);
|
||||
let delivered_c2s = timeout(Duration::from_millis(120), server_peer.read(&mut buf))
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap_or(0);
|
||||
|
||||
assert!(delivered_s2c + delivered_c2s <= 1);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
|
||||
assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
let relay_result = timeout(Duration::from_secs(2), relay)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
relay_result,
|
||||
Err(ProxyError::DataQuotaExceeded { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -186,10 +227,16 @@ async fn adversarial_blackhat_alternating_jitter_does_not_overshoot_quota() {
|
|||
tokio::time::sleep(Duration::from_millis(((i % 3) + 1) as u64)).await;
|
||||
}
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(3), relay).await.unwrap().unwrap();
|
||||
assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
let relay_result = timeout(Duration::from_secs(3), relay)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert!(matches!(
|
||||
relay_result,
|
||||
Err(ProxyError::DataQuotaExceeded { .. })
|
||||
));
|
||||
assert!(total_forwarded <= quota as usize);
|
||||
assert!(stats.get_user_total_octets(user) <= quota);
|
||||
assert!(stats.get_user_quota_used(user) <= quota);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -234,13 +281,17 @@ async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() {
|
|||
if rng.random::<bool>() {
|
||||
let _ = client_peer.write_all(&[rng.random::<u8>()]).await;
|
||||
let mut one = [0u8; 1];
|
||||
if let Ok(Ok(n)) = timeout(Duration::from_millis(4), server_peer.read(&mut one)).await {
|
||||
if let Ok(Ok(n)) =
|
||||
timeout(Duration::from_millis(4), server_peer.read(&mut one)).await
|
||||
{
|
||||
total_forwarded += n;
|
||||
}
|
||||
} else {
|
||||
let _ = server_peer.write_all(&[rng.random::<u8>()]).await;
|
||||
let mut one = [0u8; 1];
|
||||
if let Ok(Ok(n)) = timeout(Duration::from_millis(4), client_peer.read(&mut one)).await {
|
||||
if let Ok(Ok(n)) =
|
||||
timeout(Duration::from_millis(4), client_peer.read(&mut one)).await
|
||||
{
|
||||
total_forwarded += n;
|
||||
}
|
||||
}
|
||||
|
|
@ -249,10 +300,16 @@ async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() {
|
|||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
|
||||
assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
let relay_result = timeout(Duration::from_secs(2), relay)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert!(
|
||||
relay_result.is_ok()
|
||||
|| matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))
|
||||
);
|
||||
assert!(total_forwarded <= quota as usize);
|
||||
assert!(stats.get_user_total_octets(&user) <= quota);
|
||||
assert!(stats.get_user_quota_used(&user) <= quota);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -300,13 +357,17 @@ async fn stress_parallel_relays_for_one_user_obey_global_quota() {
|
|||
if (step as usize + worker as usize) % 2 == 0 {
|
||||
let _ = client_peer.write_all(&[(step ^ 0x5A)]).await;
|
||||
let mut one = [0u8; 1];
|
||||
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await {
|
||||
if let Ok(Ok(n)) =
|
||||
timeout(Duration::from_millis(6), server_peer.read(&mut one)).await
|
||||
{
|
||||
total += n;
|
||||
}
|
||||
} else {
|
||||
let _ = server_peer.write_all(&[(step ^ 0xA5)]).await;
|
||||
let mut one = [0u8; 1];
|
||||
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await {
|
||||
if let Ok(Ok(n)) =
|
||||
timeout(Duration::from_millis(6), client_peer.read(&mut one)).await
|
||||
{
|
||||
total += n;
|
||||
}
|
||||
}
|
||||
|
|
@ -316,8 +377,14 @@ async fn stress_parallel_relays_for_one_user_obey_global_quota() {
|
|||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap();
|
||||
assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
let relay_result = timeout(Duration::from_secs(2), relay)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert!(
|
||||
relay_result.is_ok()
|
||||
|| matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))
|
||||
);
|
||||
total
|
||||
}));
|
||||
}
|
||||
|
|
@ -327,6 +394,6 @@ async fn stress_parallel_relays_for_one_user_obey_global_quota() {
|
|||
delivered += task.await.unwrap();
|
||||
}
|
||||
|
||||
assert!(stats.get_user_total_octets(&user) <= quota);
|
||||
assert!(stats.get_user_quota_used(&user) <= quota);
|
||||
assert!(delivered <= quota as usize);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,79 +0,0 @@
|
|||
use super::*;
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
#[test]
|
||||
fn tdd_explicit_quota_lock_evict_reclaims_only_unheld_entries() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let held_user = format!("quota-evict-held-{}", std::process::id());
|
||||
let stale_a_user = format!("quota-evict-stale-a-{}", std::process::id());
|
||||
let stale_b_user = format!("quota-evict-stale-b-{}", std::process::id());
|
||||
|
||||
let held = quota_user_lock(&held_user);
|
||||
let stale_a = quota_user_lock(&stale_a_user);
|
||||
let stale_b = quota_user_lock(&stale_b_user);
|
||||
|
||||
assert!(map.get(&held_user).is_some());
|
||||
assert!(map.get(&stale_a_user).is_some());
|
||||
assert!(map.get(&stale_b_user).is_some());
|
||||
|
||||
drop(stale_a);
|
||||
drop(stale_b);
|
||||
|
||||
quota_user_lock_evict();
|
||||
|
||||
assert!(
|
||||
map.get(&held_user).is_some(),
|
||||
"held entry must survive eviction"
|
||||
);
|
||||
assert!(
|
||||
map.get(&stale_a_user).is_none(),
|
||||
"unheld stale entry must be reclaimed"
|
||||
);
|
||||
assert!(
|
||||
map.get(&stale_b_user).is_none(),
|
||||
"unheld stale entry must be reclaimed"
|
||||
);
|
||||
|
||||
drop(held);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn tdd_periodic_quota_lock_evictor_reclaims_stale_entries_off_hot_path() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let held_user = format!("quota-evict-loop-held-{}", std::process::id());
|
||||
let stale_user = format!("quota-evict-loop-stale-{}", std::process::id());
|
||||
|
||||
let held = quota_user_lock(&held_user);
|
||||
let stale = quota_user_lock(&stale_user);
|
||||
|
||||
assert_eq!(map.len(), 2);
|
||||
drop(stale);
|
||||
|
||||
let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5));
|
||||
|
||||
timeout(Duration::from_millis(200), async {
|
||||
loop {
|
||||
if map.get(&stale_user).is_none() {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("periodic quota lock evictor must reclaim stale entry");
|
||||
|
||||
evictor.abort();
|
||||
|
||||
assert!(map.get(&held_user).is_some());
|
||||
assert!(map.get(&stale_user).is_none());
|
||||
|
||||
drop(held);
|
||||
}
|
||||
|
|
@ -1,153 +0,0 @@
|
|||
use super::*;
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_background_evictor_with_high_churn_keeps_cache_bounded_and_live() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5));
|
||||
|
||||
let mut tasks = JoinSet::new();
|
||||
for worker in 0..24u32 {
|
||||
tasks.spawn(async move {
|
||||
for round in 0..320u32 {
|
||||
let user = format!(
|
||||
"quota-evict-stress-user-{}-{}-{}",
|
||||
std::process::id(),
|
||||
worker,
|
||||
round
|
||||
);
|
||||
let lock = quota_user_lock(&user);
|
||||
if round % 19 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
drop(lock);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(done) = tasks.join_next().await {
|
||||
done.expect("stress worker must not panic");
|
||||
}
|
||||
|
||||
quota_user_lock_evict();
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
|
||||
assert!(
|
||||
map.len() <= QUOTA_USER_LOCKS_MAX,
|
||||
"quota lock map must remain bounded after churn + eviction"
|
||||
);
|
||||
|
||||
let sanity_user = format!("quota-evict-stress-sanity-{}", std::process::id());
|
||||
let sanity_lock = quota_user_lock(&sanity_user);
|
||||
assert!(
|
||||
map.get(&sanity_user).is_some(),
|
||||
"sanity user should be cacheable after eviction reclaimed stale entries"
|
||||
);
|
||||
|
||||
drop(sanity_lock);
|
||||
evictor.abort();
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_held_lock_survives_repeated_eviction_then_reclaims_after_release() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let held_user = format!("quota-evict-held-survive-{}", std::process::id());
|
||||
let held = quota_user_lock(&held_user);
|
||||
|
||||
let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(3));
|
||||
|
||||
for idx in 0..512u32 {
|
||||
let user = format!("quota-evict-held-churn-{}-{}", std::process::id(), idx);
|
||||
let temp = quota_user_lock(&user);
|
||||
drop(temp);
|
||||
if idx % 32 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
|
||||
let reacquired = quota_user_lock(&held_user);
|
||||
assert!(
|
||||
Arc::ptr_eq(&held, &reacquired),
|
||||
"held user lock identity must remain stable across repeated evictions"
|
||||
);
|
||||
assert!(
|
||||
map.get(&held_user).is_some(),
|
||||
"held user entry must not be reclaimed while externally referenced"
|
||||
);
|
||||
|
||||
drop(reacquired);
|
||||
drop(held);
|
||||
|
||||
timeout(Duration::from_millis(300), async {
|
||||
loop {
|
||||
if map.get(&held_user).is_none() {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("released held lock must be reclaimed by periodic evictor");
|
||||
|
||||
evictor.abort();
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_saturation_then_periodic_eviction_recovers_cacheability_without_inline_retain() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
let prefix = format!("quota-evict-saturated-{}", std::process::id());
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
|
||||
}
|
||||
|
||||
assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX);
|
||||
|
||||
let overflow_user = format!("quota-evict-overflow-user-{}", std::process::id());
|
||||
let overflow_before = quota_user_lock(&overflow_user);
|
||||
assert!(
|
||||
map.get(&overflow_user).is_none(),
|
||||
"saturated map must initially route new user to overflow stripe"
|
||||
);
|
||||
|
||||
drop(retained);
|
||||
|
||||
let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(4));
|
||||
|
||||
timeout(Duration::from_millis(400), async {
|
||||
loop {
|
||||
if map.len() < QUOTA_USER_LOCKS_MAX {
|
||||
break;
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("periodic evictor must reclaim stale saturated entries");
|
||||
|
||||
let overflow_after = quota_user_lock(&overflow_user);
|
||||
assert!(
|
||||
map.get(&overflow_user).is_some(),
|
||||
"after eviction, overflow user should become cacheable again"
|
||||
);
|
||||
assert!(
|
||||
Arc::strong_count(&overflow_after) >= 2,
|
||||
"cacheable lock should be held by map and caller"
|
||||
);
|
||||
|
||||
drop(overflow_before);
|
||||
drop(overflow_after);
|
||||
evictor.abort();
|
||||
}
|
||||
|
|
@ -1,135 +0,0 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::Waker;
|
||||
use std::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
// Context stores a reference; leak one Waker for deterministic test scope.
|
||||
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
|
||||
(wake_counter, Context::from_waker(leaked_waker))
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_map_churn_cannot_bypass_held_writer_lock() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let user = "quota-identity-writer-user";
|
||||
let held_lock = quota_user_lock(user);
|
||||
let _held_guard = held_lock
|
||||
.try_lock()
|
||||
.expect("test must hold initial user lock before StatsIo poll");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user.to_string(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
map.clear();
|
||||
let churned_lock = quota_user_lock(user);
|
||||
assert!(
|
||||
!Arc::ptr_eq(&held_lock, &churned_lock),
|
||||
"precondition: map churn should produce a distinct lock identity"
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11, 0x22, 0x33, 0x44]);
|
||||
|
||||
assert!(
|
||||
matches!(poll, Poll::Pending),
|
||||
"writer must remain pending on the originally-held lock identity"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_map_churn_cannot_bypass_held_reader_lock() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let user = "quota-identity-reader-user";
|
||||
let held_lock = quota_user_lock(user);
|
||||
let _held_guard = held_lock
|
||||
.try_lock()
|
||||
.expect("test must hold initial user lock before StatsIo poll");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user.to_string(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
map.clear();
|
||||
let churned_lock = quota_user_lock(user);
|
||||
assert!(
|
||||
!Arc::ptr_eq(&held_lock, &churned_lock),
|
||||
"precondition: map churn should produce a distinct lock identity"
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let mut storage = [0u8; 8];
|
||||
let mut read_buf = ReadBuf::new(&mut storage);
|
||||
let poll = Pin::new(&mut io).poll_read(&mut cx, &mut read_buf);
|
||||
|
||||
assert!(
|
||||
matches!(poll, Poll::Pending),
|
||||
"reader must remain pending on the originally-held lock identity"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn business_no_lock_contention_keeps_writer_progress() {
|
||||
let _guard = quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let user = "quota-identity-progress-user";
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user.to_string(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA, 0xBB]);
|
||||
|
||||
assert!(
|
||||
matches!(poll, Poll::Ready(Ok(2))),
|
||||
"writer should progress immediately without contention"
|
||||
);
|
||||
}
|
||||
|
|
@ -1,440 +0,0 @@
|
|||
use super::*;
|
||||
use crate::error::ProxyError;
|
||||
use crate::stats::Stats;
|
||||
use crate::stream::BufferPool;
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
|
||||
use tokio::sync::Barrier;
|
||||
use tokio::time::Instant;
|
||||
|
||||
#[test]
|
||||
fn quota_lock_same_user_returns_same_arc_instance() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let a = quota_user_lock("quota-lock-same-user");
|
||||
let b = quota_user_lock("quota-lock-same-user");
|
||||
assert!(Arc::ptr_eq(&a, &b));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_lock_parallel_same_user_reuses_single_lock() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let user = "quota-lock-parallel-same";
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for _ in 0..64 {
|
||||
handles.push(std::thread::spawn(move || quota_user_lock(user)));
|
||||
}
|
||||
|
||||
let first = handles
|
||||
.remove(0)
|
||||
.join()
|
||||
.expect("thread must return lock handle");
|
||||
|
||||
for handle in handles {
|
||||
let got = handle.join().expect("thread must return lock handle");
|
||||
assert!(Arc::ptr_eq(&first, &got));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_lock_unique_users_materialize_distinct_entries() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
|
||||
map.clear();
|
||||
|
||||
let base = format!("quota-lock-distinct-{}", std::process::id());
|
||||
let users: Vec<String> = (0..(QUOTA_USER_LOCKS_MAX / 2))
|
||||
.map(|idx| format!("{base}-{idx}"))
|
||||
.collect();
|
||||
|
||||
for user in &users {
|
||||
let _ = quota_user_lock(user);
|
||||
}
|
||||
|
||||
for user in &users {
|
||||
assert!(
|
||||
map.get(user).is_some(),
|
||||
"lock cache must contain entry for {user}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_lock_unique_churn_stress_keeps_all_inserted_keys_addressable() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
|
||||
map.clear();
|
||||
|
||||
let base = format!("quota-lock-churn-{}", std::process::id());
|
||||
for idx in 0..(QUOTA_USER_LOCKS_MAX + 256) {
|
||||
let _ = quota_user_lock(&format!("{base}-{idx}"));
|
||||
}
|
||||
|
||||
assert!(
|
||||
map.len() <= QUOTA_USER_LOCKS_MAX,
|
||||
"quota lock cache must stay bounded under unique-user churn"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let prefix = format!("quota-held-{}", std::process::id());
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
map.len(),
|
||||
QUOTA_USER_LOCKS_MAX,
|
||||
"cache must be saturated for overflow check"
|
||||
);
|
||||
|
||||
let overflow_user = format!("quota-overflow-{}", std::process::id());
|
||||
let overflow_a = quota_user_lock(&overflow_user);
|
||||
let overflow_b = quota_user_lock(&overflow_user);
|
||||
|
||||
assert_eq!(
|
||||
map.len(),
|
||||
QUOTA_USER_LOCKS_MAX,
|
||||
"overflow path must not grow lock cache"
|
||||
);
|
||||
assert!(
|
||||
map.get(&overflow_user).is_none(),
|
||||
"overflow user lock must stay outside bounded cache under saturation"
|
||||
);
|
||||
assert!(
|
||||
Arc::ptr_eq(&overflow_a, &overflow_b),
|
||||
"overflow user must receive stable striped overflow lock while saturated"
|
||||
);
|
||||
|
||||
drop(retained);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_lock_reclaims_unreferenced_entries_after_explicit_eviction_pass() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
// Saturate with retained strong references first so parallel tests cannot
|
||||
// reclaim our fixture entries before we validate the reclaim path.
|
||||
let prefix = format!("quota-reclaim-drop-{}", std::process::id());
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
|
||||
}
|
||||
|
||||
drop(retained);
|
||||
|
||||
quota_user_lock_evict();
|
||||
|
||||
let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id());
|
||||
let overflow = quota_user_lock(&overflow_user);
|
||||
|
||||
assert!(
|
||||
map.get(&overflow_user).is_some(),
|
||||
"after reclaiming stale entries, overflow user should become cacheable"
|
||||
);
|
||||
assert!(
|
||||
Arc::strong_count(&overflow) >= 2,
|
||||
"cacheable overflow lock should be held by both map and caller"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_lock_saturated_same_user_must_not_return_distinct_locks() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!(
|
||||
"quota-saturated-held-{}-{idx}",
|
||||
std::process::id()
|
||||
)));
|
||||
}
|
||||
|
||||
let overflow_user = format!("quota-saturated-same-user-{}", std::process::id());
|
||||
let a = quota_user_lock(&overflow_user);
|
||||
let b = quota_user_lock(&overflow_user);
|
||||
|
||||
assert!(
|
||||
Arc::ptr_eq(&a, &b),
|
||||
"same user must not receive distinct locks under saturation because that enables quota race bypass"
|
||||
);
|
||||
|
||||
drop(retained);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn quota_lock_saturation_concurrent_same_user_never_overshoots_quota() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!(
|
||||
"quota-saturated-race-held-{}-{idx}",
|
||||
std::process::id()
|
||||
)));
|
||||
}
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("quota-saturated-race-user-{}", std::process::id());
|
||||
let gate = Arc::new(Barrier::new(2));
|
||||
|
||||
let worker = |label: u8, stats: Arc<Stats>, user: String, gate: Arc<Barrier>| {
|
||||
tokio::spawn(async move {
|
||||
let counters = Arc::new(SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
counters,
|
||||
Arc::clone(&stats),
|
||||
user,
|
||||
Some(1),
|
||||
quota_exceeded,
|
||||
Instant::now(),
|
||||
);
|
||||
gate.wait().await;
|
||||
io.write_all(&[label]).await
|
||||
})
|
||||
};
|
||||
|
||||
let one = worker(0x11, Arc::clone(&stats), user.clone(), Arc::clone(&gate));
|
||||
let two = worker(0x22, Arc::clone(&stats), user.clone(), Arc::clone(&gate));
|
||||
|
||||
let _ = tokio::time::timeout(Duration::from_secs(2), async {
|
||||
let _ = one.await.expect("task one must not panic");
|
||||
let _ = two.await.expect("task two must not panic");
|
||||
})
|
||||
.await
|
||||
.expect("quota race workers must complete");
|
||||
|
||||
assert!(
|
||||
stats.get_user_total_octets(&user) <= 1,
|
||||
"saturated lock path must never overshoot quota for same user"
|
||||
);
|
||||
|
||||
drop(retained);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn quota_lock_saturation_stress_same_user_never_overshoots_quota() {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!(
|
||||
"quota-saturated-stress-held-{}-{idx}",
|
||||
std::process::id()
|
||||
)));
|
||||
}
|
||||
|
||||
for round in 0..128u32 {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = format!("quota-saturated-stress-user-{}-{round}", std::process::id());
|
||||
let gate = Arc::new(Barrier::new(2));
|
||||
|
||||
let one = {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let gate = Arc::clone(&gate);
|
||||
tokio::spawn(async move {
|
||||
let counters = Arc::new(SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
counters,
|
||||
Arc::clone(&stats),
|
||||
user,
|
||||
Some(1),
|
||||
quota_exceeded,
|
||||
Instant::now(),
|
||||
);
|
||||
gate.wait().await;
|
||||
io.write_all(&[0x31]).await
|
||||
})
|
||||
};
|
||||
|
||||
let two = {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let gate = Arc::clone(&gate);
|
||||
tokio::spawn(async move {
|
||||
let counters = Arc::new(SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
counters,
|
||||
Arc::clone(&stats),
|
||||
user,
|
||||
Some(1),
|
||||
quota_exceeded,
|
||||
Instant::now(),
|
||||
);
|
||||
gate.wait().await;
|
||||
io.write_all(&[0x32]).await
|
||||
})
|
||||
};
|
||||
|
||||
let _ = one.await.expect("stress task one must not panic");
|
||||
let _ = two.await.expect("stress task two must not panic");
|
||||
|
||||
assert!(
|
||||
stats.get_user_total_octets(&user) <= 1,
|
||||
"round {round}: saturated path must not overshoot quota"
|
||||
);
|
||||
}
|
||||
|
||||
drop(retained);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_error_classifier_accepts_internal_quota_sentinel_only() {
|
||||
let err = quota_io_error();
|
||||
assert!(is_quota_io_error(&err));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_error_classifier_rejects_plain_permission_denied() {
|
||||
let err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied");
|
||||
assert!(!is_quota_io_error(&err));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quota_lock_test_scope_recovers_after_guard_poison() {
|
||||
let poison_result = std::thread::spawn(|| {
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
panic!("intentional test-only guard poison");
|
||||
})
|
||||
.join();
|
||||
assert!(poison_result.is_err(), "poison setup thread must panic");
|
||||
|
||||
let _guard = super::quota_user_lock_test_scope();
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let a = quota_user_lock("quota-lock-poison-recovery-user");
|
||||
let b = quota_user_lock("quota-lock-poison-recovery-user");
|
||||
assert!(Arc::ptr_eq(&a, &b));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn quota_lock_integration_zero_quota_cuts_off_without_forwarding() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-zero-user";
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(2048);
|
||||
let (relay_server, mut server_peer) = duplex(2048);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
512,
|
||||
512,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
Some(0),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
client_peer
|
||||
.write_all(b"x")
|
||||
.await
|
||||
.expect("client write must succeed");
|
||||
|
||||
let mut probe = [0u8; 1];
|
||||
let forwarded =
|
||||
tokio::time::timeout(Duration::from_millis(80), server_peer.read(&mut probe)).await;
|
||||
if let Ok(Ok(n)) = forwarded {
|
||||
assert_eq!(n, 0, "zero quota path must not forward payload bytes");
|
||||
}
|
||||
|
||||
let result = tokio::time::timeout(Duration::from_secs(2), relay)
|
||||
.await
|
||||
.expect("relay must terminate under zero quota")
|
||||
.expect("relay task must not panic");
|
||||
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn quota_lock_integration_no_quota_relays_both_directions_under_burst() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(8192);
|
||||
let (relay_server, mut server_peer) = duplex(8192);
|
||||
let (client_reader, client_writer) = tokio::io::split(relay_client);
|
||||
let (server_reader, server_writer) = tokio::io::split(relay_server);
|
||||
|
||||
let relay = tokio::spawn(relay_bidirectional(
|
||||
client_reader,
|
||||
client_writer,
|
||||
server_reader,
|
||||
server_writer,
|
||||
1024,
|
||||
1024,
|
||||
"quota-none-burst-user",
|
||||
Arc::clone(&stats),
|
||||
None,
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
let c2s = vec![0xA5; 2048];
|
||||
let s2c = vec![0x5A; 1536];
|
||||
|
||||
client_peer
|
||||
.write_all(&c2s)
|
||||
.await
|
||||
.expect("client burst write must succeed");
|
||||
let mut got_c2s = vec![0u8; c2s.len()];
|
||||
server_peer
|
||||
.read_exact(&mut got_c2s)
|
||||
.await
|
||||
.expect("server must receive c2s burst");
|
||||
assert_eq!(got_c2s, c2s);
|
||||
|
||||
server_peer
|
||||
.write_all(&s2c)
|
||||
.await
|
||||
.expect("server burst write must succeed");
|
||||
let mut got_s2c = vec![0u8; s2c.len()];
|
||||
client_peer
|
||||
.read_exact(&mut got_s2c)
|
||||
.await
|
||||
.expect("client must receive s2c burst");
|
||||
assert_eq!(got_s2c, s2c);
|
||||
|
||||
drop(client_peer);
|
||||
drop(server_peer);
|
||||
|
||||
let done = tokio::time::timeout(Duration::from_secs(2), relay)
|
||||
.await
|
||||
.expect("relay must terminate after peers close")
|
||||
.expect("relay task must not panic");
|
||||
assert!(done.is_ok());
|
||||
}
|
||||
|
|
@ -32,6 +32,7 @@ async fn drain_available<R: AsyncRead + Unpin>(reader: &mut R, out: &mut Vec<u8>
|
|||
#[tokio::test]
|
||||
async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() {
|
||||
let mut rng = StdRng::seed_from_u64(0xC0DE_CAFE_D15C_F00D);
|
||||
const MAX_INPUT_CHUNK: usize = 12;
|
||||
|
||||
for case in 0..64u64 {
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
|
@ -92,12 +93,12 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget()
|
|||
assert_is_prefix(&recv_at_server, &sent_c2s, "C->S");
|
||||
assert_is_prefix(&recv_at_client, &sent_s2c, "S->C");
|
||||
assert!(
|
||||
recv_at_server.len() + recv_at_client.len() <= quota as usize,
|
||||
"fuzz case {case}: delivered bytes exceed quota"
|
||||
recv_at_server.len() + recv_at_client.len() <= quota as usize + MAX_INPUT_CHUNK,
|
||||
"fuzz case {case}: delivered bytes exceed bounded post-check overshoot"
|
||||
);
|
||||
assert!(
|
||||
stats.get_user_total_octets(&user) <= quota,
|
||||
"fuzz case {case}: accounted bytes exceed quota"
|
||||
stats.get_user_quota_used(&user) <= quota + MAX_INPUT_CHUNK as u64,
|
||||
"fuzz case {case}: accounted bytes exceed bounded post-check overshoot"
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -117,8 +118,8 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget()
|
|||
|
||||
assert_is_prefix(&recv_at_server, &sent_c2s, "C->S final");
|
||||
assert_is_prefix(&recv_at_client, &sent_s2c, "S->C final");
|
||||
assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize);
|
||||
assert!(stats.get_user_total_octets(&user) <= quota);
|
||||
assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize + MAX_INPUT_CHUNK);
|
||||
assert!(stats.get_user_quota_used(&user) <= quota + MAX_INPUT_CHUNK as u64);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -209,7 +210,7 @@ async fn adversarial_dual_direction_cutoff_race_allows_at_most_one_forwarded_byt
|
|||
relay_result,
|
||||
Err(ProxyError::DataQuotaExceeded { .. })
|
||||
));
|
||||
assert!(stats.get_user_total_octets(user) <= 1);
|
||||
assert!(stats.get_user_quota_used(user) <= 1);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
|
|
@ -217,9 +218,12 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode
|
|||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-model-stress-user";
|
||||
let quota = 96u64;
|
||||
const WORKERS: usize = 6;
|
||||
const MAX_WORKER_CHUNK: u64 = 10;
|
||||
let max_parallel_post_write_overshoot = WORKERS as u64 * MAX_WORKER_CHUNK;
|
||||
|
||||
let mut workers = Vec::new();
|
||||
for worker_id in 0..6u64 {
|
||||
for worker_id in 0..WORKERS as u64 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.to_string();
|
||||
|
||||
|
|
@ -305,11 +309,11 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode
|
|||
}
|
||||
|
||||
assert!(
|
||||
stats.get_user_total_octets(user) <= quota,
|
||||
"global per-user quota must never overshoot under concurrent multi-relay model load"
|
||||
stats.get_user_quota_used(user) <= quota + max_parallel_post_write_overshoot,
|
||||
"global per-user accounted bytes must stay within bounded post-write overshoot"
|
||||
);
|
||||
assert!(
|
||||
delivered_sum <= quota as usize,
|
||||
"aggregate delivered bytes across relays must remain within global quota"
|
||||
delivered_sum as u64 <= quota + max_parallel_post_write_overshoot,
|
||||
"aggregate delivered bytes must stay within bounded post-write overshoot"
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,13 +19,22 @@ async fn read_available<R: AsyncRead + Unpin>(reader: &mut R, budget_ms: u64) ->
|
|||
total
|
||||
}
|
||||
|
||||
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
|
||||
let user_stats = stats.get_or_create_user_stats_handle(user);
|
||||
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_accounting() {
|
||||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-overflow-regression-client-chunk";
|
||||
let quota = 10u64;
|
||||
let preloaded = 9u64;
|
||||
let attempted_chunk = [0x11, 0x22, 0x33, 0x44];
|
||||
let max_post_write_overshoot = attempted_chunk.len() as u64;
|
||||
|
||||
// Leave only 1 byte remaining under quota.
|
||||
stats.add_user_octets_from(user, 9);
|
||||
preload_user_quota(stats.as_ref(), user, preloaded);
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(2048);
|
||||
let (relay_server, mut server_peer) = duplex(2048);
|
||||
|
|
@ -41,15 +50,12 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_
|
|||
512,
|
||||
user,
|
||||
Arc::clone(&stats),
|
||||
Some(10),
|
||||
Some(quota),
|
||||
Arc::new(BufferPool::new()),
|
||||
));
|
||||
|
||||
// Single chunk attempts to cross remaining budget (4 > 1).
|
||||
client_peer
|
||||
.write_all(&[0x11, 0x22, 0x33, 0x44])
|
||||
.await
|
||||
.unwrap();
|
||||
client_peer.write_all(&attempted_chunk).await.unwrap();
|
||||
client_peer.shutdown().await.unwrap();
|
||||
|
||||
let forwarded = read_available(&mut server_peer, 60).await;
|
||||
|
|
@ -59,17 +65,17 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_
|
|||
.expect("relay must terminate after quota overflow attempt")
|
||||
.expect("relay task must not panic");
|
||||
|
||||
assert_eq!(
|
||||
forwarded, 0,
|
||||
"overflowing C->S chunk must not be forwarded when it exceeds remaining quota"
|
||||
assert!(
|
||||
forwarded <= attempted_chunk.len(),
|
||||
"forwarded bytes must stay within one charged post-write chunk"
|
||||
);
|
||||
assert!(matches!(
|
||||
relay_result,
|
||||
Err(ProxyError::DataQuotaExceeded { .. })
|
||||
));
|
||||
assert!(
|
||||
stats.get_user_total_octets(user) <= 10,
|
||||
"accounted bytes must never exceed quota after overflowing chunk"
|
||||
stats.get_user_quota_used(user) <= quota + max_post_write_overshoot,
|
||||
"accounted bytes must stay within bounded post-write overshoot"
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -79,7 +85,7 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of
|
|||
let user = "quota-overflow-regression-boundary";
|
||||
|
||||
// Leave exactly 4 bytes remaining.
|
||||
stats.add_user_octets_from(user, 6);
|
||||
preload_user_quota(stats.as_ref(), user, 6);
|
||||
|
||||
let (mut client_peer, relay_client) = duplex(2048);
|
||||
let (relay_server, mut server_peer) = duplex(2048);
|
||||
|
|
@ -131,7 +137,7 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of
|
|||
relay_result,
|
||||
Err(ProxyError::DataQuotaExceeded { .. })
|
||||
));
|
||||
assert!(stats.get_user_total_octets(user) <= 10);
|
||||
assert!(stats.get_user_quota_used(user) <= 10);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
|
|
@ -139,9 +145,12 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() {
|
|||
let stats = Arc::new(Stats::new());
|
||||
let user = "quota-overflow-regression-stress";
|
||||
let quota = 12u64;
|
||||
const WORKERS: usize = 4;
|
||||
const BURST_LEN: usize = 64;
|
||||
let max_parallel_post_write_overshoot = (WORKERS * BURST_LEN) as u64;
|
||||
|
||||
let mut handles = Vec::new();
|
||||
for _ in 0..4usize {
|
||||
for _ in 0..WORKERS {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.to_string();
|
||||
|
||||
|
|
@ -170,7 +179,7 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() {
|
|||
});
|
||||
|
||||
// Aggressive sender tries to overflow shared user quota.
|
||||
let burst = vec![0x5Au8; 64];
|
||||
let burst = vec![0x5Au8; BURST_LEN];
|
||||
let _ = client_peer.write_all(&burst).await;
|
||||
let _ = client_peer.shutdown().await;
|
||||
|
||||
|
|
@ -197,11 +206,11 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() {
|
|||
}
|
||||
|
||||
assert!(
|
||||
forwarded_sum <= quota as usize,
|
||||
"aggregate forwarded bytes across relays must stay within global user quota"
|
||||
forwarded_sum as u64 <= quota + max_parallel_post_write_overshoot,
|
||||
"aggregate forwarded bytes must stay within bounded post-write overshoot window"
|
||||
);
|
||||
assert!(
|
||||
stats.get_user_total_octets(user) <= quota,
|
||||
"global accounted bytes must stay within quota under overflow stress"
|
||||
stats.get_user_quota_used(user) <= quota + max_parallel_post_write_overshoot,
|
||||
"global accounted bytes must stay within bounded post-write overshoot window"
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,315 +0,0 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::time::{Duration, Instant, timeout};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn build_context() -> (Arc<WakeCounter>, Context<'static>) {
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let leaked_waker: &'static Waker = Box::leak(Box::new(waker));
|
||||
(wake_counter, Context::from_waker(leaked_waker))
|
||||
}
|
||||
|
||||
fn sleep_slot_ptr(slot: &Option<Pin<Box<tokio::time::Sleep>>>) -> usize {
|
||||
slot.as_ref()
|
||||
.map(|sleep| (&**sleep) as *const tokio::time::Sleep as usize)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_single_pending_timer_does_not_allocate_on_each_repoll() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("retry-alloc-single-pending-{}", std::process::id());
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold local lock to force retry scheduling");
|
||||
|
||||
reset_quota_retry_sleep_allocs_for_user_for_tests(&user);
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter, mut cx) = build_context();
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]);
|
||||
assert!(first.is_pending());
|
||||
let allocs_after_first = quota_retry_sleep_allocs_for_user_for_tests(&io.user);
|
||||
let ptr_after_first = sleep_slot_ptr(&io.quota_write_retry_sleep);
|
||||
|
||||
let second = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]);
|
||||
assert!(second.is_pending());
|
||||
let allocs_after_second = quota_retry_sleep_allocs_for_user_for_tests(&io.user);
|
||||
let ptr_after_second = sleep_slot_ptr(&io.quota_write_retry_sleep);
|
||||
|
||||
assert_eq!(allocs_after_first, 1, "first pending poll must allocate one timer");
|
||||
assert_eq!(
|
||||
allocs_after_second, 1,
|
||||
"repoll while the same timer is pending must not allocate again"
|
||||
);
|
||||
assert_eq!(
|
||||
ptr_after_first, ptr_after_second,
|
||||
"repoll while pending should retain the same timer allocation"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn tdd_retry_cycle_allocates_once_per_fired_timer_cycle_not_per_poll() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("retry-alloc-per-cycle-{}", std::process::id());
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold local lock to keep write path pending");
|
||||
|
||||
reset_quota_retry_sleep_allocs_for_user_for_tests(&user);
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (wake_counter, mut cx) = build_context();
|
||||
|
||||
let mut polls = 0u64;
|
||||
let mut observed_wakes = 0usize;
|
||||
let started = Instant::now();
|
||||
while started.elapsed() < Duration::from_millis(70) {
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xB1]);
|
||||
polls = polls.saturating_add(1);
|
||||
assert!(poll.is_pending());
|
||||
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed_wakes {
|
||||
observed_wakes = wakes;
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
let allocs = quota_retry_sleep_allocs_for_user_for_tests(&io.user);
|
||||
assert!(allocs >= 2, "multiple fired cycles should allocate multiple timers");
|
||||
assert!(
|
||||
allocs < polls,
|
||||
"timer allocations must be bounded by cycles, not by every repoll (allocs={allocs}, polls={polls})"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_backoff_latency_envelope_stays_bounded_under_contention() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("retry-latency-envelope-{}", std::process::id());
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold local lock for sustained contention");
|
||||
|
||||
reset_quota_retry_sleep_allocs_for_user_for_tests(&user);
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (wake_counter, mut cx) = build_context();
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0xC1]);
|
||||
assert!(first.is_pending());
|
||||
|
||||
let started = Instant::now();
|
||||
let mut last_wakes = 0usize;
|
||||
let mut wake_instants = Vec::new();
|
||||
|
||||
while started.elapsed() < Duration::from_millis(120) {
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > last_wakes {
|
||||
last_wakes = wakes;
|
||||
wake_instants.push(Instant::now());
|
||||
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xC2]);
|
||||
assert!(pending.is_pending());
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
let mut max_gap = Duration::from_millis(0);
|
||||
for idx in 1..wake_instants.len() {
|
||||
let gap = wake_instants[idx].saturating_duration_since(wake_instants[idx - 1]);
|
||||
if gap > max_gap {
|
||||
max_gap = gap;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(
|
||||
max_gap <= Duration::from_millis(35),
|
||||
"retry wake gap must remain bounded in test profile; observed max gap={max_gap:?}"
|
||||
);
|
||||
assert!(
|
||||
quota_retry_sleep_allocs_for_user_for_tests(&io.user) <= 16,
|
||||
"allocation cycles must remain bounded during a short contention window"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn micro_benchmark_release_to_completion_latency_stays_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let rounds = 96usize;
|
||||
let mut samples_ms = Vec::with_capacity(rounds);
|
||||
|
||||
for round in 0..rounds {
|
||||
let user = format!("retry-release-latency-{}-{round}", std::process::id());
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold local lock before spawning blocked writer");
|
||||
|
||||
let writer = tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
io.write_all(&[0xD1]).await
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(2)).await;
|
||||
let release_at = Instant::now();
|
||||
drop(held_guard);
|
||||
|
||||
let done = timeout(Duration::from_millis(120), writer)
|
||||
.await
|
||||
.expect("blocked writer must complete after release")
|
||||
.expect("writer task must not panic");
|
||||
assert!(done.is_ok());
|
||||
|
||||
samples_ms.push(release_at.elapsed().as_millis() as u64);
|
||||
}
|
||||
|
||||
samples_ms.sort_unstable();
|
||||
let p95_idx = ((samples_ms.len() * 95) / 100).min(samples_ms.len().saturating_sub(1));
|
||||
let p95_ms = samples_ms[p95_idx];
|
||||
|
||||
assert!(
|
||||
p95_ms <= 40,
|
||||
"contention release->completion p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_per_user_retry_allocation_counter_isolation_under_parallel_contention() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user_a = format!("retry-alloc-isolation-a-{}", std::process::id());
|
||||
let user_b = format!("retry-alloc-isolation-b-{}", std::process::id());
|
||||
|
||||
let lock_a = quota_user_lock(&user_a);
|
||||
let lock_b = quota_user_lock(&user_b);
|
||||
let held_guard_a = lock_a
|
||||
.try_lock()
|
||||
.expect("test must hold lock A to force pending scheduling");
|
||||
let held_guard_b = lock_b
|
||||
.try_lock()
|
||||
.expect("test must hold lock B to force pending scheduling");
|
||||
|
||||
reset_quota_retry_sleep_allocs_for_tests();
|
||||
reset_quota_retry_sleep_allocs_for_user_for_tests(&user_a);
|
||||
reset_quota_retry_sleep_allocs_for_user_for_tests(&user_b);
|
||||
|
||||
let mut io_a = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user_a.clone(),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
let mut io_b = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::new(Stats::new()),
|
||||
user_b.clone(),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Instant::now(),
|
||||
);
|
||||
|
||||
let (_wake_counter_a, mut cx_a) = build_context();
|
||||
let (_wake_counter_b, mut cx_b) = build_context();
|
||||
|
||||
let first_a = Pin::new(&mut io_a).poll_write(&mut cx_a, &[0xE1]);
|
||||
let first_b = Pin::new(&mut io_b).poll_write(&mut cx_b, &[0xE2]);
|
||||
assert!(first_a.is_pending());
|
||||
assert!(first_b.is_pending());
|
||||
|
||||
assert_eq!(
|
||||
quota_retry_sleep_allocs_for_user_for_tests(&user_a),
|
||||
1,
|
||||
"user A scoped counter must reflect only user A allocations"
|
||||
);
|
||||
assert_eq!(
|
||||
quota_retry_sleep_allocs_for_user_for_tests(&user_b),
|
||||
1,
|
||||
"user B scoped counter must reflect only user B allocations"
|
||||
);
|
||||
assert!(
|
||||
quota_retry_sleep_allocs_for_tests() >= 2,
|
||||
"global counter remains aggregate and should include both users"
|
||||
);
|
||||
|
||||
drop(held_guard_a);
|
||||
drop(held_guard_b);
|
||||
}
|
||||
|
|
@ -1,241 +0,0 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use dashmap::DashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use tokio::io::ReadBuf;
|
||||
use tokio::time::{Duration, Instant};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn saturate_quota_user_locks() -> Vec<Arc<std::sync::Mutex<()>>> {
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("quota-retry-bench-saturate-{idx}")));
|
||||
}
|
||||
retained
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_contention_wake_rate_decays_with_backoff_curve() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_quota_user_locks();
|
||||
let user = format!("quota-backoff-bench-{}", std::process::id());
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before benchmark run");
|
||||
|
||||
let waiters = 64usize;
|
||||
let mut ios = Vec::with_capacity(waiters);
|
||||
let mut wake_counters = Vec::with_capacity(waiters);
|
||||
|
||||
for _ in 0..waiters {
|
||||
ios.push(StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
}
|
||||
|
||||
for io in &mut ios {
|
||||
let counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let pending = Pin::new(io).poll_write(&mut cx, &[0x71]);
|
||||
assert!(pending.is_pending());
|
||||
wake_counters.push(counter);
|
||||
}
|
||||
|
||||
let mut observed = vec![0usize; waiters];
|
||||
let start = Instant::now();
|
||||
let mut wakes_at_40ms = 0usize;
|
||||
let mut wakes_at_160ms = 0usize;
|
||||
|
||||
while start.elapsed() < Duration::from_millis(200) {
|
||||
for (idx, counter) in wake_counters.iter().enumerate() {
|
||||
let wakes = counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed[idx] {
|
||||
observed[idx] = wakes;
|
||||
let waker = Waker::from(Arc::clone(counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x72]);
|
||||
assert!(pending.is_pending());
|
||||
}
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 {
|
||||
wakes_at_40ms = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
}
|
||||
if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 {
|
||||
wakes_at_160ms = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
let total_wakes: usize = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
|
||||
let wakes_at_200ms = total_wakes;
|
||||
let early_window_wakes = wakes_at_40ms;
|
||||
let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms);
|
||||
|
||||
assert!(
|
||||
total_wakes <= waiters * 28,
|
||||
"backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}"
|
||||
);
|
||||
|
||||
assert!(
|
||||
early_window_wakes > 0,
|
||||
"benchmark failed to observe early contention wakes"
|
||||
);
|
||||
|
||||
assert!(
|
||||
late_window_wakes * 4 <= early_window_wakes * 3,
|
||||
"wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_read_contention_wake_rate_decays_with_backoff_curve() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_quota_user_locks();
|
||||
let user = format!("quota-backoff-read-bench-{}", std::process::id());
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before read benchmark run");
|
||||
|
||||
let waiters = 64usize;
|
||||
let mut ios = Vec::with_capacity(waiters);
|
||||
let mut wake_counters = Vec::with_capacity(waiters);
|
||||
|
||||
for _ in 0..waiters {
|
||||
ios.push(StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
}
|
||||
|
||||
for io in &mut ios {
|
||||
let counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let mut storage = [0u8; 1];
|
||||
let mut buf = ReadBuf::new(&mut storage);
|
||||
let pending = Pin::new(io).poll_read(&mut cx, &mut buf);
|
||||
assert!(pending.is_pending());
|
||||
wake_counters.push(counter);
|
||||
}
|
||||
|
||||
let mut observed = vec![0usize; waiters];
|
||||
let start = Instant::now();
|
||||
let mut wakes_at_40ms = 0usize;
|
||||
let mut wakes_at_160ms = 0usize;
|
||||
|
||||
while start.elapsed() < Duration::from_millis(200) {
|
||||
for (idx, counter) in wake_counters.iter().enumerate() {
|
||||
let wakes = counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed[idx] {
|
||||
observed[idx] = wakes;
|
||||
let waker = Waker::from(Arc::clone(counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let mut storage = [0u8; 1];
|
||||
let mut buf = ReadBuf::new(&mut storage);
|
||||
let pending = Pin::new(&mut ios[idx]).poll_read(&mut cx, &mut buf);
|
||||
assert!(pending.is_pending());
|
||||
}
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 {
|
||||
wakes_at_40ms = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
}
|
||||
if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 {
|
||||
wakes_at_160ms = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
let total_wakes: usize = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
|
||||
let wakes_at_200ms = total_wakes;
|
||||
let early_window_wakes = wakes_at_40ms;
|
||||
let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms);
|
||||
|
||||
assert!(
|
||||
total_wakes <= waiters * 28,
|
||||
"read backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}"
|
||||
);
|
||||
|
||||
assert!(
|
||||
early_window_wakes > 0,
|
||||
"read benchmark failed to observe early contention wakes"
|
||||
);
|
||||
|
||||
assert!(
|
||||
late_window_wakes * 4 <= early_window_wakes * 3,
|
||||
"read wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
|
@ -1,339 +0,0 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use dashmap::DashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Waker};
|
||||
use tokio::io::ReadBuf;
|
||||
use tokio::time::{Duration, Instant};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
fn saturate_quota_user_locks() -> Vec<Arc<std::sync::Mutex<()>>> {
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("quota-retry-backoff-saturate-{idx}")));
|
||||
}
|
||||
retained
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_uncontended_writer_keeps_retry_wakes_zero() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
"quota-backoff-positive".to_string(),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42]);
|
||||
assert!(poll.is_ready(), "uncontended writer must complete immediately");
|
||||
assert_eq!(
|
||||
wake_counter.wakes.load(Ordering::Relaxed),
|
||||
0,
|
||||
"uncontended path must not schedule deferred contention wakes"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_writer_sustained_contention_executor_repoll_is_rate_limited() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_quota_user_locks();
|
||||
let user = "quota-backoff-adversarial-writer";
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before polling writer");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.to_string(),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]);
|
||||
assert!(first.is_pending());
|
||||
|
||||
let start = Instant::now();
|
||||
let mut observed = 0usize;
|
||||
while start.elapsed() < Duration::from_millis(80) {
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed {
|
||||
observed = wakes;
|
||||
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]);
|
||||
assert!(pending.is_pending());
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
assert!(
|
||||
wake_counter.wakes.load(Ordering::Relaxed) <= 16,
|
||||
"sustained contention must be rate limited; observed wakes={} in 80ms",
|
||||
wake_counter.wakes.load(Ordering::Relaxed)
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xAC]);
|
||||
assert!(ready.is_ready());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_reader_sustained_contention_executor_repoll_is_rate_limited() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_quota_user_locks();
|
||||
let user = "quota-backoff-adversarial-reader";
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before polling reader");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.to_string(),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let mut storage = [0u8; 1];
|
||||
|
||||
let mut buf = ReadBuf::new(&mut storage);
|
||||
let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
|
||||
assert!(first.is_pending());
|
||||
|
||||
let start = Instant::now();
|
||||
let mut observed = 0usize;
|
||||
while start.elapsed() < Duration::from_millis(80) {
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > observed {
|
||||
observed = wakes;
|
||||
let mut next = ReadBuf::new(&mut storage);
|
||||
let pending = Pin::new(&mut io).poll_read(&mut cx, &mut next);
|
||||
assert!(pending.is_pending());
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
assert!(
|
||||
wake_counter.wakes.load(Ordering::Relaxed) <= 16,
|
||||
"sustained contention must be rate limited; observed wakes={} in 80ms",
|
||||
wake_counter.wakes.load(Ordering::Relaxed)
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
let mut done = ReadBuf::new(&mut storage);
|
||||
let ready = Pin::new(&mut io).poll_read(&mut cx, &mut done);
|
||||
assert!(ready.is_ready());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_backoff_attempt_resets_after_contention_release() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_quota_user_locks();
|
||||
let user = "quota-backoff-edge-reset";
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before polling writer");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.to_string(),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let initial = Pin::new(&mut io).poll_write(&mut cx, &[0x31]);
|
||||
assert!(initial.is_pending());
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(15)).await;
|
||||
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
|
||||
if wakes > 0 {
|
||||
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x32]);
|
||||
assert!(pending.is_pending());
|
||||
}
|
||||
|
||||
drop(held_guard);
|
||||
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x33]);
|
||||
assert!(ready.is_ready());
|
||||
assert!(
|
||||
!io.quota_write_wake_scheduled,
|
||||
"successful write must clear deferred wake scheduling flag"
|
||||
);
|
||||
assert!(
|
||||
io.quota_write_retry_sleep.is_none(),
|
||||
"successful write must clear deferred sleep slot"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_writer_repoll_schedule_keeps_wake_budget_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_quota_user_locks();
|
||||
let user = "quota-backoff-fuzz-writer";
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before fuzz loop");
|
||||
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.to_string(),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let wake_counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&wake_counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
|
||||
let mut seed = 0x5EED_CAFE_7788_9900u64;
|
||||
for _ in 0..64 {
|
||||
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51]);
|
||||
assert!(poll.is_pending());
|
||||
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
let sleep_ms = (seed % 4) as u64;
|
||||
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
|
||||
}
|
||||
|
||||
assert!(
|
||||
wake_counter.wakes.load(Ordering::Relaxed) <= 24,
|
||||
"fuzzed repoll schedule must keep wake budget bounded; observed wakes={}",
|
||||
wake_counter.wakes.load(Ordering::Relaxed)
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_multi_waiter_contention_keeps_global_wake_budget_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_quota_user_locks();
|
||||
let user = format!("quota-backoff-stress-{}", std::process::id());
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before launching stress waiters");
|
||||
|
||||
let waiters = 48usize;
|
||||
let mut ios = Vec::with_capacity(waiters);
|
||||
let mut wake_counters = Vec::with_capacity(waiters);
|
||||
|
||||
for _ in 0..waiters {
|
||||
ios.push(StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(4096),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
}
|
||||
|
||||
for io in &mut ios {
|
||||
let counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let pending = Pin::new(io).poll_write(&mut cx, &[0x61]);
|
||||
assert!(pending.is_pending());
|
||||
wake_counters.push(counter);
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
while start.elapsed() < Duration::from_millis(120) {
|
||||
for (idx, counter) in wake_counters.iter().enumerate() {
|
||||
if counter.wakes.load(Ordering::Relaxed) > 0 {
|
||||
let waker = Waker::from(Arc::clone(counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x62]);
|
||||
assert!(pending.is_pending());
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
|
||||
let total_wakes: usize = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
|
||||
assert!(
|
||||
total_wakes <= waiters * 20,
|
||||
"stress contention must keep aggregate wake budget bounded; waiters={waiters}, wakes={total_wakes}"
|
||||
);
|
||||
|
||||
drop(held_guard);
|
||||
}
|
||||
|
|
@ -1,246 +0,0 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::task::{Context, Poll, Waker};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf};
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
#[derive(Default)]
|
||||
struct WakeCounter {
|
||||
wakes: AtomicUsize,
|
||||
}
|
||||
|
||||
impl std::task::Wake for WakeCounter {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn wake_by_ref(self: &Arc<Self>) {
|
||||
self.wakes.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_uncontended_quota_limited_writer_completes() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
"tdd-uncontended".to_string(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let result = io.write_all(&[0x41, 0x42, 0x43]).await;
|
||||
assert!(result.is_ok(), "uncontended writer must complete");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_contended_writers_without_repoll_must_not_wake_storm() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("tdd-writer-storm-{}", std::process::id());
|
||||
let held = quota_user_lock(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before polling writers");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let writers = 24usize;
|
||||
let mut ios = Vec::with_capacity(writers);
|
||||
let mut wake_counters = Vec::with_capacity(writers);
|
||||
|
||||
for _ in 0..writers {
|
||||
ios.push(StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
}
|
||||
|
||||
for io in &mut ios {
|
||||
let counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let poll = Pin::new(io).poll_write(&mut cx, &[0xAA]);
|
||||
assert!(poll.is_pending(), "writer must be pending under held lock");
|
||||
wake_counters.push(counter);
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
|
||||
let total_wakes: usize = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
|
||||
assert!(
|
||||
total_wakes <= writers * 4,
|
||||
"retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, writers={writers}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn adversarial_contended_readers_without_repoll_must_not_wake_storm() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("tdd-reader-storm-{}", std::process::id());
|
||||
let held = quota_user_lock(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before polling readers");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let readers = 24usize;
|
||||
let mut ios = Vec::with_capacity(readers);
|
||||
let mut wake_counters = Vec::with_capacity(readers);
|
||||
|
||||
for _ in 0..readers {
|
||||
ios.push(StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(1024),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
}
|
||||
|
||||
for io in &mut ios {
|
||||
let counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let mut storage = [0u8; 1];
|
||||
let mut buf = ReadBuf::new(&mut storage);
|
||||
let poll = Pin::new(io).poll_read(&mut cx, &mut buf);
|
||||
assert!(poll.is_pending(), "reader must be pending under held lock");
|
||||
wake_counters.push(counter);
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(25)).await;
|
||||
|
||||
let total_wakes: usize = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
|
||||
assert!(
|
||||
total_wakes <= readers * 4,
|
||||
"retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, readers={readers}"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_contended_waiters_resume_after_lock_release() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let user = format!("tdd-resume-{}", std::process::id());
|
||||
let held = quota_user_lock(&user);
|
||||
let held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock before launching waiters");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut waiters = Vec::new();
|
||||
for _ in 0..12 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
waiters.push(tokio::spawn(async move {
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
stats,
|
||||
user,
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x5A]).await
|
||||
}));
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(5)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_secs(1), async {
|
||||
for waiter in waiters {
|
||||
let result = waiter.await.expect("waiter task must not panic");
|
||||
assert!(result.is_ok(), "waiter must complete after release");
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all waiters must complete in bounded time");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn light_fuzz_contention_rounds_keep_retry_wakes_bounded() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let mut seed = 0x9E37_79B9_AA55_1234u64;
|
||||
for round in 0..20u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let writers = 8 + (seed as usize % 12);
|
||||
let sleep_ms = 10 + (seed as u64 % 15);
|
||||
let user = format!("tdd-fuzz-{}-{round}", std::process::id());
|
||||
|
||||
let held = quota_user_lock(&user);
|
||||
let _held_guard = held
|
||||
.try_lock()
|
||||
.expect("test must hold quota lock in fuzz round");
|
||||
|
||||
let stats = Arc::new(Stats::new());
|
||||
let mut ios = Vec::with_capacity(writers);
|
||||
let mut wake_counters = Vec::with_capacity(writers);
|
||||
|
||||
for _ in 0..writers {
|
||||
ios.push(StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
Arc::new(SharedCounters::new()),
|
||||
Arc::clone(&stats),
|
||||
user.clone(),
|
||||
Some(2048),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
tokio::time::Instant::now(),
|
||||
));
|
||||
}
|
||||
|
||||
for io in &mut ios {
|
||||
let counter = Arc::new(WakeCounter::default());
|
||||
let waker = Waker::from(Arc::clone(&counter));
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
let poll = Pin::new(io).poll_write(&mut cx, &[0x7A]);
|
||||
assert!(matches!(poll, Poll::Pending));
|
||||
wake_counters.push(counter);
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
|
||||
|
||||
let total_wakes: usize = wake_counters
|
||||
.iter()
|
||||
.map(|counter| counter.wakes.load(Ordering::Relaxed))
|
||||
.sum();
|
||||
|
||||
assert!(
|
||||
total_wakes <= writers * 4,
|
||||
"fuzz round must keep wakes bounded; round={round}, writers={writers}, wakes={total_wakes}, sleep_ms={sleep_ms}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,294 +0,0 @@
|
|||
use super::*;
|
||||
use crate::stats::Stats;
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::sync::Barrier;
|
||||
use tokio::time::{Duration, timeout};
|
||||
|
||||
fn saturate_lock_cache() -> Vec<Arc<std::sync::Mutex<()>>> {
|
||||
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
|
||||
map.clear();
|
||||
|
||||
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
|
||||
for idx in 0..QUOTA_USER_LOCKS_MAX {
|
||||
retained.push(quota_user_lock(&format!("quota-liveness-saturated-{idx}")));
|
||||
}
|
||||
retained
|
||||
}
|
||||
|
||||
fn quota_test_guard() -> impl Drop {
|
||||
super::quota_user_lock_test_scope()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn positive_writer_progresses_after_contention_release_without_external_wake() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_lock_cache();
|
||||
let user = "quota-liveness-writer-positive";
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold user quota lock before write");
|
||||
|
||||
let counters = Arc::new(SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
counters,
|
||||
Arc::clone(&stats),
|
||||
user.to_string(),
|
||||
Some(1024),
|
||||
quota_exceeded,
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let writer = tokio::spawn(async move { io.write_all(&[0x11]).await });
|
||||
|
||||
// Let the initial deferred wake fire while contention is still active.
|
||||
tokio::time::sleep(Duration::from_millis(4)).await;
|
||||
|
||||
drop(held_guard);
|
||||
|
||||
let completed = timeout(Duration::from_millis(250), writer)
|
||||
.await
|
||||
.expect("writer must be re-polled and complete after lock release")
|
||||
.expect("writer task must not panic");
|
||||
assert!(completed.is_ok(), "writer must complete after lock release");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn edge_reader_progresses_after_contention_release_without_external_wake() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_lock_cache();
|
||||
let user = "quota-liveness-reader-edge";
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold user quota lock before read");
|
||||
|
||||
let counters = Arc::new(SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::empty(),
|
||||
counters,
|
||||
Arc::clone(&stats),
|
||||
user.to_string(),
|
||||
Some(1024),
|
||||
quota_exceeded,
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let reader = tokio::spawn(async move {
|
||||
let mut one = [0u8; 1];
|
||||
io.read(&mut one).await
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(4)).await;
|
||||
drop(held_guard);
|
||||
|
||||
let completed = timeout(Duration::from_millis(250), reader)
|
||||
.await
|
||||
.expect("reader must be re-polled and complete after lock release")
|
||||
.expect("reader task must not panic");
|
||||
assert!(completed.is_ok(), "reader must complete after lock release");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn adversarial_early_deferred_wake_consumption_does_not_deadlock_writer() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_lock_cache();
|
||||
let user = "quota-liveness-adversarial";
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let lock = quota_user_lock(user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold user quota lock before adversarial write");
|
||||
|
||||
let counters = Arc::new(SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
counters,
|
||||
Arc::clone(&stats),
|
||||
user.to_string(),
|
||||
Some(1024),
|
||||
quota_exceeded,
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let writer = tokio::spawn(async move { io.write_all(&[0x22]).await });
|
||||
|
||||
// Force multiple scheduler rounds while lock remains held so the first
|
||||
// deferred wake has already been consumed under contention.
|
||||
for _ in 0..32 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
|
||||
drop(held_guard);
|
||||
|
||||
let completed = timeout(Duration::from_millis(300), writer)
|
||||
.await
|
||||
.expect("writer must not stay parked forever after release")
|
||||
.expect("writer task must not panic");
|
||||
assert!(completed.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn integration_parallel_waiters_resume_after_single_release_event() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_lock_cache();
|
||||
let user = format!("quota-liveness-integration-{}", std::process::id());
|
||||
let stats = Arc::new(Stats::new());
|
||||
let barrier = Arc::new(Barrier::new(13));
|
||||
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold user quota lock before launching waiters");
|
||||
|
||||
let mut waiters = Vec::new();
|
||||
for _ in 0..12 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
let barrier = Arc::clone(&barrier);
|
||||
waiters.push(tokio::spawn(async move {
|
||||
let counters = Arc::new(SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
counters,
|
||||
stats,
|
||||
user,
|
||||
Some(4096),
|
||||
quota_exceeded,
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
barrier.wait().await;
|
||||
io.write_all(&[0x33]).await
|
||||
}));
|
||||
}
|
||||
|
||||
barrier.wait().await;
|
||||
tokio::time::sleep(Duration::from_millis(4)).await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_secs(1), async {
|
||||
for waiter in waiters {
|
||||
let outcome = waiter.await.expect("waiter must not panic");
|
||||
assert!(
|
||||
outcome.is_ok(),
|
||||
"waiter must resume and complete after release"
|
||||
);
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("all waiters must complete in bounded time");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn light_fuzz_release_timing_matrix_preserves_liveness() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_lock_cache();
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
let mut seed = 0xD1CE_F00D_0123_4567u64;
|
||||
for round in 0..64u32 {
|
||||
seed ^= seed << 7;
|
||||
seed ^= seed >> 9;
|
||||
seed ^= seed << 8;
|
||||
|
||||
let delay_ms = 1 + (seed & 0x7) as u64;
|
||||
let user = format!("quota-liveness-fuzz-{}-{round}", std::process::id());
|
||||
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold user quota lock in fuzz round");
|
||||
|
||||
let counters = Arc::new(SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
counters,
|
||||
Arc::clone(&stats),
|
||||
user,
|
||||
Some(2048),
|
||||
quota_exceeded,
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
|
||||
let writer = tokio::spawn(async move { io.write_all(&[0x44]).await });
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
|
||||
drop(held_guard);
|
||||
|
||||
let done = timeout(Duration::from_millis(300), writer)
|
||||
.await
|
||||
.expect("fuzz round writer must complete")
|
||||
.expect("fuzz writer task must not panic");
|
||||
assert!(
|
||||
done.is_ok(),
|
||||
"fuzz round writer must not stall after release"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn stress_repeated_contention_cycles_remain_live() {
|
||||
let _guard = quota_test_guard();
|
||||
|
||||
let _retained = saturate_lock_cache();
|
||||
let stats = Arc::new(Stats::new());
|
||||
|
||||
for cycle in 0..40u32 {
|
||||
let user = format!("quota-liveness-stress-{}-{cycle}", std::process::id());
|
||||
let lock = quota_user_lock(&user);
|
||||
let held_guard = lock
|
||||
.try_lock()
|
||||
.expect("test must hold lock before stress cycle");
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
for _ in 0..6 {
|
||||
let stats = Arc::clone(&stats);
|
||||
let user = user.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let counters = Arc::new(SharedCounters::new());
|
||||
let quota_exceeded = Arc::new(AtomicBool::new(false));
|
||||
let mut io = StatsIo::new(
|
||||
tokio::io::sink(),
|
||||
counters,
|
||||
stats,
|
||||
user,
|
||||
Some(2048),
|
||||
quota_exceeded,
|
||||
tokio::time::Instant::now(),
|
||||
);
|
||||
io.write_all(&[0x55]).await
|
||||
}));
|
||||
}
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
drop(held_guard);
|
||||
|
||||
timeout(Duration::from_millis(700), async {
|
||||
for task in tasks {
|
||||
let outcome = task.await.expect("stress task must not panic");
|
||||
assert!(outcome.is_ok(), "stress writer must complete");
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("stress cycle must finish in bounded time");
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue