diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..1b6e455 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,39 @@ +name: Build + +on: + push: + branches: [ "*" ] + pull_request: + branches: [ "*" ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + name: Build + runs-on: ubuntu-latest + + permissions: + contents: read + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Install latest stable Rust toolchain + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo registry & build artifacts + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - name: Build Release + run: cargo build --release --verbose \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1012793..d01293e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -26,6 +26,9 @@ jobs: name: GNU ${{ matrix.target }} runs-on: ubuntu-latest + container: + image: rust:slim-bookworm + strategy: fail-fast: false matrix: @@ -47,8 +50,8 @@ jobs: - name: Install deps run: | - sudo apt-get update - sudo apt-get install -y \ + apt-get update + apt-get install -y \ build-essential \ clang \ lld \ @@ -69,14 +72,10 @@ jobs: if [ "${{ matrix.target }}" = "aarch64-unknown-linux-gnu" ]; then export CC=aarch64-linux-gnu-gcc export CXX=aarch64-linux-gnu-g++ - export CC_aarch64_unknown_linux_gnu=aarch64-linux-gnu-gcc - export CXX_aarch64_unknown_linux_gnu=aarch64-linux-gnu-g++ export RUSTFLAGS="-C linker=aarch64-linux-gnu-gcc" else export CC=clang export CXX=clang++ - export CC_x86_64_unknown_linux_gnu=clang - export CXX_x86_64_unknown_linux_gnu=clang++ export RUSTFLAGS="-C linker=clang -C link-arg=-fuse-ld=lld" fi @@ -85,20 +84,19 @@ jobs: - name: Package run: | mkdir -p dist - BIN=target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} - - cp "$BIN" dist/${{ env.BINARY_NAME }}-${{ matrix.target }} + cp target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} dist/telemt cd dist - tar -czf ${{ matrix.asset }}.tar.gz ${{ env.BINARY_NAME }}-${{ matrix.target }} + tar -czf ${{ matrix.asset }}.tar.gz \ + --owner=0 --group=0 --numeric-owner \ + telemt + sha256sum ${{ matrix.asset }}.tar.gz > ${{ matrix.asset }}.sha256 - uses: actions/upload-artifact@v4 with: name: ${{ matrix.asset }} - path: | - dist/${{ matrix.asset }}.tar.gz - dist/${{ matrix.asset }}.sha256 + path: dist/* # ========================== # MUSL @@ -125,43 +123,7 @@ jobs: - name: Install deps run: | apt-get update - apt-get install -y \ - musl-tools \ - pkg-config \ - curl - - - uses: actions/cache@v4 - if: matrix.target == 'aarch64-unknown-linux-musl' - with: - path: ~/.musl-aarch64 - key: musl-toolchain-aarch64-v1 - - - name: Install aarch64 musl toolchain - if: matrix.target == 'aarch64-unknown-linux-musl' - run: | - set -e - - TOOLCHAIN_DIR="$HOME/.musl-aarch64" - ARCHIVE="aarch64-linux-musl-cross.tgz" - URL="https://github.com/telemt/telemt/releases/download/toolchains/$ARCHIVE" - - if [ -x "$TOOLCHAIN_DIR/bin/aarch64-linux-musl-gcc" ]; then - echo "✅ MUSL toolchain already installed" - else - echo "⬇️ Downloading musl toolchain from Telemt GitHub Releases..." - - curl -fL \ - --retry 5 \ - --retry-delay 3 \ - --connect-timeout 10 \ - --max-time 120 \ - -o "$ARCHIVE" "$URL" - - mkdir -p "$TOOLCHAIN_DIR" - tar -xzf "$ARCHIVE" --strip-components=1 -C "$TOOLCHAIN_DIR" - fi - - echo "$TOOLCHAIN_DIR/bin" >> $GITHUB_PATH + apt-get install -y musl-tools pkg-config curl - name: Add rust target run: rustup target add ${{ matrix.target }} @@ -178,11 +140,9 @@ jobs: run: | if [ "${{ matrix.target }}" = "aarch64-unknown-linux-musl" ]; then export CC=aarch64-linux-musl-gcc - export CC_aarch64_unknown_linux_musl=aarch64-linux-musl-gcc export RUSTFLAGS="-C target-feature=+crt-static -C linker=aarch64-linux-musl-gcc" else export CC=musl-gcc - export CC_x86_64_unknown_linux_musl=musl-gcc export RUSTFLAGS="-C target-feature=+crt-static" fi @@ -191,119 +151,93 @@ jobs: - name: Package run: | mkdir -p dist - BIN=target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} - - cp "$BIN" dist/${{ env.BINARY_NAME }}-${{ matrix.target }} + cp target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} dist/telemt cd dist - tar -czf ${{ matrix.asset }}.tar.gz ${{ env.BINARY_NAME }}-${{ matrix.target }} + tar -czf ${{ matrix.asset }}.tar.gz \ + --owner=0 --group=0 --numeric-owner \ + telemt + sha256sum ${{ matrix.asset }}.tar.gz > ${{ matrix.asset }}.sha256 - uses: actions/upload-artifact@v4 with: name: ${{ matrix.asset }} - path: | - dist/${{ matrix.asset }}.tar.gz - dist/${{ matrix.asset }}.sha256 + path: dist/* # ========================== -# OpenBSD +# Release # ========================== - build-openbsd: - name: OpenBSD ${{ matrix.arch }} - runs-on: ubuntu-latest - - strategy: - fail-fast: false - matrix: - include: - - arch: x86_64 - asset: telemt-x86_64-openbsd - rustflags: -C opt-level=3 - - arch: aarch64 - asset: telemt-aarch64-openbsd - rustflags: -C opt-level=3 - - steps: - - uses: actions/checkout@v4 - - - name: Build in OpenBSD VM - uses: vmactions/openbsd-vm@v1 - with: - release: "7.8" - arch: ${{ matrix.arch }} - usesh: true - sync: sshfs - envs: RUSTFLAGS CARGO_TERM_COLOR - prepare: | - pkg_add rust - run: | - set -e - RUSTC_VERSION=$(rustc --version | awk '{print $2}') - RUSTC_MAJOR=$(echo "$RUSTC_VERSION" | cut -d. -f1) - RUSTC_MINOR=$(echo "$RUSTC_VERSION" | cut -d. -f2) - REQUIRED_MAJOR=1 - REQUIRED_MINOR=85 - if [ "$RUSTC_MAJOR" -lt "$REQUIRED_MAJOR" ] || { [ "$RUSTC_MAJOR" -eq "$REQUIRED_MAJOR" ] && [ "$RUSTC_MINOR" -lt "$REQUIRED_MINOR" ]; }; then - echo "rustc ${REQUIRED_MAJOR}.${REQUIRED_MINOR}.0 or newer is required for this project (found ${RUSTC_VERSION})." - exit 1 - fi - cargo build --release --locked --verbose - env: - RUSTFLAGS: ${{ matrix.rustflags }} - - - name: Package - run: | - mkdir -p dist - cp target/release/${{ env.BINARY_NAME }} dist/${{ env.BINARY_NAME }}-${{ matrix.arch }}-unknown-openbsd - cd dist - tar -czf ${{ matrix.asset }}.tar.gz ${{ env.BINARY_NAME }}-${{ matrix.arch }}-unknown-openbsd - sha256sum ${{ matrix.asset }}.tar.gz > ${{ matrix.asset }}.sha256 - - - uses: actions/upload-artifact@v4 - with: - name: ${{ matrix.asset }} - path: | - dist/${{ matrix.asset }}.tar.gz - dist/${{ matrix.asset }}.sha256 - -# ========================== -# Docker -# ========================== - docker: - name: Docker + release: + name: Release runs-on: ubuntu-latest needs: [build-gnu, build-musl] - continue-on-error: true + + permissions: + contents: write steps: - - uses: actions/checkout@v4 - - uses: actions/download-artifact@v4 with: path: artifacts - - name: Extract binaries + - name: Flatten run: | mkdir dist - find artifacts -name "*.tar.gz" -exec tar -xzf {} -C dist \; + find artifacts -type f -exec cp {} dist/ \; - cp dist/telemt-x86_64-unknown-linux-musl dist/telemt || true - - - uses: docker/setup-qemu-action@v3 - - uses: docker/setup-buildx-action@v3 - - - name: Login to GHCR - uses: docker/login-action@v3 + - name: Create Release + uses: softprops/action-gh-release@v2 with: - registry: ghcr.io - username: ${{ github.actor }} - password: ${{ secrets.GITHUB_TOKEN }} + files: dist/* + generate_release_notes: true + prerelease: ${{ contains(github.ref, '-') }} + +# ========================== +# Docker (FROM RELEASE) +# ========================== + docker: + name: Docker (from release) + runs-on: ubuntu-latest + needs: release + + permissions: + contents: read + packages: write + + steps: + - uses: actions/checkout@v4 + + - name: Install gh + run: apt-get update && apt-get install -y gh - name: Extract version id: vars run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT + - name: Download binary + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + mkdir dist + + gh release download ${{ steps.vars.outputs.VERSION }} \ + --repo ${{ github.repository }} \ + --pattern "telemt-x86_64-linux-musl.tar.gz" \ + --dir dist + + tar -xzf dist/telemt-x86_64-linux-musl.tar.gz -C dist + chmod +x dist/telemt + + - uses: docker/setup-qemu-action@v3 + - uses: docker/setup-buildx-action@v3 + + - uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Build & Push uses: docker/build-push-action@v6 with: @@ -314,33 +248,4 @@ jobs: ghcr.io/${{ github.repository }}:${{ steps.vars.outputs.VERSION }} ghcr.io/${{ github.repository }}:latest build-args: | - BINARY=dist/telemt - -# ========================== -# Release -# ========================== - release: - name: Release - runs-on: ubuntu-latest - needs: [build-gnu, build-musl, build-openbsd] - - permissions: - contents: write - - steps: - - uses: actions/download-artifact@v4 - with: - path: artifacts - - - name: Flatten artifacts - run: | - mkdir dist - find artifacts -type f -exec cp {} dist/ \; - - - name: Create Release - uses: softprops/action-gh-release@v2 - with: - files: dist/* - generate_release_notes: true - draft: false - prerelease: ${{ contains(github.ref, '-rc') || contains(github.ref, '-beta') || contains(github.ref, '-alpha') }} + BINARY=dist/telemt \ No newline at end of file diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml deleted file mode 100644 index d7f67e0..0000000 --- a/.github/workflows/rust.yml +++ /dev/null @@ -1,54 +0,0 @@ -name: Rust - -on: - push: - branches: [ "*" ] - pull_request: - branches: [ "*" ] - -env: - CARGO_TERM_COLOR: always - -jobs: - build: - name: Compile, Test, Lint - runs-on: ubuntu-latest - - permissions: - contents: read - actions: write - checks: write - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Install latest stable Rust toolchain - uses: dtolnay/rust-toolchain@stable - with: - components: rustfmt, clippy - - - name: Cache cargo registry & build artifacts - uses: actions/cache@v4 - with: - path: | - ~/.cargo/registry - ~/.cargo/git - target - key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-cargo- - - - name: Compile (no tests) - run: cargo check --workspace --all-features --bins --verbose - - - name: Run tests (single pass) - run: cargo test --workspace --all-features --verbose - -# clippy dont fail on warnings because of active development of telemt -# and many warnings - - name: Run clippy - run: cargo clippy -- --cap-lints warn - - - name: Check for unused dependencies - run: cargo udeps || true diff --git a/.github/workflows/stress.yml b/.github/workflows/stress.yml deleted file mode 100644 index 96b9a1b..0000000 --- a/.github/workflows/stress.yml +++ /dev/null @@ -1,57 +0,0 @@ -name: Stress Tests - -on: - workflow_dispatch: - schedule: - - cron: '0 2 * * *' - pull_request: - branches: ["*"] - paths: - - src/proxy/** - - src/transport/** - - src/stream/** - - src/protocol/** - - src/tls_front/** - - Cargo.toml - - Cargo.lock - -env: - CARGO_TERM_COLOR: always - -jobs: - quota-lock-stress: - name: Quota-lock stress loop - runs-on: ubuntu-latest - - permissions: - contents: read - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - - name: Install latest stable Rust toolchain - uses: dtolnay/rust-toolchain@stable - - - name: Cache cargo registry and build artifacts - uses: actions/cache@v4 - with: - path: | - ~/.cargo/registry - ~/.cargo/git - target - key: ${{ runner.os }}-cargo-stress-${{ hashFiles('**/Cargo.lock') }} - restore-keys: | - ${{ runner.os }}-cargo-stress- - ${{ runner.os }}-cargo- - - - name: Run quota-lock stress suites - env: - RUST_TEST_THREADS: 16 - run: | - set -euo pipefail - for i in $(seq 1 12); do - echo "[quota-lock-stress] iteration ${i}/12" - cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16 - cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16 - done diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..46d3b0a --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,127 @@ +name: Check + +on: + push: + branches: [ "*" ] + pull_request: + branches: [ "*" ] + +env: + CARGO_TERM_COLOR: always + +concurrency: + group: test-${{ github.ref }} + cancel-in-progress: true + +jobs: +# ========================== +# Formatting +# ========================== + fmt: + name: Fmt + runs-on: ubuntu-latest + + permissions: + contents: read + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + + - run: cargo fmt -- --check + +# ========================== +# Tests +# ========================== + test: + name: Test + runs-on: ubuntu-latest + + permissions: + contents: read + actions: write + checks: write + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - run: cargo test --verbose + +# ========================== +# Clippy +# ========================== + clippy: + name: Clippy + runs-on: ubuntu-latest + + permissions: + contents: read + checks: write + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Cache cargo + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - run: cargo clippy -- --cap-lints warn + +# ========================== +# Udeps +# ========================== + udeps: + name: Udeps + runs-on: ubuntu-latest + + permissions: + contents: read + + steps: + - uses: actions/checkout@v4 + + - uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ${{ runner.os }}-cargo- + + - name: Install cargo-udeps + run: cargo install cargo-udeps || true + + # тоже не валит билд + - run: cargo udeps || true \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 372f702..eac46f0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,29 +1,9 @@ # syntax=docker/dockerfile:1 -# ========================== -# Stage 1: Build -# ========================== -FROM rust:1.88-slim-bookworm AS builder - -RUN apt-get update && apt-get install -y --no-install-recommends \ - pkg-config \ - ca-certificates \ - && rm -rf /var/lib/apt/lists/* - -WORKDIR /build - -# Depcache -COPY Cargo.toml Cargo.lock* ./ -RUN mkdir src && echo 'fn main() {}' > src/main.rs && \ - cargo build --release 2>/dev/null || true && \ - rm -rf src - -# Build -COPY . . -RUN cargo build --release && strip target/release/telemt +ARG BINARY # ========================== -# Stage 2: Compress (strip + UPX) +# Stage: minimal # ========================== FROM debian:12-slim AS minimal @@ -33,7 +13,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ ca-certificates \ && rm -rf /var/lib/apt/lists/* \ \ - # install UPX from Telemt releases && curl -fL \ --retry 5 \ --retry-delay 3 \ @@ -46,15 +25,15 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ && chmod +x /usr/local/bin/upx \ && rm -rf /tmp/upx* -COPY --from=builder /build/target/release/telemt /telemt +COPY ${BINARY} /telemt RUN strip /telemt || true RUN upx --best --lzma /telemt || true # ========================== -# Stage 3: Debug base +# Debug image # ========================== -FROM debian:12-slim AS debug-base +FROM debian:12-slim AS debug RUN apt-get update && apt-get install -y --no-install-recommends \ ca-certificates \ @@ -64,48 +43,29 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ busybox \ && rm -rf /var/lib/apt/lists/* -# ========================== -# Stage 4: Debug image -# ========================== -FROM debug-base AS debug - WORKDIR /app COPY --from=minimal /telemt /app/telemt COPY config.toml /app/config.toml -USER root - -EXPOSE 443 -EXPOSE 9090 -EXPOSE 9091 +EXPOSE 443 9090 9091 ENTRYPOINT ["/app/telemt"] CMD ["config.toml"] # ========================== -# Stage 5: Production (distroless) +# Production (REAL distroless) # ========================== -FROM gcr.io/distroless/base-debian12 AS prod +FROM gcr.io/distroless/static-debian12 AS prod WORKDIR /app COPY --from=minimal /telemt /app/telemt COPY config.toml /app/config.toml -# TLS + timezone + shell -COPY --from=debug-base /etc/ssl/certs /etc/ssl/certs -COPY --from=debug-base /usr/share/zoneinfo /usr/share/zoneinfo -COPY --from=debug-base /bin/busybox /bin/busybox - -RUN ["/bin/busybox", "--install", "-s", "/bin"] - -# distroless user USER nonroot:nonroot -EXPOSE 443 -EXPOSE 9090 -EXPOSE 9091 +EXPOSE 443 9090 9091 ENTRYPOINT ["/app/telemt"] -CMD ["config.toml"] +CMD ["config.toml"] \ No newline at end of file diff --git a/docs/VPS_DOUBLE_HOP.en.md b/docs/VPS_DOUBLE_HOP.en.md index 9463b79..6b6abe5 100644 --- a/docs/VPS_DOUBLE_HOP.en.md +++ b/docs/VPS_DOUBLE_HOP.en.md @@ -63,7 +63,7 @@ recommended range from 5 to 2147483647 inclusive > [!IMPORTANT] > It is recommended to use your own, unique values.\ -> You can use the [generator](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/e8b269ff0089a27effd88f8d925179b78e5666c4/awg-gen.html) to select parameters. +> You can use the [generator](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/13f5517ca473b47c412b9a99407066de973732bd/awg-gen.html) to select parameters. #### Server B Configuration (Netherlands): @@ -84,6 +84,8 @@ Jmin = 8 Jmax = 80 S1 = 29 S2 = 15 +S3 = 18 +S4 = 0 H1 = 2087563914 H2 = 188817757 H3 = 101784570 @@ -121,6 +123,8 @@ Jmin = 8 Jmax = 80 S1 = 29 S2 = 15 +S3 = 18 +S4 = 0 H1 = 2087563914 H2 = 188817757 H3 = 101784570 diff --git a/docs/VPS_DOUBLE_HOP.ru.md b/docs/VPS_DOUBLE_HOP.ru.md index 625c64c..037dfcb 100644 --- a/docs/VPS_DOUBLE_HOP.ru.md +++ b/docs/VPS_DOUBLE_HOP.ru.md @@ -44,7 +44,7 @@ awg genkey | tee private.key | awg pubkey > public.key Параметры обфускации `S1`, `S2`, `H1`, `H2`, `H3`, `H4` должны быть строго идентичными на обоих серверах.\ Параметры `Jc`, `Jmin` и `Jmax` могут отличатся.\ -Параметры `I1-I5` [(Custom Protocol Signature)](https://docs.amnezia.org/documentation/amnezia-wg/) нужно указывать на стороне _клиента_ (Сервер **А**). +Параметры `I1-I5` ([Custom Protocol Signature](https://docs.amnezia.org/documentation/amnezia-wg/)) нужно указывать на стороне _клиента_ (Сервер **А**). Рекомендации по выбору значений: ```text @@ -62,7 +62,7 @@ H1/H2/H3/H4 — должны быть уникальны и отличаться ``` > [!IMPORTANT] > Рекомендуется использовать собственные, уникальные значения.\ -> Для выбора параметров можете воспользоваться [генератором](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/e8b269ff0089a27effd88f8d925179b78e5666c4/awg-gen.html). +> Для выбора параметров можете воспользоваться [генератором](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/13f5517ca473b47c412b9a99407066de973732bd/awg-gen.html). #### Конфигурация Сервера B (_Нидерланды_): @@ -83,6 +83,8 @@ Jmin = 8 Jmax = 80 S1 = 29 S2 = 15 +S3 = 18 +S4 = 0 H1 = 2087563914 H2 = 188817757 H3 = 101784570 @@ -121,6 +123,8 @@ Jmin = 8 Jmax = 80 S1 = 29 S2 = 15 +S3 = 18 +S4 = 0 H1 = 2087563914 H2 = 188817757 H3 = 101784570 @@ -272,7 +276,7 @@ backend telemt_nodes ``` >[!WARNING] ->**Файл должен заканчиваться пустой строкой, иначе HAProxy не запуститься!** +>**Файл должен заканчиваться пустой строкой, иначе HAProxy не запустится!** #### Разрешаем порт 443\tcp в фаерволе (если включен) ```bash diff --git a/src/config/defaults.rs b/src/config/defaults.rs index 09d146a..b0aaf5b 100644 --- a/src/config/defaults.rs +++ b/src/config/defaults.rs @@ -71,6 +71,22 @@ pub(crate) fn default_tls_fetch_scope() -> String { String::new() } +pub(crate) fn default_tls_fetch_attempt_timeout_ms() -> u64 { + 5_000 +} + +pub(crate) fn default_tls_fetch_total_budget_ms() -> u64 { + 15_000 +} + +pub(crate) fn default_tls_fetch_strict_route() -> bool { + true +} + +pub(crate) fn default_tls_fetch_profile_cache_ttl_secs() -> u64 { + 600 +} + pub(crate) fn default_mask_port() -> u16 { 443 } @@ -185,6 +201,10 @@ pub(crate) fn default_proxy_protocol_header_timeout_ms() -> u64 { 500 } +pub(crate) fn default_proxy_protocol_trusted_cidrs() -> Vec { + vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()] +} + pub(crate) fn default_server_max_connections() -> u32 { 10_000 } diff --git a/src/config/hot_reload.rs b/src/config/hot_reload.rs index a3f795a..7f7499e 100644 --- a/src/config/hot_reload.rs +++ b/src/config/hot_reload.rs @@ -228,7 +228,9 @@ impl HotFields { me_d2c_flush_batch_max_delay_us: cfg.general.me_d2c_flush_batch_max_delay_us, me_d2c_ack_flush_immediate: cfg.general.me_d2c_ack_flush_immediate, me_quota_soft_overshoot_bytes: cfg.general.me_quota_soft_overshoot_bytes, - me_d2c_frame_buf_shrink_threshold_bytes: cfg.general.me_d2c_frame_buf_shrink_threshold_bytes, + me_d2c_frame_buf_shrink_threshold_bytes: cfg + .general + .me_d2c_frame_buf_shrink_threshold_bytes, direct_relay_copy_buf_c2s_bytes: cfg.general.direct_relay_copy_buf_c2s_bytes, direct_relay_copy_buf_s2c_bytes: cfg.general.direct_relay_copy_buf_s2c_bytes, me_health_interval_ms_unhealthy: cfg.general.me_health_interval_ms_unhealthy, diff --git a/src/config/load.rs b/src/config/load.rs index fc54ec2..3cb6627 100644 --- a/src/config/load.rs +++ b/src/config/load.rs @@ -1,6 +1,6 @@ #![allow(deprecated)] -use std::collections::{BTreeSet, HashMap}; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::hash::{DefaultHasher, Hash, Hasher}; use std::net::{IpAddr, SocketAddr}; use std::path::{Path, PathBuf}; @@ -444,8 +444,7 @@ impl ProxyConfig { if !(5..=50).contains(&config.censorship.mask_classifier_prefetch_timeout_ms) { return Err(ProxyError::Config( - "censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]" - .to_string(), + "censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]".to_string(), )); } @@ -558,7 +557,9 @@ impl ProxyConfig { )); } - if !(4096..=16 * 1024 * 1024).contains(&config.general.me_d2c_frame_buf_shrink_threshold_bytes) { + if !(4096..=16 * 1024 * 1024) + .contains(&config.general.me_d2c_frame_buf_shrink_threshold_bytes) + { return Err(ProxyError::Config( "general.me_d2c_frame_buf_shrink_threshold_bytes must be within [4096, 16777216]" .to_string(), @@ -976,6 +977,28 @@ impl ProxyConfig { // Normalize optional TLS fetch scope: whitespace-only values disable scoped routing. config.censorship.tls_fetch_scope = config.censorship.tls_fetch_scope.trim().to_string(); + if config.censorship.tls_fetch.profiles.is_empty() { + config.censorship.tls_fetch.profiles = TlsFetchConfig::default().profiles; + } else { + let mut seen = HashSet::new(); + config + .censorship + .tls_fetch + .profiles + .retain(|profile| seen.insert(*profile)); + } + + if config.censorship.tls_fetch.attempt_timeout_ms == 0 { + return Err(ProxyError::Config( + "censorship.tls_fetch.attempt_timeout_ms must be > 0".to_string(), + )); + } + if config.censorship.tls_fetch.total_budget_ms == 0 { + return Err(ProxyError::Config( + "censorship.tls_fetch.total_budget_ms must be > 0".to_string(), + )); + } + // Merge primary + extra TLS domains, deduplicate (primary always first). if !config.censorship.tls_domains.is_empty() { let mut all = Vec::with_capacity(1 + config.censorship.tls_domains.len()); @@ -1262,6 +1285,11 @@ mod tests { assert_eq!(cfg.general.update_every, default_update_every()); assert_eq!(cfg.server.listen_addr_ipv4, default_listen_addr_ipv4()); assert_eq!(cfg.server.listen_addr_ipv6, default_listen_addr_ipv6_opt()); + assert_eq!( + cfg.server.proxy_protocol_trusted_cidrs, + default_proxy_protocol_trusted_cidrs() + ); + assert_eq!(cfg.censorship.unknown_sni_action, UnknownSniAction::Drop); assert_eq!(cfg.server.api.listen, default_api_listen()); assert_eq!(cfg.server.api.whitelist, default_api_whitelist()); assert_eq!( @@ -1394,6 +1422,14 @@ mod tests { let server = ServerConfig::default(); assert_eq!(server.listen_addr_ipv6, Some(default_listen_addr_ipv6())); + assert_eq!( + server.proxy_protocol_trusted_cidrs, + default_proxy_protocol_trusted_cidrs() + ); + assert_eq!( + AntiCensorshipConfig::default().unknown_sni_action, + UnknownSniAction::Drop + ); assert_eq!(server.api.listen, default_api_listen()); assert_eq!(server.api.whitelist, default_api_whitelist()); assert_eq!( @@ -1429,6 +1465,75 @@ mod tests { assert_eq!(access.users, default_access_users()); } + #[test] + fn proxy_protocol_trusted_cidrs_missing_uses_trust_all_but_explicit_empty_stays_empty() { + let cfg_missing: ProxyConfig = toml::from_str( + r#" + [server] + [general] + [network] + [access] + "#, + ) + .unwrap(); + assert_eq!( + cfg_missing.server.proxy_protocol_trusted_cidrs, + default_proxy_protocol_trusted_cidrs() + ); + + let cfg_explicit_empty: ProxyConfig = toml::from_str( + r#" + [server] + proxy_protocol_trusted_cidrs = [] + + [general] + [network] + [access] + "#, + ) + .unwrap(); + assert!( + cfg_explicit_empty + .server + .proxy_protocol_trusted_cidrs + .is_empty() + ); + } + + #[test] + fn unknown_sni_action_parses_and_defaults_to_drop() { + let cfg_default: ProxyConfig = toml::from_str( + r#" + [server] + [general] + [network] + [access] + [censorship] + "#, + ) + .unwrap(); + assert_eq!( + cfg_default.censorship.unknown_sni_action, + UnknownSniAction::Drop + ); + + let cfg_mask: ProxyConfig = toml::from_str( + r#" + [server] + [general] + [network] + [access] + [censorship] + unknown_sni_action = "mask" + "#, + ) + .unwrap(); + assert_eq!( + cfg_mask.censorship.unknown_sni_action, + UnknownSniAction::Mask + ); + } + #[test] fn dc_overrides_allow_string_and_array() { let toml = r#" @@ -2376,6 +2481,94 @@ mod tests { let _ = std::fs::remove_file(path); } + #[test] + fn tls_fetch_defaults_are_applied() { + let toml = r#" + [censorship] + tls_domain = "example.com" + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_defaults_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert_eq!( + cfg.censorship.tls_fetch.profiles, + TlsFetchConfig::default().profiles + ); + assert!(cfg.censorship.tls_fetch.strict_route); + assert_eq!(cfg.censorship.tls_fetch.attempt_timeout_ms, 5_000); + assert_eq!(cfg.censorship.tls_fetch.total_budget_ms, 15_000); + assert_eq!(cfg.censorship.tls_fetch.profile_cache_ttl_secs, 600); + let _ = std::fs::remove_file(path); + } + + #[test] + fn tls_fetch_profiles_are_deduplicated_preserving_order() { + let toml = r#" + [censorship] + tls_domain = "example.com" + [censorship.tls_fetch] + profiles = ["compat_tls12", "modern_chrome_like", "compat_tls12", "legacy_minimal"] + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_profiles_dedup_test.toml"); + std::fs::write(&path, toml).unwrap(); + let cfg = ProxyConfig::load(&path).unwrap(); + assert_eq!( + cfg.censorship.tls_fetch.profiles, + vec![ + TlsFetchProfile::CompatTls12, + TlsFetchProfile::ModernChromeLike, + TlsFetchProfile::LegacyMinimal + ] + ); + let _ = std::fs::remove_file(path); + } + + #[test] + fn tls_fetch_attempt_timeout_zero_is_rejected() { + let toml = r#" + [censorship] + tls_domain = "example.com" + [censorship.tls_fetch] + attempt_timeout_ms = 0 + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_attempt_timeout_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("censorship.tls_fetch.attempt_timeout_ms must be > 0")); + let _ = std::fs::remove_file(path); + } + + #[test] + fn tls_fetch_total_budget_zero_is_rejected() { + let toml = r#" + [censorship] + tls_domain = "example.com" + [censorship.tls_fetch] + total_budget_ms = 0 + + [access.users] + user = "00000000000000000000000000000000" + "#; + let dir = std::env::temp_dir(); + let path = dir.join("telemt_tls_fetch_total_budget_zero_test.toml"); + std::fs::write(&path, toml).unwrap(); + let err = ProxyConfig::load(&path).unwrap_err().to_string(); + assert!(err.contains("censorship.tls_fetch.total_budget_ms must be > 0")); + let _ = std::fs::remove_file(path); + } + #[test] fn invalid_ad_tag_is_disabled_during_load() { let toml = r#" diff --git a/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs b/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs index 49ee953..0b3d543 100644 --- a/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs +++ b/src/config/tests/load_mask_classifier_prefetch_timeout_security_tests.rs @@ -8,8 +8,9 @@ fn write_temp_config(contents: &str) -> PathBuf { .duration_since(UNIX_EPOCH) .expect("system time must be after unix epoch") .as_nanos(); - let path = std::env::temp_dir() - .join(format!("telemt-load-mask-prefetch-timeout-security-{nonce}.toml")); + let path = std::env::temp_dir().join(format!( + "telemt-load-mask-prefetch-timeout-security-{nonce}.toml" + )); fs::write(&path, contents).expect("temp config write must succeed"); path } @@ -67,8 +68,8 @@ mask_classifier_prefetch_timeout_ms = 20 "#, ); - let cfg = ProxyConfig::load(&path) - .expect("prefetch timeout within security bounds must be accepted"); + let cfg = + ProxyConfig::load(&path).expect("prefetch timeout within security bounds must be accepted"); assert_eq!(cfg.censorship.mask_classifier_prefetch_timeout_ms, 20); remove_temp_config(&path); diff --git a/src/config/tests/load_mask_shape_security_tests.rs b/src/config/tests/load_mask_shape_security_tests.rs index 2e4aa41..bccd36f 100644 --- a/src/config/tests/load_mask_shape_security_tests.rs +++ b/src/config/tests/load_mask_shape_security_tests.rs @@ -265,8 +265,8 @@ mask_relay_max_bytes = 67108865 "#, ); - let err = ProxyConfig::load(&path) - .expect_err("mask_relay_max_bytes above hard cap must be rejected"); + let err = + ProxyConfig::load(&path).expect_err("mask_relay_max_bytes above hard cap must be rejected"); let msg = err.to_string(); assert!( msg.contains("censorship.mask_relay_max_bytes must be <= 67108864"), diff --git a/src/config/types.rs b/src/config/types.rs index 5dc9719..3939664 100644 --- a/src/config/types.rs +++ b/src/config/types.rs @@ -954,7 +954,8 @@ impl Default for GeneralConfig { me_d2c_flush_batch_max_delay_us: default_me_d2c_flush_batch_max_delay_us(), me_d2c_ack_flush_immediate: default_me_d2c_ack_flush_immediate(), me_quota_soft_overshoot_bytes: default_me_quota_soft_overshoot_bytes(), - me_d2c_frame_buf_shrink_threshold_bytes: default_me_d2c_frame_buf_shrink_threshold_bytes(), + me_d2c_frame_buf_shrink_threshold_bytes: + default_me_d2c_frame_buf_shrink_threshold_bytes(), direct_relay_copy_buf_c2s_bytes: default_direct_relay_copy_buf_c2s_bytes(), direct_relay_copy_buf_s2c_bytes: default_direct_relay_copy_buf_s2c_bytes(), me_warmup_stagger_enabled: default_true(), @@ -1239,9 +1240,10 @@ pub struct ServerConfig { /// Trusted source CIDRs allowed to send incoming PROXY protocol headers. /// - /// When non-empty, connections from addresses outside this allowlist are - /// rejected before `src_addr` is applied. - #[serde(default)] + /// If this field is omitted in config, it defaults to trust-all CIDRs + /// (`0.0.0.0/0` and `::/0`). If it is explicitly set to an empty list, + /// all PROXY protocol headers are rejected. + #[serde(default = "default_proxy_protocol_trusted_cidrs")] pub proxy_protocol_trusted_cidrs: Vec, /// Port for the Prometheus-compatible metrics endpoint. @@ -1286,7 +1288,7 @@ impl Default for ServerConfig { listen_tcp: None, proxy_protocol: false, proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(), - proxy_protocol_trusted_cidrs: Vec::new(), + proxy_protocol_trusted_cidrs: default_proxy_protocol_trusted_cidrs(), metrics_port: None, metrics_listen: None, metrics_whitelist: default_metrics_whitelist(), @@ -1357,6 +1359,90 @@ impl Default for TimeoutsConfig { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum UnknownSniAction { + #[default] + Drop, + Mask, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum TlsFetchProfile { + ModernChromeLike, + ModernFirefoxLike, + CompatTls12, + LegacyMinimal, +} + +impl TlsFetchProfile { + pub fn as_str(self) -> &'static str { + match self { + TlsFetchProfile::ModernChromeLike => "modern_chrome_like", + TlsFetchProfile::ModernFirefoxLike => "modern_firefox_like", + TlsFetchProfile::CompatTls12 => "compat_tls12", + TlsFetchProfile::LegacyMinimal => "legacy_minimal", + } + } +} + +fn default_tls_fetch_profiles() -> Vec { + vec![ + TlsFetchProfile::ModernChromeLike, + TlsFetchProfile::ModernFirefoxLike, + TlsFetchProfile::CompatTls12, + TlsFetchProfile::LegacyMinimal, + ] +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TlsFetchConfig { + /// Ordered list of ClientHello profiles used for adaptive fallback. + #[serde(default = "default_tls_fetch_profiles")] + pub profiles: Vec, + + /// When true and upstream route is configured, TLS fetch fails closed on + /// upstream connect errors and does not fallback to direct TCP. + #[serde(default = "default_tls_fetch_strict_route")] + pub strict_route: bool, + + /// Timeout per one profile attempt in milliseconds. + #[serde(default = "default_tls_fetch_attempt_timeout_ms")] + pub attempt_timeout_ms: u64, + + /// Total wall-clock budget in milliseconds across all profile attempts. + #[serde(default = "default_tls_fetch_total_budget_ms")] + pub total_budget_ms: u64, + + /// Adds GREASE-style values into selected ClientHello extensions. + #[serde(default)] + pub grease_enabled: bool, + + /// Produces deterministic ClientHello randomness for debugging/tests. + #[serde(default)] + pub deterministic: bool, + + /// TTL for winner-profile cache entries in seconds. + /// Set to 0 to disable profile cache. + #[serde(default = "default_tls_fetch_profile_cache_ttl_secs")] + pub profile_cache_ttl_secs: u64, +} + +impl Default for TlsFetchConfig { + fn default() -> Self { + Self { + profiles: default_tls_fetch_profiles(), + strict_route: default_tls_fetch_strict_route(), + attempt_timeout_ms: default_tls_fetch_attempt_timeout_ms(), + total_budget_ms: default_tls_fetch_total_budget_ms(), + grease_enabled: false, + deterministic: false, + profile_cache_ttl_secs: default_tls_fetch_profile_cache_ttl_secs(), + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AntiCensorshipConfig { #[serde(default = "default_tls_domain")] @@ -1366,11 +1452,19 @@ pub struct AntiCensorshipConfig { #[serde(default)] pub tls_domains: Vec, + /// Policy for TLS ClientHello with unknown (non-configured) SNI. + #[serde(default)] + pub unknown_sni_action: UnknownSniAction, + /// Upstream scope used for TLS front metadata fetches. /// Empty value keeps default upstream routing behavior. #[serde(default = "default_tls_fetch_scope")] pub tls_fetch_scope: String, + /// Fetch strategy for TLS front metadata bootstrap and periodic refresh. + #[serde(default)] + pub tls_fetch: TlsFetchConfig, + #[serde(default = "default_true")] pub mask: bool, @@ -1476,7 +1570,9 @@ impl Default for AntiCensorshipConfig { Self { tls_domain: default_tls_domain(), tls_domains: Vec::new(), + unknown_sni_action: UnknownSniAction::Drop, tls_fetch_scope: default_tls_fetch_scope(), + tls_fetch: TlsFetchConfig::default(), mask: default_true(), mask_host: None, mask_port: default_mask_port(), diff --git a/src/error.rs b/src/error.rs index d9aeb22..49c8c81 100644 --- a/src/error.rs +++ b/src/error.rs @@ -216,6 +216,9 @@ pub enum ProxyError { #[error("Invalid proxy protocol header")] InvalidProxyProtocol, + #[error("Unknown TLS SNI")] + UnknownTlsSni, + #[error("Proxy error: {0}")] Proxy(String), diff --git a/src/maestro/helpers.rs b/src/maestro/helpers.rs index 35f796f..032460c 100644 --- a/src/maestro/helpers.rs +++ b/src/maestro/helpers.rs @@ -8,8 +8,10 @@ use tracing::{debug, error, info, warn}; use crate::cli; use crate::config::ProxyConfig; +use crate::transport::UpstreamManager; use crate::transport::middle_proxy::{ - ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache, + ProxyConfigData, fetch_proxy_config_with_raw_via_upstream, load_proxy_config_cache, + save_proxy_config_cache, }; pub(crate) fn resolve_runtime_config_path( @@ -288,9 +290,10 @@ pub(crate) async fn load_startup_proxy_config_snapshot( cache_path: Option<&str>, me2dc_fallback: bool, label: &'static str, + upstream: Option>, ) -> Option { loop { - match fetch_proxy_config_with_raw(url).await { + match fetch_proxy_config_with_raw_via_upstream(url, upstream.clone()).await { Ok((cfg, raw)) => { if !cfg.map.is_empty() { if let Some(path) = cache_path diff --git a/src/maestro/me_startup.rs b/src/maestro/me_startup.rs index 022f8ae..b1e605c 100644 --- a/src/maestro/me_startup.rs +++ b/src/maestro/me_startup.rs @@ -63,9 +63,10 @@ pub(crate) async fn initialize_me_pool( let proxy_secret_path = config.general.proxy_secret_path.as_deref(); let pool_size = config.general.middle_proxy_pool_size.max(1); let proxy_secret = loop { - match crate::transport::middle_proxy::fetch_proxy_secret( + match crate::transport::middle_proxy::fetch_proxy_secret_with_upstream( proxy_secret_path, config.general.proxy_secret_len_max, + Some(upstream_manager.clone()), ) .await { @@ -129,6 +130,7 @@ pub(crate) async fn initialize_me_pool( config.general.proxy_config_v4_cache_path.as_deref(), me2dc_fallback, "getProxyConfig", + Some(upstream_manager.clone()), ) .await; if cfg_v4.is_some() { @@ -160,6 +162,7 @@ pub(crate) async fn initialize_me_pool( config.general.proxy_config_v6_cache_path.as_deref(), me2dc_fallback, "getProxyConfigV6", + Some(upstream_manager.clone()), ) .await; if cfg_v6.is_some() { diff --git a/src/maestro/runtime_tasks.rs b/src/maestro/runtime_tasks.rs index 066c853..d553eb9 100644 --- a/src/maestro/runtime_tasks.rs +++ b/src/maestro/runtime_tasks.rs @@ -32,14 +32,6 @@ pub(crate) struct RuntimeWatches { pub(crate) detected_ip_v6: Option, } -const QUOTA_USER_LOCK_EVICT_INTERVAL_SECS: u64 = 60; - -fn spawn_quota_lock_maintenance_task() -> tokio::task::JoinHandle<()> { - crate::proxy::relay::spawn_quota_user_lock_evictor(std::time::Duration::from_secs( - QUOTA_USER_LOCK_EVICT_INTERVAL_SECS, - )) -} - #[allow(clippy::too_many_arguments)] pub(crate) async fn spawn_runtime_tasks( config: &Arc, @@ -77,8 +69,6 @@ pub(crate) async fn spawn_runtime_tasks( rc_clone.run_periodic_cleanup().await; }); - spawn_quota_lock_maintenance_task(); - let detected_ip_v4: Option = probe.detected_ipv4.map(IpAddr::V4); let detected_ip_v6: Option = probe.detected_ipv6.map(IpAddr::V6); debug!( @@ -370,24 +360,3 @@ pub(crate) async fn mark_runtime_ready(startup_tracker: &Arc) { .await; startup_tracker.mark_ready().await; } - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn tdd_runtime_quota_lock_maintenance_path_spawns_single_evictor_task() { - crate::proxy::relay::reset_quota_user_lock_evictor_spawn_count_for_tests(); - - let handle = spawn_quota_lock_maintenance_task(); - tokio::time::sleep(std::time::Duration::from_millis(5)).await; - - assert_eq!( - crate::proxy::relay::quota_user_lock_evictor_spawn_count_for_tests(), - 1, - "runtime maintenance path must spawn exactly one quota lock evictor task per call" - ); - - handle.abort(); - } -} diff --git a/src/maestro/tls_bootstrap.rs b/src/maestro/tls_bootstrap.rs index 342a2f9..7cf3039 100644 --- a/src/maestro/tls_bootstrap.rs +++ b/src/maestro/tls_bootstrap.rs @@ -7,6 +7,7 @@ use tracing::warn; use crate::config::ProxyConfig; use crate::startup::{COMPONENT_TLS_FRONT_BOOTSTRAP, StartupTracker}; use crate::tls_front::TlsFrontCache; +use crate::tls_front::fetcher::TlsFetchStrategy; use crate::transport::UpstreamManager; pub(crate) async fn bootstrap_tls_front( @@ -40,7 +41,17 @@ pub(crate) async fn bootstrap_tls_front( let mask_unix_sock = config.censorship.mask_unix_sock.clone(); let tls_fetch_scope = (!config.censorship.tls_fetch_scope.is_empty()) .then(|| config.censorship.tls_fetch_scope.clone()); - let fetch_timeout = Duration::from_secs(5); + let tls_fetch = config.censorship.tls_fetch.clone(); + let fetch_strategy = TlsFetchStrategy { + profiles: tls_fetch.profiles, + strict_route: tls_fetch.strict_route, + attempt_timeout: Duration::from_millis(tls_fetch.attempt_timeout_ms.max(1)), + total_budget: Duration::from_millis(tls_fetch.total_budget_ms.max(1)), + grease_enabled: tls_fetch.grease_enabled, + deterministic: tls_fetch.deterministic, + profile_cache_ttl: Duration::from_secs(tls_fetch.profile_cache_ttl_secs), + }; + let fetch_timeout = fetch_strategy.total_budget; let cache_initial = cache.clone(); let domains_initial = tls_domains.to_vec(); @@ -48,6 +59,7 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_initial = mask_unix_sock.clone(); let scope_initial = tls_fetch_scope.clone(); let upstream_initial = upstream_manager.clone(); + let strategy_initial = fetch_strategy.clone(); tokio::spawn(async move { let mut join = tokio::task::JoinSet::new(); for domain in domains_initial { @@ -56,12 +68,13 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_domain = unix_sock_initial.clone(); let scope_domain = scope_initial.clone(); let upstream_domain = upstream_initial.clone(); + let strategy_domain = strategy_initial.clone(); join.spawn(async move { - match crate::tls_front::fetcher::fetch_real_tls( + match crate::tls_front::fetcher::fetch_real_tls_with_strategy( &host_domain, port, &domain, - fetch_timeout, + &strategy_domain, Some(upstream_domain), scope_domain.as_deref(), proxy_protocol, @@ -107,6 +120,7 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_refresh = mask_unix_sock.clone(); let scope_refresh = tls_fetch_scope.clone(); let upstream_refresh = upstream_manager.clone(); + let strategy_refresh = fetch_strategy.clone(); tokio::spawn(async move { loop { let base_secs = rand::rng().random_range(4 * 3600..=6 * 3600); @@ -120,12 +134,13 @@ pub(crate) async fn bootstrap_tls_front( let unix_sock_domain = unix_sock_refresh.clone(); let scope_domain = scope_refresh.clone(); let upstream_domain = upstream_refresh.clone(); + let strategy_domain = strategy_refresh.clone(); join.spawn(async move { - match crate::tls_front::fetcher::fetch_real_tls( + match crate::tls_front::fetcher::fetch_real_tls_with_strategy( &host_domain, port, &domain, - fetch_timeout, + &strategy_domain, Some(upstream_domain), scope_domain.as_deref(), proxy_protocol, diff --git a/src/main.rs b/src/main.rs index c512e6b..e5d931f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,12 +7,12 @@ mod crypto; mod error; mod ip_tracker; #[cfg(test)] -#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"] -mod ip_tracker_hotpath_adversarial_tests; -#[cfg(test)] #[path = "tests/ip_tracker_encapsulation_adversarial_tests.rs"] mod ip_tracker_encapsulation_adversarial_tests; #[cfg(test)] +#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"] +mod ip_tracker_hotpath_adversarial_tests; +#[cfg(test)] #[path = "tests/ip_tracker_regression_tests.rs"] mod ip_tracker_regression_tests; mod maestro; @@ -29,5 +29,6 @@ mod util; #[tokio::main] async fn main() -> std::result::Result<(), Box> { + let _ = rustls::crypto::ring::default_provider().install_default(); maestro::run().await } diff --git a/src/metrics.rs b/src/metrics.rs index a821d4d..f9475f6 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -1233,10 +1233,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp out, "# HELP telemt_me_d2c_batch_bytes_bucket_total DC->Client batch byte size buckets" ); - let _ = writeln!( - out, - "# TYPE telemt_me_d2c_batch_bytes_bucket_total counter" - ); + let _ = writeln!(out, "# TYPE telemt_me_d2c_batch_bytes_bucket_total counter"); let _ = writeln!( out, "telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"0_1k\"}} {}", diff --git a/src/proxy/client.rs b/src/proxy/client.rs index a804a2c..8ce3e96 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -210,7 +210,9 @@ fn should_prefetch_mask_classifier_window(initial_data: &[u8]) -> bool { return false; } - initial_data.iter().all(|b| b.is_ascii_alphabetic() || *b == b' ') + initial_data + .iter() + .all(|b| b.is_ascii_alphabetic() || *b == b' ') } #[cfg(test)] @@ -218,16 +220,19 @@ async fn extend_masking_initial_window(reader: &mut R, initial_data: &mut Vec where R: AsyncRead + Unpin, { - extend_masking_initial_window_with_timeout(reader, initial_data, MASK_CLASSIFIER_PREFETCH_TIMEOUT) - .await; + extend_masking_initial_window_with_timeout( + reader, + initial_data, + MASK_CLASSIFIER_PREFETCH_TIMEOUT, + ) + .await; } async fn extend_masking_initial_window_with_timeout( reader: &mut R, initial_data: &mut Vec, prefetch_timeout: Duration, -) -where +) where R: AsyncRead + Unpin, { if !should_prefetch_mask_classifier_window(initial_data) { @@ -312,13 +317,20 @@ fn record_handshake_failure_class( record_beobachten_class(beobachten, config, peer_ip, class); } +#[inline] +fn increment_bad_on_unknown_tls_sni(stats: &Stats, error: &ProxyError) { + if matches!(error, ProxyError::UnknownTlsSni) { + stats.increment_connects_bad(); + } +} + fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool { if trusted.is_empty() { static EMPTY_PROXY_TRUST_WARNED: OnceLock = OnceLock::new(); let warned = EMPTY_PROXY_TRUST_WARNED.get_or_init(|| AtomicBool::new(false)); if !warned.swap(true, Ordering::Relaxed) { warn!( - "PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers by default" + "PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers" ); } return false; @@ -503,7 +515,10 @@ where beobachten.clone(), )); } - HandshakeResult::Error(e) => return Err(e), + HandshakeResult::Error(e) => { + increment_bad_on_unknown_tls_sni(stats.as_ref(), &e); + return Err(e); + } }; debug!(peer = %peer, "Reading MTProto handshake through TLS"); @@ -954,7 +969,10 @@ impl RunningClientHandler { self.beobachten.clone(), )); } - HandshakeResult::Error(e) => return Err(e), + HandshakeResult::Error(e) => { + increment_bad_on_unknown_tls_sni(stats.as_ref(), &e); + return Err(e); + } }; debug!(peer = %peer, "Reading MTProto handshake through TLS"); @@ -1223,7 +1241,7 @@ impl RunningClientHandler { } if let Some(quota) = config.access.user_data_quota.get(user) - && stats.get_user_total_octets(user) >= *quota + && stats.get_user_quota_used(user) >= *quota { return Err(ProxyError::DataQuotaExceeded { user: user.to_string(), @@ -1282,7 +1300,7 @@ impl RunningClientHandler { } if let Some(quota) = config.access.user_data_quota.get(user) - && stats.get_user_total_octets(user) >= *quota + && stats.get_user_quota_used(user) >= *quota { return Err(ProxyError::DataQuotaExceeded { user: user.to_string(), diff --git a/src/proxy/handshake.rs b/src/proxy/handshake.rs index 96994c7..2ef8e1b 100644 --- a/src/proxy/handshake.rs +++ b/src/proxy/handshake.rs @@ -16,7 +16,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, trace, warn}; use zeroize::{Zeroize, Zeroizing}; -use crate::config::ProxyConfig; +use crate::config::{ProxyConfig, UnknownSniAction}; use crate::crypto::{AesCtr, SecureRandom, sha256}; use crate::error::{HandshakeResult, ProxyError}; use crate::protocol::constants::*; @@ -282,30 +282,9 @@ fn auth_probe_record_failure_with_state( let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None; let state_len = state.len(); let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT); - let start_offset = auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit); - let mut scanned = 0usize; - for entry in state.iter().skip(start_offset) { - let key = *entry.key(); - let fail_streak = entry.value().fail_streak; - let last_seen = entry.value().last_seen; - match eviction_candidate { - Some((_, current_fail, current_seen)) - if fail_streak > current_fail - || (fail_streak == current_fail && last_seen >= current_seen) => {} - _ => eviction_candidate = Some((key, fail_streak, last_seen)), - } - if auth_probe_state_expired(entry.value(), now) { - stale_keys.push(key); - } - scanned += 1; - if scanned >= scan_limit { - break; - } - } - - if scanned < scan_limit { - for entry in state.iter().take(scan_limit - scanned) { + if state_len <= AUTH_PROBE_PRUNE_SCAN_LIMIT { + for entry in state.iter() { let key = *entry.key(); let fail_streak = entry.value().fail_streak; let last_seen = entry.value().last_seen; @@ -319,6 +298,46 @@ fn auth_probe_record_failure_with_state( stale_keys.push(key); } } + } else { + let start_offset = + auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit); + let mut scanned = 0usize; + for entry in state.iter().skip(start_offset) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail && last_seen >= current_seen) => {} + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + scanned += 1; + if scanned >= scan_limit { + break; + } + } + + if scanned < scan_limit { + for entry in state.iter().take(scan_limit - scanned) { + let key = *entry.key(); + let fail_streak = entry.value().fail_streak; + let last_seen = entry.value().last_seen; + match eviction_candidate { + Some((_, current_fail, current_seen)) + if fail_streak > current_fail + || (fail_streak == current_fail + && last_seen >= current_seen) => {} + _ => eviction_candidate = Some((key, fail_streak, last_seen)), + } + if auth_probe_state_expired(entry.value(), now) { + stale_keys.push(key); + } + } + } } for stale_key in stale_keys { @@ -510,6 +529,21 @@ fn decode_user_secrets( secrets } +#[inline] +fn find_matching_tls_domain<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> { + if config.censorship.tls_domain.eq_ignore_ascii_case(sni) { + return Some(config.censorship.tls_domain.as_str()); + } + + for domain in &config.censorship.tls_domains { + if domain.eq_ignore_ascii_case(sni) { + return Some(domain.as_str()); + } + } + + None +} + async fn maybe_apply_server_hello_delay(config: &ProxyConfig) { if config.censorship.server_hello_delay_max_ms == 0 { return; @@ -593,60 +627,12 @@ where } let client_sni = tls::extract_sni_from_client_hello(handshake); - let secrets = decode_user_secrets(config, client_sni.as_deref()); - - let validation = match tls::validate_tls_handshake_with_replay_window( - handshake, - &secrets, - config.access.ignore_time_skew, - config.access.replay_window_secs, - ) { - Some(v) => v, - None => { - auth_probe_record_failure(peer.ip(), Instant::now()); - maybe_apply_server_hello_delay(config).await; - debug!( - peer = %peer, - ignore_time_skew = config.access.ignore_time_skew, - "TLS handshake validation failed - no matching user or time skew" - ); - return HandshakeResult::BadClient { reader, writer }; - } - }; - - let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { - Some((_, s)) => s, - None => { - maybe_apply_server_hello_delay(config).await; - return HandshakeResult::BadClient { reader, writer }; - } - }; - - let cached = if config.censorship.tls_emulation { - if let Some(cache) = tls_cache.as_ref() { - let selected_domain = if let Some(sni) = client_sni.as_ref() { - if cache.contains_domain(sni).await { - sni.clone() - } else { - config.censorship.tls_domain.clone() - } - } else { - config.censorship.tls_domain.clone() - }; - let cached_entry = cache.get(&selected_domain).await; - let use_full_cert_payload = cache - .take_full_cert_budget_for_ip( - peer.ip(), - Duration::from_secs(config.censorship.tls_full_cert_ttl_secs), - ) - .await; - Some((cached_entry, use_full_cert_payload)) - } else { - None - } - } else { - None - }; + let preferred_user_hint = client_sni + .as_deref() + .filter(|sni| config.access.users.contains_key(*sni)); + let matched_tls_domain = client_sni + .as_deref() + .and_then(|sni| find_matching_tls_domain(config, sni)); let alpn_list = if config.censorship.alpn_enforce { tls::extract_alpn_from_client_hello(handshake) @@ -669,16 +655,81 @@ where None }; - // Replay tracking is applied only after full policy validation (including - // ALPN checks) so rejected handshakes cannot poison replay state. + if client_sni.is_some() && matched_tls_domain.is_none() && preferred_user_hint.is_none() { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + debug!( + peer = %peer, + sni = ?client_sni, + action = ?config.censorship.unknown_sni_action, + "TLS handshake rejected by unknown SNI policy" + ); + return match config.censorship.unknown_sni_action { + UnknownSniAction::Drop => HandshakeResult::Error(ProxyError::UnknownTlsSni), + UnknownSniAction::Mask => HandshakeResult::BadClient { reader, writer }, + }; + } + + let secrets = decode_user_secrets(config, preferred_user_hint); + + let validation = match tls::validate_tls_handshake_with_replay_window( + handshake, + &secrets, + config.access.ignore_time_skew, + config.access.replay_window_secs, + ) { + Some(v) => v, + None => { + auth_probe_record_failure(peer.ip(), Instant::now()); + maybe_apply_server_hello_delay(config).await; + debug!( + peer = %peer, + ignore_time_skew = config.access.ignore_time_skew, + "TLS handshake validation failed - no matching user or time skew" + ); + return HandshakeResult::BadClient { reader, writer }; + } + }; + + // Reject known replay digests before expensive cache/domain/ALPN policy work. let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN]; - if replay_checker.check_and_add_tls_digest(digest_half) { + if replay_checker.check_tls_digest(digest_half) { auth_probe_record_failure(peer.ip(), Instant::now()); maybe_apply_server_hello_delay(config).await; warn!(peer = %peer, "TLS replay attack detected (duplicate digest)"); return HandshakeResult::BadClient { reader, writer }; } + let secret = match secrets.iter().find(|(name, _)| *name == validation.user) { + Some((_, s)) => s, + None => { + maybe_apply_server_hello_delay(config).await; + return HandshakeResult::BadClient { reader, writer }; + } + }; + + let cached = if config.censorship.tls_emulation { + if let Some(cache) = tls_cache.as_ref() { + let selected_domain = + matched_tls_domain.unwrap_or(config.censorship.tls_domain.as_str()); + let cached_entry = cache.get(selected_domain).await; + let use_full_cert_payload = cache + .take_full_cert_budget_for_ip( + peer.ip(), + Duration::from_secs(config.censorship.tls_full_cert_ttl_secs), + ) + .await; + Some((cached_entry, use_full_cert_payload)) + } else { + None + } + } else { + None + }; + + // Add replay digest only for policy-valid handshakes. + replay_checker.add_tls_digest(digest_half); + let response = if let Some((cached_entry, use_full_cert_payload)) = cached { emulator::build_emulated_server_hello( secret, diff --git a/src/proxy/masking.rs b/src/proxy/masking.rs index 841749c..ba9f20a 100644 --- a/src/proxy/masking.rs +++ b/src/proxy/masking.rs @@ -10,10 +10,10 @@ use rand::rngs::StdRng; use rand::{Rng, RngExt, SeedableRng}; use std::net::{IpAddr, SocketAddr}; use std::str; -#[cfg(unix)] -use std::sync::{Mutex, OnceLock}; #[cfg(test)] use std::sync::atomic::{AtomicUsize, Ordering}; +#[cfg(unix)] +use std::sync::{Mutex, OnceLock}; use std::time::{Duration, Instant as StdInstant}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::net::TcpStream; @@ -60,7 +60,7 @@ where R: AsyncRead + Unpin, W: AsyncWrite + Unpin, { - let mut buf = [0u8; MASK_BUFFER_SIZE]; + let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]); let mut total = 0usize; let mut ended_by_eof = false; @@ -107,15 +107,7 @@ where fn is_http_probe(data: &[u8]) -> bool { // RFC 7540 section 3.5: HTTP/2 client preface starts with "PRI ". const HTTP_METHODS: [&[u8]; 10] = [ - b"GET ", - b"POST", - b"HEAD", - b"PUT ", - b"DELETE", - b"OPTIONS", - b"CONNECT", - b"TRACE", - b"PATCH", + b"GET ", b"POST", b"HEAD", b"PUT ", b"DELETE", b"OPTIONS", b"CONNECT", b"TRACE", b"PATCH", b"PRI ", ]; @@ -262,7 +254,11 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration { let floor = config.censorship.mask_timing_normalization_floor_ms; let ceiling = config.censorship.mask_timing_normalization_ceiling_ms; if floor == 0 { - return MASK_TIMEOUT; + if ceiling == 0 { + return Duration::from_millis(0); + } + let mut rng = rand::rng(); + return Duration::from_millis(rng.random_range(0..=ceiling)); } if ceiling > floor { let mut rng = rand::rng(); @@ -324,7 +320,10 @@ fn parse_mask_host_ip_literal(host: &str) -> Option { fn canonical_ip(ip: IpAddr) -> IpAddr { match ip { - IpAddr::V6(v6) => v6.to_ipv4_mapped().map(IpAddr::V4).unwrap_or(IpAddr::V6(v6)), + IpAddr::V6(v6) => v6 + .to_ipv4_mapped() + .map(IpAddr::V4) + .unwrap_or(IpAddr::V6(v6)), IpAddr::V4(v4) => IpAddr::V4(v4), } } @@ -660,12 +659,20 @@ pub async fn handle_bad_client( Ok(Err(e)) => { wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask unix socket"); - consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; + consume_client_data_with_timeout_and_cap( + reader, + config.censorship.mask_relay_max_bytes, + ) + .await; wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask unix socket"); - consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; + consume_client_data_with_timeout_and_cap( + reader, + config.censorship.mask_relay_max_bytes, + ) + .await; wait_mask_outcome_budget(outcome_started, config).await; } } @@ -694,7 +701,8 @@ pub async fn handle_bad_client( local = %local_addr, "Mask target resolves to local listener; refusing self-referential masking fallback" ); - consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; + consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes) + .await; wait_mask_outcome_budget(outcome_started, config).await; return; } @@ -754,12 +762,20 @@ pub async fn handle_bad_client( Ok(Err(e)) => { wait_mask_connect_budget_if_needed(connect_started, config).await; debug!(error = %e, "Failed to connect to mask host"); - consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; + consume_client_data_with_timeout_and_cap( + reader, + config.censorship.mask_relay_max_bytes, + ) + .await; wait_mask_outcome_budget(outcome_started, config).await; } Err(_) => { debug!("Timeout connecting to mask host"); - consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes).await; + consume_client_data_with_timeout_and_cap( + reader, + config.censorship.mask_relay_max_bytes, + ) + .await; wait_mask_outcome_budget(outcome_started, config).await; } } @@ -838,7 +854,7 @@ async fn consume_client_data(mut reader: R, byte_cap: usiz } // Keep drain path fail-closed under slow-loris stalls. - let mut buf = [0u8; MASK_BUFFER_SIZE]; + let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]); let mut total = 0usize; loop { diff --git a/src/proxy/middle_relay.rs b/src/proxy/middle_relay.rs index 14ea001..3259597 100644 --- a/src/proxy/middle_relay.rs +++ b/src/proxy/middle_relay.rs @@ -10,7 +10,7 @@ use std::time::{Duration, Instant}; use dashmap::DashMap; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::sync::{Mutex as AsyncMutex, mpsc, oneshot, watch}; +use tokio::sync::{mpsc, oneshot, watch}; use tokio::time::timeout; use tracing::{debug, info, trace, warn}; @@ -23,7 +23,9 @@ use crate::proxy::route_mode::{ ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state, cutover_stagger_delay, }; -use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, Stats}; +use crate::stats::{ + MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, QuotaReserveError, Stats, UserStats, +}; use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag}; @@ -53,20 +55,11 @@ const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1; const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096; const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2; const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024; -#[cfg(test)] -const QUOTA_USER_LOCKS_MAX: usize = 64; -#[cfg(not(test))] -const QUOTA_USER_LOCKS_MAX: usize = 4_096; -#[cfg(test)] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; -#[cfg(not(test))] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; +const QUOTA_RESERVE_SPIN_RETRIES: usize = 32; static DESYNC_DEDUP: OnceLock> = OnceLock::new(); static DESYNC_HASHER: OnceLock = OnceLock::new(); static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock>> = OnceLock::new(); static DESYNC_DEDUP_EVER_SATURATED: OnceLock = OnceLock::new(); -static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); -static QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock> = OnceLock::new(); static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0); @@ -100,7 +93,8 @@ fn relay_idle_candidate_registry() -> &'static Mutex RELAY_IDLE_CANDIDATE_REGISTRY.get_or_init(|| Mutex::new(RelayIdleCandidateRegistry::default())) } -fn relay_idle_candidate_registry_lock() -> std::sync::MutexGuard<'static, RelayIdleCandidateRegistry> { +fn relay_idle_candidate_registry_lock() -> std::sync::MutexGuard<'static, RelayIdleCandidateRegistry> +{ let registry = relay_idle_candidate_registry(); match registry.lock() { Ok(guard) => guard, @@ -538,36 +532,28 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET } -fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option) -> bool { - quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota) -} - -#[cfg_attr(not(test), allow(dead_code))] -fn quota_would_be_exceeded_for_user( - stats: &Stats, - user: &str, - quota_limit: Option, - bytes: u64, -) -> bool { - quota_limit.is_some_and(|quota| { - let used = stats.get_user_total_octets(user); - used >= quota || bytes > quota.saturating_sub(used) - }) -} - fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 { limit.saturating_add(overshoot) } -fn quota_would_be_exceeded_for_user_soft( - stats: &Stats, - user: &str, - quota_limit: Option, +async fn reserve_user_quota_with_yield( + user_stats: &UserStats, bytes: u64, - overshoot: u64, -) -> bool { - let capped_limit = quota_limit.map(|quota| quota_soft_cap(quota, overshoot)); - quota_would_be_exceeded_for_user(stats, user, capped_limit, bytes) + limit: u64, +) -> std::result::Result { + loop { + for _ in 0..QUOTA_RESERVE_SPIN_RETRIES { + match user_stats.quota_try_reserve(bytes, limit) { + Ok(total) => return Ok(total), + Err(QuotaReserveError::LimitExceeded) => { + return Err(QuotaReserveError::LimitExceeded); + } + Err(QuotaReserveError::Contended) => std::hint::spin_loop(), + } + } + + tokio::task::yield_now().await; + } } fn classify_me_d2c_flush_reason( @@ -613,29 +599,6 @@ fn observe_me_d2c_flush_event( } } -fn rollback_me2c_quota_reservation( - stats: &Stats, - user: &str, - bytes_me2c: &AtomicU64, - reserved_bytes: u64, -) { - stats.sub_user_octets_to(user, reserved_bytes); - bytes_me2c.fetch_sub(reserved_bytes, Ordering::Relaxed); -} - -#[cfg(test)] -fn quota_user_lock_test_guard() -> &'static Mutex<()> { - static TEST_LOCK: OnceLock> = OnceLock::new(); - TEST_LOCK.get_or_init(|| Mutex::new(())) -} - -#[cfg(test)] -fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { - quota_user_lock_test_guard() - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) -} - #[cfg(test)] fn relay_idle_pressure_test_guard() -> &'static Mutex<()> { static TEST_LOCK: OnceLock> = OnceLock::new(); @@ -649,46 +612,6 @@ pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static, .unwrap_or_else(|poisoned| poisoned.into_inner()) } -fn quota_overflow_user_lock(user: &str) -> Arc> { - let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { - (0..QUOTA_OVERFLOW_LOCK_STRIPES) - .map(|_| Arc::new(AsyncMutex::new(()))) - .collect() - }); - - let hash = crc32fast::hash(user.as_bytes()) as usize; - Arc::clone(&stripes[hash % stripes.len()]) -} - -fn quota_user_lock(user: &str) -> Arc> { - let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - if let Some(existing) = locks.get(user) { - return Arc::clone(existing.value()); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - locks.retain(|_, value| Arc::strong_count(value) > 1); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - return quota_overflow_user_lock(user); - } - - let created = Arc::new(AsyncMutex::new(())); - match locks.entry(user.to_string()) { - dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), - dashmap::mapref::entry::Entry::Vacant(entry) => { - entry.insert(Arc::clone(&created)); - created - } - } -} - -#[cfg(test)] -pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user) -} - async fn enqueue_c2me_command( tx: &mpsc::Sender, cmd: C2MeCommand, @@ -744,8 +667,7 @@ where { let user = success.user.clone(); let quota_limit = config.access.user_data_quota.get(&user).copied(); - let cross_mode_quota_lock = - quota_limit.map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); + let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user)); let peer = success.peer; let proto_tag = success.proto_tag; let pool_generation = me_pool.current_generation(); @@ -872,7 +794,7 @@ where let stats_clone = stats.clone(); let rng_clone = rng.clone(); let user_clone = user.clone(); - let cross_mode_quota_lock_me_writer = cross_mode_quota_lock.clone(); + let quota_user_stats_me_writer = quota_user_stats.clone(); let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone(); let bytes_me2c_clone = bytes_me2c.clone(); let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config); @@ -894,7 +816,7 @@ where let first_is_downstream_activity = matches!(&first, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response_with_cross_mode_lock( + match process_me_writer_response( first, &mut writer, proto_tag, @@ -902,9 +824,9 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, - cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -953,7 +875,7 @@ where let next_is_downstream_activity = matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response_with_cross_mode_lock( + match process_me_writer_response( next, &mut writer, proto_tag, @@ -961,9 +883,9 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, - cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1015,7 +937,7 @@ where Ok(Some(next)) => { let next_is_downstream_activity = matches!(&next, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response_with_cross_mode_lock( + match process_me_writer_response( next, &mut writer, proto_tag, @@ -1023,9 +945,9 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, - cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1079,7 +1001,7 @@ where let extra_is_downstream_activity = matches!(&extra, MeResponse::Data { .. } | MeResponse::Ack(_)); - match process_me_writer_response_with_cross_mode_lock( + match process_me_writer_response( extra, &mut writer, proto_tag, @@ -1087,9 +1009,9 @@ where &mut frame_buf, stats_clone.as_ref(), &user_clone, + quota_user_stats_me_writer.as_deref(), quota_limit, d2c_flush_policy.quota_soft_overshoot_bytes, - cross_mode_quota_lock_me_writer.as_ref(), bytes_me2c_clone.as_ref(), conn_id, d2c_flush_policy.ack_flush_immediate, @@ -1259,24 +1181,23 @@ where forensics.bytes_c2me = forensics .bytes_c2me .saturating_add(payload.len() as u64); - if let Some(limit) = quota_limit { - let quota_lock = quota_user_lock(&user); - let _quota_guard = quota_lock.lock().await; - let Some(cross_mode_lock) = cross_mode_quota_lock.as_ref() else { - main_result = Err(ProxyError::Proxy( - "cross-mode quota lock missing for quota-limited session" - .to_string(), - )); - break; - }; - let _cross_mode_quota_guard = cross_mode_lock.lock().await; - stats.add_user_octets_from(&user, payload.len() as u64); - if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) { + if let (Some(limit), Some(user_stats)) = + (quota_limit, quota_user_stats.as_deref()) + { + if reserve_user_quota_with_yield( + user_stats, + payload.len() as u64, + limit, + ) + .await + .is_err() + { main_result = Err(ProxyError::DataQuotaExceeded { user: user.clone(), }); break; } + stats.add_user_octets_from_handle(user_stats, payload.len() as u64); } else { stats.add_user_octets_from(&user, payload.len() as u64); } @@ -1602,8 +1523,7 @@ where } if !idle_policy.enabled { - consecutive_zero_len_frames = - consecutive_zero_len_frames.saturating_add(1); + consecutive_zero_len_frames = consecutive_zero_len_frames.saturating_add(1); if consecutive_zero_len_frames > LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES { stats.increment_relay_protocol_desync_close_total(); return Err(ProxyError::Proxy( @@ -1755,7 +1675,6 @@ enum MeWriterResponseOutcome { Close, } -#[cfg(test)] async fn process_me_writer_response( response: MeResponse, client_writer: &mut CryptoWriter, @@ -1764,6 +1683,7 @@ async fn process_me_writer_response( frame_buf: &mut Vec, stats: &Stats, user: &str, + quota_user_stats: Option<&UserStats>, quota_limit: Option, quota_soft_overshoot_bytes: u64, bytes_me2c: &AtomicU64, @@ -1771,44 +1691,6 @@ async fn process_me_writer_response( ack_flush_immediate: bool, batched: bool, ) -> Result -where - W: AsyncWrite + Unpin + Send + 'static, -{ - process_me_writer_response_with_cross_mode_lock( - response, - client_writer, - proto_tag, - rng, - frame_buf, - stats, - user, - quota_limit, - quota_soft_overshoot_bytes, - None, - bytes_me2c, - conn_id, - ack_flush_immediate, - batched, - ) - .await -} - -async fn process_me_writer_response_with_cross_mode_lock( - response: MeResponse, - client_writer: &mut CryptoWriter, - proto_tag: ProtoTag, - rng: &SecureRandom, - frame_buf: &mut Vec, - stats: &Stats, - user: &str, - quota_limit: Option, - quota_soft_overshoot_bytes: u64, - cross_mode_quota_lock: Option<&Arc>>, - bytes_me2c: &AtomicU64, - conn_id: u64, - ack_flush_immediate: bool, - batched: bool, -) -> Result where W: AsyncWrite + Unpin + Send + 'static, { @@ -1820,78 +1702,43 @@ where trace!(conn_id, bytes = data.len(), flags, "ME->C data"); } let data_len = data.len() as u64; - if let Some(limit) = quota_limit { - let owned_cross_mode_lock; - let cross_mode_lock = if let Some(lock) = cross_mode_quota_lock { - lock - } else { - owned_cross_mode_lock = - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user); - &owned_cross_mode_lock - }; - let cross_mode_quota_guard = cross_mode_lock.lock().await; + if let (Some(limit), Some(user_stats)) = (quota_limit, quota_user_stats) { let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes); - if quota_would_be_exceeded_for_user_soft( - stats, - user, - Some(limit), - data_len, - quota_soft_overshoot_bytes, - ) { + if reserve_user_quota_with_yield(user_stats, data_len, soft_limit) + .await + .is_err() + { stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); return Err(ProxyError::DataQuotaExceeded { user: user.to_string(), }); } - - // Reserve quota before awaiting network I/O to avoid same-user HoL stalls. - // If reservation loses a race or write fails, we roll back immediately. - bytes_me2c.fetch_add(data_len, Ordering::Relaxed); - stats.add_user_octets_to(user, data_len); - - if stats.get_user_total_octets(user) > soft_limit { - rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len); - stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite); - return Err(ProxyError::DataQuotaExceeded { - user: user.to_string(), - }); - } - - // Keep cross-mode lock scope explicit and minimal: quota reservation is serialized, - // but socket I/O proceeds without holding same-user cross-mode admission lock. - drop(cross_mode_quota_guard); - - let write_mode = - match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) - .await - { - Ok(mode) => mode, - Err(err) => { - rollback_me2c_quota_reservation(stats, user, bytes_me2c, data_len); - return Err(err); - } - }; - - stats.increment_me_d2c_data_frames_total(); - stats.add_me_d2c_payload_bytes_total(data_len); - stats.increment_me_d2c_write_mode(write_mode); - - // Do not fail immediately on exact boundary after a successful write. - // Returning an error here can bypass batch flush in the caller and risk - // dropping buffered ciphertext from CryptoWriter. The next frame is - // rejected by the pre-check at function entry. - } else { - let write_mode = - write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) - .await?; - - bytes_me2c.fetch_add(data_len, Ordering::Relaxed); - stats.add_user_octets_to(user, data_len); - stats.increment_me_d2c_data_frames_total(); - stats.add_me_d2c_payload_bytes_total(data_len); - stats.increment_me_d2c_write_mode(write_mode); } + let write_mode = + match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf) + .await + { + Ok(mode) => mode, + Err(err) => { + if quota_limit.is_some() { + stats.add_quota_write_fail_bytes_total(data_len); + stats.increment_quota_write_fail_events_total(); + } + return Err(err); + } + }; + + bytes_me2c.fetch_add(data_len, Ordering::Relaxed); + if let Some(user_stats) = quota_user_stats { + stats.add_user_octets_to_handle(user_stats, data_len); + } else { + stats.add_user_octets_to(user, data_len); + } + stats.increment_me_d2c_data_frames_total(); + stats.add_me_d2c_payload_bytes_total(data_len); + stats.increment_me_d2c_write_mode(write_mode); + Ok(MeWriterResponseOutcome::Continue { frames: 1, bytes: data.len(), @@ -1990,8 +1837,14 @@ where MeD2cWriteMode::Coalesced } else { let header = [first]; - client_writer.write_all(&header).await.map_err(ProxyError::Io)?; - client_writer.write_all(data).await.map_err(ProxyError::Io)?; + client_writer + .write_all(&header) + .await + .map_err(ProxyError::Io)?; + client_writer + .write_all(data) + .await + .map_err(ProxyError::Io)?; MeD2cWriteMode::Split } } else if len_words < (1 << 24) { @@ -2013,8 +1866,14 @@ where MeD2cWriteMode::Coalesced } else { let header = [first, lw[0], lw[1], lw[2]]; - client_writer.write_all(&header).await.map_err(ProxyError::Io)?; - client_writer.write_all(data).await.map_err(ProxyError::Io)?; + client_writer + .write_all(&header) + .await + .map_err(ProxyError::Io)?; + client_writer + .write_all(data) + .await + .map_err(ProxyError::Io)?; MeD2cWriteMode::Split } } else { @@ -2056,8 +1915,14 @@ where MeD2cWriteMode::Coalesced } else { let header = len_val.to_le_bytes(); - client_writer.write_all(&header).await.map_err(ProxyError::Io)?; - client_writer.write_all(data).await.map_err(ProxyError::Io)?; + client_writer + .write_all(&header) + .await + .map_err(ProxyError::Io)?; + client_writer + .write_all(data) + .await + .map_err(ProxyError::Io)?; if padding_len > 0 { frame_buf.clear(); if frame_buf.capacity() < padding_len { @@ -2097,10 +1962,6 @@ where .map_err(ProxyError::Io) } -#[cfg(test)] -#[path = "tests/middle_relay_security_tests.rs"] -mod security_tests; - #[cfg(test)] #[path = "tests/middle_relay_idle_policy_security_tests.rs"] mod idle_policy_security_tests; @@ -2113,30 +1974,10 @@ mod desync_all_full_dedup_security_tests; #[path = "tests/middle_relay_stub_completion_security_tests.rs"] mod stub_completion_security_tests; -#[cfg(test)] -#[path = "tests/middle_relay_coverage_high_risk_security_tests.rs"] -mod coverage_high_risk_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_quota_overflow_lock_security_tests.rs"] -mod quota_overflow_lock_security_tests; - #[cfg(test)] #[path = "tests/middle_relay_length_cast_hardening_security_tests.rs"] mod length_cast_hardening_security_tests; -#[cfg(test)] -#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"] -mod blackhat_campaign_integration_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_hol_quota_security_tests.rs"] -mod hol_quota_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_quota_reservation_adversarial_tests.rs"] -mod quota_reservation_adversarial_tests; - #[cfg(test)] #[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"] mod middle_relay_idle_registry_poison_security_tests; @@ -2158,25 +1999,5 @@ mod middle_relay_tiny_frame_debt_concurrency_security_tests; mod middle_relay_tiny_frame_debt_proto_chunking_security_tests; #[cfg(test)] -#[path = "tests/middle_relay_cross_mode_quota_reservation_security_tests.rs"] -mod middle_relay_cross_mode_quota_reservation_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs"] -mod middle_relay_cross_mode_quota_lock_matrix_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs"] -mod middle_relay_cross_mode_lookup_efficiency_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs"] -mod middle_relay_cross_mode_lock_release_regression_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_quota_extended_attack_surface_security_tests.rs"] -mod middle_relay_quota_extended_attack_surface_security_tests; - -#[cfg(test)] -#[path = "tests/middle_relay_quota_reservation_extreme_security_tests.rs"] -mod middle_relay_quota_reservation_extreme_security_tests; +#[path = "tests/middle_relay_atomic_quota_invariant_tests.rs"] +mod middle_relay_atomic_quota_invariant_tests; diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 519f1b3..5880558 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -4,58 +4,58 @@ #![cfg_attr(test, allow(warnings))] #![cfg_attr(not(test), forbid(clippy::undocumented_unsafe_blocks))] #![cfg_attr( - not(test), - deny( - clippy::unwrap_used, - clippy::expect_used, - clippy::panic, - clippy::todo, - clippy::unimplemented, - clippy::correctness, - clippy::option_if_let_else, - clippy::or_fun_call, - clippy::branches_sharing_code, - clippy::single_option_map, - clippy::useless_let_if_seq, - clippy::redundant_locals, - clippy::cloned_ref_to_slice_refs, - unsafe_code, - clippy::await_holding_lock, - clippy::await_holding_refcell_ref, - clippy::debug_assert_with_mut_call, - clippy::macro_use_imports, - clippy::cast_ptr_alignment, - clippy::cast_lossless, - clippy::ptr_as_ptr, - clippy::large_stack_arrays, - clippy::same_functions_in_if_condition, - trivial_casts, - trivial_numeric_casts, - unused_extern_crates, - unused_import_braces, - rust_2018_idioms - ) + not(test), + deny( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::todo, + clippy::unimplemented, + clippy::correctness, + clippy::option_if_let_else, + clippy::or_fun_call, + clippy::branches_sharing_code, + clippy::single_option_map, + clippy::useless_let_if_seq, + clippy::redundant_locals, + clippy::cloned_ref_to_slice_refs, + unsafe_code, + clippy::await_holding_lock, + clippy::await_holding_refcell_ref, + clippy::debug_assert_with_mut_call, + clippy::macro_use_imports, + clippy::cast_ptr_alignment, + clippy::cast_lossless, + clippy::ptr_as_ptr, + clippy::large_stack_arrays, + clippy::same_functions_in_if_condition, + trivial_casts, + trivial_numeric_casts, + unused_extern_crates, + unused_import_braces, + rust_2018_idioms + ) )] #![cfg_attr( - not(test), - allow( - clippy::use_self, - clippy::redundant_closure, - clippy::too_many_arguments, - clippy::doc_markdown, - clippy::missing_const_for_fn, - clippy::unnecessary_operation, - clippy::redundant_pub_crate, - clippy::derive_partial_eq_without_eq, - clippy::type_complexity, - clippy::new_ret_no_self, - clippy::cast_possible_truncation, - clippy::cast_possible_wrap, - clippy::significant_drop_tightening, - clippy::significant_drop_in_scrutinee, - clippy::float_cmp, - clippy::nursery - ) + not(test), + allow( + clippy::use_self, + clippy::redundant_closure, + clippy::too_many_arguments, + clippy::doc_markdown, + clippy::missing_const_for_fn, + clippy::unnecessary_operation, + clippy::redundant_pub_crate, + clippy::derive_partial_eq_without_eq, + clippy::type_complexity, + clippy::new_ret_no_self, + clippy::cast_possible_truncation, + clippy::cast_possible_wrap, + clippy::significant_drop_tightening, + clippy::significant_drop_in_scrutinee, + clippy::float_cmp, + clippy::nursery + ) )] pub mod adaptive_buffers; @@ -64,7 +64,6 @@ pub mod direct_relay; pub mod handshake; pub mod masking; pub mod middle_relay; -pub mod quota_lock_registry; pub mod relay; pub mod route_mode; pub mod session_eviction; diff --git a/src/proxy/quota_lock_registry.rs b/src/proxy/quota_lock_registry.rs deleted file mode 100644 index 7798b09..0000000 --- a/src/proxy/quota_lock_registry.rs +++ /dev/null @@ -1,88 +0,0 @@ -use dashmap::DashMap; -use std::sync::{Arc, OnceLock}; -use tokio::sync::Mutex; - -#[cfg(test)] -use std::sync::atomic::{AtomicUsize, Ordering}; - -#[cfg(test)] -const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 64; -#[cfg(not(test))] -const CROSS_MODE_QUOTA_USER_LOCKS_MAX: usize = 4_096; -#[cfg(test)] -const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; -#[cfg(not(test))] -const CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; - -static CROSS_MODE_QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); -static CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); - -#[cfg(test)] -static CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS: AtomicUsize = AtomicUsize::new(0); -#[cfg(test)] -static CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER: OnceLock> = OnceLock::new(); - -fn cross_mode_quota_overflow_user_lock(user: &str) -> Arc> { - let stripes = CROSS_MODE_QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { - (0..CROSS_MODE_QUOTA_OVERFLOW_LOCK_STRIPES) - .map(|_| Arc::new(Mutex::new(()))) - .collect() - }); - - let hash = crc32fast::hash(user.as_bytes()) as usize; - Arc::clone(&stripes[hash % stripes.len()]) -} - -pub(crate) fn cross_mode_quota_user_lock(user: &str) -> Arc> { - #[cfg(test)] - { - CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.fetch_add(1, Ordering::Relaxed); - let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); - let mut entry = lookups.entry(user.to_string()).or_insert(0); - *entry += 1; - } - - let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); - if let Some(existing) = locks.get(user) { - return Arc::clone(existing.value()); - } - - if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX { - locks.retain(|_, value| Arc::strong_count(value) > 1); - } - - if locks.len() >= CROSS_MODE_QUOTA_USER_LOCKS_MAX { - return cross_mode_quota_overflow_user_lock(user); - } - - let created = Arc::new(Mutex::new(())); - match locks.entry(user.to_string()) { - dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), - dashmap::mapref::entry::Entry::Vacant(entry) => { - entry.insert(Arc::clone(&created)); - created - } - } -} - -#[cfg(test)] -pub(crate) fn reset_cross_mode_quota_user_lock_lookup_count_for_tests() { - CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.store(0, Ordering::Relaxed); - let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); - lookups.clear(); -} - -#[cfg(test)] -pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_tests() -> usize { - CROSS_MODE_QUOTA_USER_LOCK_LOOKUPS.load(Ordering::Relaxed) -} - -#[cfg(test)] -pub(crate) fn cross_mode_quota_user_lock_lookup_count_for_user_for_tests(user: &str) -> usize { - let lookups = CROSS_MODE_QUOTA_USER_LOOKUPS_BY_USER.get_or_init(DashMap::new); - lookups.get(user).map(|entry| *entry).unwrap_or(0) -} - -#[cfg(test)] -#[path = "tests/quota_lock_registry_cross_mode_adversarial_tests.rs"] -mod quota_lock_registry_cross_mode_adversarial_tests; diff --git a/src/proxy/relay.rs b/src/proxy/relay.rs index 90b46d9..6000e18 100644 --- a/src/proxy/relay.rs +++ b/src/proxy/relay.rs @@ -52,18 +52,16 @@ //! - `SharedCounters` (atomics) let the watchdog read stats without locking use crate::error::{ProxyError, Result}; -use crate::stats::Stats; +use crate::stats::{Stats, UserStats}; use crate::stream::BufferPool; -use dashmap::DashMap; use std::io; use std::pin::Pin; +use std::sync::Arc; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::{Arc, Mutex, OnceLock}; use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes}; -use tokio::sync::Mutex as AsyncMutex; -use tokio::time::{Instant, Sleep}; +use tokio::time::Instant; use tracing::{debug, trace, warn}; // ============= Constants ============= @@ -210,16 +208,10 @@ struct StatsIo { counters: Arc, stats: Arc, user: String, - quota_lock: Option>>, - cross_mode_quota_lock: Option>>, + user_stats: Arc, quota_limit: Option, quota_exceeded: Arc, - quota_read_wake_scheduled: bool, - quota_write_wake_scheduled: bool, - quota_read_retry_sleep: Option>>, - quota_write_retry_sleep: Option>>, - quota_read_retry_attempt: u8, - quota_write_retry_attempt: u8, + quota_bytes_since_check: u64, epoch: Instant, } @@ -235,24 +227,16 @@ impl StatsIo { ) -> Self { // Mark initial activity so the watchdog doesn't fire before data flows counters.touch(Instant::now(), epoch); - let quota_lock = quota_limit.map(|_| quota_user_lock(&user)); - let cross_mode_quota_lock = quota_limit - .map(|_| crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); + let user_stats = stats.get_or_create_user_stats_handle(&user); Self { inner, counters, stats, user, - quota_lock, - cross_mode_quota_lock, + user_stats, quota_limit, quota_exceeded, - quota_read_wake_scheduled: false, - quota_write_wake_scheduled: false, - quota_read_retry_sleep: None, - quota_write_retry_sleep: None, - quota_read_retry_attempt: 0, - quota_write_retry_attempt: 0, + quota_bytes_since_check: 0, epoch, } } @@ -281,193 +265,22 @@ fn is_quota_io_error(err: &io::Error) -> bool { .is_some() } -#[cfg(test)] -const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1); -#[cfg(not(test))] -const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2); -#[cfg(test)] -const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(16); -#[cfg(not(test))] -const QUOTA_CONTENTION_RETRY_MAX_INTERVAL: Duration = Duration::from_millis(64); +const QUOTA_NEAR_LIMIT_BYTES: u64 = 64 * 1024; +const QUOTA_LARGE_CHARGE_BYTES: u64 = 16 * 1024; +const QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES: u64 = 4 * 1024; +const QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES: u64 = 64 * 1024; -#[cfg(test)] -static QUOTA_RETRY_SLEEP_ALLOCS: AtomicU64 = AtomicU64::new(0); -#[cfg(test)] -static QUOTA_RETRY_SLEEP_ALLOCS_BY_USER: OnceLock> = OnceLock::new(); -#[cfg(test)] -static QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT: AtomicU64 = AtomicU64::new(0); - -#[cfg(test)] -pub(crate) fn reset_quota_retry_sleep_allocs_for_tests() { - QUOTA_RETRY_SLEEP_ALLOCS.store(0, Ordering::Relaxed); -} - -#[cfg(test)] -pub(crate) fn reset_quota_retry_sleep_allocs_for_user_for_tests(user: &str) { - let map = QUOTA_RETRY_SLEEP_ALLOCS_BY_USER.get_or_init(DashMap::new); - map.remove(user); -} - -#[cfg(test)] -pub(crate) fn quota_retry_sleep_allocs_for_tests() -> u64 { - QUOTA_RETRY_SLEEP_ALLOCS.load(Ordering::Relaxed) -} - -#[cfg(test)] -pub(crate) fn quota_retry_sleep_allocs_for_user_for_tests(user: &str) -> u64 { - let map = QUOTA_RETRY_SLEEP_ALLOCS_BY_USER.get_or_init(DashMap::new); - map.get(user).map(|v| *v.value()).unwrap_or(0) +#[inline] +fn quota_adaptive_interval_bytes(remaining_before: u64) -> u64 { + remaining_before.saturating_div(2).clamp( + QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES, + QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES, + ) } #[inline] -fn quota_contention_retry_delay(retry_attempt: u8) -> Duration { - let shift = u32::from(retry_attempt.min(5)); - let multiplier = 1_u32 << shift; - QUOTA_CONTENTION_RETRY_INTERVAL - .saturating_mul(multiplier) - .min(QUOTA_CONTENTION_RETRY_MAX_INTERVAL) -} - -#[inline] -fn reset_quota_retry_scheduler( - sleep_slot: &mut Option>>, - wake_scheduled: &mut bool, - retry_attempt: &mut u8, -) { - *wake_scheduled = false; - *sleep_slot = None; - *retry_attempt = 0; -} - -fn poll_quota_retry_sleep( - sleep_slot: &mut Option>>, - wake_scheduled: &mut bool, - retry_attempt: &mut u8, - user: &str, - cx: &mut Context<'_>, -) { - #[cfg(not(test))] - let _ = user; - - if !*wake_scheduled { - *wake_scheduled = true; - #[cfg(test)] - { - QUOTA_RETRY_SLEEP_ALLOCS.fetch_add(1, Ordering::Relaxed); - let map = QUOTA_RETRY_SLEEP_ALLOCS_BY_USER.get_or_init(DashMap::new); - map.entry(user.to_string()) - .and_modify(|count| *count = count.saturating_add(1)) - .or_insert(1); - } - *sleep_slot = Some(Box::pin(tokio::time::sleep(quota_contention_retry_delay( - *retry_attempt, - )))); - } - - if let Some(sleep) = sleep_slot.as_mut() - && sleep.as_mut().poll(cx).is_ready() - { - *sleep_slot = None; - *wake_scheduled = false; - *retry_attempt = retry_attempt.saturating_add(1); - cx.waker().wake_by_ref(); - } -} - -static QUOTA_USER_LOCKS: OnceLock>>> = OnceLock::new(); -static QUOTA_USER_OVERFLOW_LOCKS: OnceLock>>> = OnceLock::new(); - -#[cfg(test)] -const QUOTA_USER_LOCKS_MAX: usize = 64; -#[cfg(not(test))] -const QUOTA_USER_LOCKS_MAX: usize = 4_096; -#[cfg(test)] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16; -#[cfg(not(test))] -const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256; - -#[cfg(test)] -fn quota_user_lock_test_guard() -> &'static Mutex<()> { - static TEST_LOCK: OnceLock> = OnceLock::new(); - TEST_LOCK.get_or_init(|| Mutex::new(())) -} - -#[cfg(test)] -fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> { - quota_user_lock_test_guard() - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) -} - -fn quota_overflow_user_lock(user: &str) -> Arc> { - let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| { - (0..QUOTA_OVERFLOW_LOCK_STRIPES) - .map(|_| Arc::new(Mutex::new(()))) - .collect() - }); - - let hash = crc32fast::hash(user.as_bytes()) as usize; - Arc::clone(&stripes[hash % stripes.len()]) -} - -pub(crate) fn quota_user_lock_evict() { - if let Some(locks) = QUOTA_USER_LOCKS.get() { - locks.retain(|_, value| Arc::strong_count(value) > 1); - } -} - -pub(crate) fn spawn_quota_user_lock_evictor(interval: Duration) -> tokio::task::JoinHandle<()> { - let interval = interval.max(Duration::from_millis(1)); - #[cfg(test)] - QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.fetch_add(1, Ordering::Relaxed); - tokio::spawn(async move { - loop { - tokio::time::sleep(interval).await; - quota_user_lock_evict(); - } - }) -} - -#[cfg(test)] -pub(crate) fn spawn_quota_user_lock_evictor_for_tests( - interval: Duration, -) -> tokio::task::JoinHandle<()> { - spawn_quota_user_lock_evictor(interval) -} - -#[cfg(test)] -pub(crate) fn reset_quota_user_lock_evictor_spawn_count_for_tests() { - QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.store(0, Ordering::Relaxed); -} - -#[cfg(test)] -pub(crate) fn quota_user_lock_evictor_spawn_count_for_tests() -> u64 { - QUOTA_USER_LOCK_EVICTOR_SPAWN_COUNT.load(Ordering::Relaxed) -} - -fn quota_user_lock(user: &str) -> Arc> { - let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - if let Some(existing) = locks.get(user) { - return Arc::clone(existing.value()); - } - - if locks.len() >= QUOTA_USER_LOCKS_MAX { - return quota_overflow_user_lock(user); - } - - let created = Arc::new(Mutex::new(())); - match locks.entry(user.to_string()) { - dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()), - dashmap::mapref::entry::Entry::Vacant(entry) => { - entry.insert(Arc::clone(&created)); - created - } - } -} - -#[cfg(test)] -pub(crate) fn cross_mode_quota_user_lock_for_tests(user: &str) -> Arc> { - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(user) +fn should_immediate_quota_check(remaining_before: u64, charge_bytes: u64) -> bool { + remaining_before <= QUOTA_NEAR_LIMIT_BYTES || charge_bytes >= QUOTA_LARGE_CHARGE_BYTES } impl AsyncRead for StatsIo { @@ -477,95 +290,60 @@ impl AsyncRead for StatsIo { buf: &mut ReadBuf<'_>, ) -> Poll> { let this = self.get_mut(); - if this.quota_exceeded.load(Ordering::Relaxed) { + if this.quota_exceeded.load(Ordering::Acquire) { return Poll::Ready(Err(quota_io_error())); } - let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() { - match lock.try_lock() { - Ok(guard) => Some(guard), - Err(_) => { - poll_quota_retry_sleep( - &mut this.quota_read_retry_sleep, - &mut this.quota_read_wake_scheduled, - &mut this.quota_read_retry_attempt, - &this.user, - cx, - ); - return Poll::Pending; - } + let mut remaining_before = None; + if let Some(limit) = this.quota_limit { + let used_before = this.user_stats.quota_used(); + let remaining = limit.saturating_sub(used_before); + if remaining == 0 { + this.quota_exceeded.store(true, Ordering::Release); + return Poll::Ready(Err(quota_io_error())); } - } else { - None - }; - - let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() { - match lock.try_lock() { - Ok(guard) => Some(guard), - Err(_) => { - poll_quota_retry_sleep( - &mut this.quota_read_retry_sleep, - &mut this.quota_read_wake_scheduled, - &mut this.quota_read_retry_attempt, - &this.user, - cx, - ); - return Poll::Pending; - } - } - } else { - None - }; - - reset_quota_retry_scheduler( - &mut this.quota_read_retry_sleep, - &mut this.quota_read_wake_scheduled, - &mut this.quota_read_retry_attempt, - ); - - if let Some(limit) = this.quota_limit - && this.stats.get_user_total_octets(&this.user) >= limit - { - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); + remaining_before = Some(remaining); } + let before = buf.filled().len(); match Pin::new(&mut this.inner).poll_read(cx, buf) { Poll::Ready(Ok(())) => { let n = buf.filled().len() - before; if n > 0 { - let mut reached_quota_boundary = false; - if let Some(limit) = this.quota_limit { - let used = this.stats.get_user_total_octets(&this.user); - if used >= limit { - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); - } - - let remaining = limit - used; - if (n as u64) > remaining { - // Fail closed: when a single read chunk would cross quota, - // stop relay immediately without accounting beyond the cap. - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); - } - - reached_quota_boundary = (n as u64) == remaining; - } + let n_to_charge = n as u64; // C→S: client sent data this.counters .c2s_bytes - .fetch_add(n as u64, Ordering::Relaxed); + .fetch_add(n_to_charge, Ordering::Relaxed); this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed); this.counters.touch(Instant::now(), this.epoch); - this.stats.add_user_octets_from(&this.user, n as u64); - this.stats.increment_user_msgs_from(&this.user); + this.stats + .add_user_octets_from_handle(this.user_stats.as_ref(), n_to_charge); + this.stats + .increment_user_msgs_from_handle(this.user_stats.as_ref()); - if reached_quota_boundary { - this.quota_exceeded.store(true, Ordering::Relaxed); + if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) { + this.stats + .quota_charge_post_write(this.user_stats.as_ref(), n_to_charge); + if should_immediate_quota_check(remaining, n_to_charge) { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } else { + this.quota_bytes_since_check = + this.quota_bytes_since_check.saturating_add(n_to_charge); + let interval = quota_adaptive_interval_bytes(remaining); + if this.quota_bytes_since_check >= interval { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } + } } trace!(user = %this.user, bytes = n, "C->S"); @@ -584,89 +362,57 @@ impl AsyncWrite for StatsIo { buf: &[u8], ) -> Poll> { let this = self.get_mut(); - if this.quota_exceeded.load(Ordering::Relaxed) { + if this.quota_exceeded.load(Ordering::Acquire) { return Poll::Ready(Err(quota_io_error())); } - let _quota_guard = if let Some(lock) = this.quota_lock.as_ref() { - match lock.try_lock() { - Ok(guard) => Some(guard), - Err(_) => { - poll_quota_retry_sleep( - &mut this.quota_write_retry_sleep, - &mut this.quota_write_wake_scheduled, - &mut this.quota_write_retry_attempt, - &this.user, - cx, - ); - return Poll::Pending; - } - } - } else { - None - }; - - let _cross_mode_quota_guard = if let Some(lock) = this.cross_mode_quota_lock.as_ref() { - match lock.try_lock() { - Ok(guard) => Some(guard), - Err(_) => { - poll_quota_retry_sleep( - &mut this.quota_write_retry_sleep, - &mut this.quota_write_wake_scheduled, - &mut this.quota_write_retry_attempt, - &this.user, - cx, - ); - return Poll::Pending; - } - } - } else { - None - }; - - reset_quota_retry_scheduler( - &mut this.quota_write_retry_sleep, - &mut this.quota_write_wake_scheduled, - &mut this.quota_write_retry_attempt, - ); - - let write_buf = if let Some(limit) = this.quota_limit { - let used = this.stats.get_user_total_octets(&this.user); - if used >= limit { - this.quota_exceeded.store(true, Ordering::Relaxed); + let mut remaining_before = None; + if let Some(limit) = this.quota_limit { + let used_before = this.user_stats.quota_used(); + let remaining = limit.saturating_sub(used_before); + if remaining == 0 { + this.quota_exceeded.store(true, Ordering::Release); return Poll::Ready(Err(quota_io_error())); } + remaining_before = Some(remaining); + } - let remaining = (limit - used) as usize; - if buf.len() > remaining { - // Fail closed: do not emit partial S->C payload when remaining - // quota cannot accommodate the pending write request. - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); - } - buf - } else { - buf - }; - - match Pin::new(&mut this.inner).poll_write(cx, write_buf) { + match Pin::new(&mut this.inner).poll_write(cx, buf) { Poll::Ready(Ok(n)) => { if n > 0 { + let n_to_charge = n as u64; + // S→C: data written to client this.counters .s2c_bytes - .fetch_add(n as u64, Ordering::Relaxed); + .fetch_add(n_to_charge, Ordering::Relaxed); this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed); this.counters.touch(Instant::now(), this.epoch); - this.stats.add_user_octets_to(&this.user, n as u64); - this.stats.increment_user_msgs_to(&this.user); + this.stats + .add_user_octets_to_handle(this.user_stats.as_ref(), n_to_charge); + this.stats + .increment_user_msgs_to_handle(this.user_stats.as_ref()); - if let Some(limit) = this.quota_limit - && this.stats.get_user_total_octets(&this.user) >= limit - { - this.quota_exceeded.store(true, Ordering::Relaxed); - return Poll::Ready(Err(quota_io_error())); + if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) { + this.stats + .quota_charge_post_write(this.user_stats.as_ref(), n_to_charge); + if should_immediate_quota_check(remaining, n_to_charge) { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } else { + this.quota_bytes_since_check = + this.quota_bytes_since_check.saturating_add(n_to_charge); + let interval = quota_adaptive_interval_bytes(remaining); + if this.quota_bytes_since_check >= interval { + this.quota_bytes_since_check = 0; + if this.user_stats.quota_used() >= limit { + this.quota_exceeded.store(true, Ordering::Release); + } + } + } } trace!(user = %this.user, bytes = n, "S->C"); @@ -760,7 +506,7 @@ where let now = Instant::now(); let idle = wd_counters.idle_duration(now, epoch); - if wd_quota_exceeded.load(Ordering::Relaxed) { + if wd_quota_exceeded.load(Ordering::Acquire) { warn!(user = %wd_user, "User data quota reached, closing relay"); return; } @@ -898,18 +644,10 @@ where } } -#[cfg(test)] -#[path = "tests/relay_security_tests.rs"] -mod security_tests; - #[cfg(test)] #[path = "tests/relay_adversarial_tests.rs"] mod adversarial_tests; -#[cfg(test)] -#[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"] -mod relay_quota_lock_pressure_adversarial_tests; - #[cfg(test)] #[path = "tests/relay_quota_boundary_blackhat_tests.rs"] mod relay_quota_boundary_blackhat_tests; @@ -931,69 +669,5 @@ mod relay_quota_extended_attack_surface_security_tests; mod relay_watchdog_delta_security_tests; #[cfg(test)] -#[path = "tests/relay_quota_waker_storm_adversarial_tests.rs"] -mod relay_quota_waker_storm_adversarial_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"] -mod relay_quota_wake_liveness_regression_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_lock_identity_security_tests.rs"] -mod relay_quota_lock_identity_security_tests; - -#[cfg(test)] -#[path = "tests/relay_cross_mode_quota_lock_security_tests.rs"] -mod relay_cross_mode_quota_lock_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_retry_scheduler_tdd_tests.rs"] -mod relay_quota_retry_scheduler_tdd_tests; - -#[cfg(test)] -#[path = "tests/relay_cross_mode_quota_fairness_tdd_tests.rs"] -mod relay_cross_mode_quota_fairness_tdd_tests; - -#[cfg(test)] -#[path = "tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs"] -mod relay_cross_mode_pipeline_hol_integration_security_tests; - -#[cfg(test)] -#[path = "tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs"] -mod relay_cross_mode_pipeline_latency_benchmark_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_retry_backoff_security_tests.rs"] -mod relay_quota_retry_backoff_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_retry_backoff_benchmark_security_tests.rs"] -mod relay_quota_retry_backoff_benchmark_security_tests; - -#[cfg(test)] -#[path = "tests/relay_dual_lock_backoff_regression_security_tests.rs"] -mod relay_dual_lock_backoff_regression_security_tests; - -#[cfg(test)] -#[path = "tests/relay_dual_lock_contention_matrix_security_tests.rs"] -mod relay_dual_lock_contention_matrix_security_tests; - -#[cfg(test)] -#[path = "tests/relay_dual_lock_race_harness_security_tests.rs"] -mod relay_dual_lock_race_harness_security_tests; - -#[cfg(test)] -#[path = "tests/relay_dual_lock_alternating_contention_security_tests.rs"] -mod relay_dual_lock_alternating_contention_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_retry_allocation_latency_security_tests.rs"] -mod relay_quota_retry_allocation_latency_security_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs"] -mod relay_quota_lock_eviction_lifecycle_tdd_tests; - -#[cfg(test)] -#[path = "tests/relay_quota_lock_eviction_stress_security_tests.rs"] -mod relay_quota_lock_eviction_stress_security_tests; +#[path = "tests/relay_atomic_quota_invariant_tests.rs"] +mod relay_atomic_quota_invariant_tests; diff --git a/src/proxy/tests/client_clever_advanced_tests.rs b/src/proxy/tests/client_clever_advanced_tests.rs index da2e703..f462ed8 100644 --- a/src/proxy/tests/client_clever_advanced_tests.rs +++ b/src/proxy/tests/client_clever_advanced_tests.rs @@ -1,5 +1,5 @@ use super::*; -use crate::config::{UpstreamConfig, UpstreamType, ProxyConfig}; +use crate::config::{ProxyConfig, UpstreamConfig, UpstreamType}; use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE}; use crate::stats::Stats; use crate::transport::UpstreamManager; @@ -41,7 +41,9 @@ fn edge_handshake_timeout_with_mask_grace_saturating_add_prevents_overflow() { #[test] fn edge_tls_clienthello_len_in_bounds_exact_boundaries() { assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE)); - assert!(!tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE - 1)); + assert!(!tls_clienthello_len_in_bounds( + MIN_TLS_CLIENT_HELLO_SIZE - 1 + )); assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE)); assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1)); } @@ -87,7 +89,15 @@ async fn adversarial_tls_handshake_timeout_during_masking_delay() { "198.51.100.1:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -99,7 +109,10 @@ async fn adversarial_tls_handshake_timeout_during_masking_delay() { false, )); - client_side.write_all(&[0x16, 0x03, 0x01, 0xFF, 0xFF]).await.unwrap(); + client_side + .write_all(&[0x16, 0x03, 0x01, 0xFF, 0xFF]) + .await + .unwrap(); let result = tokio::time::timeout(Duration::from_secs(4), handle) .await @@ -123,7 +136,15 @@ async fn blackhat_proxy_protocol_slowloris_timeout() { "198.51.100.2:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -167,7 +188,15 @@ async fn negative_proxy_protocol_enabled_but_client_sends_tls_hello() { "198.51.100.3:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -179,7 +208,10 @@ async fn negative_proxy_protocol_enabled_but_client_sends_tls_hello() { true, )); - client_side.write_all(&[0x16, 0x03, 0x01, 0x02, 0x00]).await.unwrap(); + client_side + .write_all(&[0x16, 0x03, 0x01, 0x02, 0x00]) + .await + .unwrap(); let result = tokio::time::timeout(Duration::from_secs(2), handle) .await @@ -202,7 +234,15 @@ async fn edge_client_stream_exactly_4_bytes_eof() { "198.51.100.4:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -214,7 +254,10 @@ async fn edge_client_stream_exactly_4_bytes_eof() { false, )); - client_side.write_all(&[0x16, 0x03, 0x01, 0x00]).await.unwrap(); + client_side + .write_all(&[0x16, 0x03, 0x01, 0x00]) + .await + .unwrap(); client_side.shutdown().await.unwrap(); let _ = tokio::time::timeout(Duration::from_secs(2), handle).await; @@ -234,7 +277,15 @@ async fn edge_client_stream_tls_header_valid_but_body_1_byte_short_eof() { "198.51.100.5:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -246,7 +297,10 @@ async fn edge_client_stream_tls_header_valid_but_body_1_byte_short_eof() { false, )); - client_side.write_all(&[0x16, 0x03, 0x01, 0x00, 100]).await.unwrap(); + client_side + .write_all(&[0x16, 0x03, 0x01, 0x00, 100]) + .await + .unwrap(); client_side.write_all(&vec![0x41; 99]).await.unwrap(); client_side.shutdown().await.unwrap(); @@ -269,7 +323,15 @@ async fn integration_non_tls_modes_disabled_immediately_masks() { "198.51.100.6:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -372,11 +434,7 @@ async fn stress_user_connection_reservation_concurrent_same_ip_exhaustion() { let ip_tracker = ip_tracker.clone(); tasks.spawn(async move { RunningClientHandler::acquire_user_connection_reservation_static( - user, - &config, - stats, - peer, - ip_tracker, + user, &config, stats, peer, ip_tracker, ) .await }); diff --git a/src/proxy/tests/client_deep_invariants_tests.rs b/src/proxy/tests/client_deep_invariants_tests.rs index 97c55c6..e57f817 100644 --- a/src/proxy/tests/client_deep_invariants_tests.rs +++ b/src/proxy/tests/client_deep_invariants_tests.rs @@ -7,6 +7,11 @@ use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncWriteExt, duplex}; +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[test] fn invariant_wrap_tls_application_record_exact_multiples() { let chunk_size = u16::MAX as usize; @@ -37,7 +42,15 @@ async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking() "198.51.100.20:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -60,7 +73,9 @@ async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking() .unwrap(); client_side.shutdown().await.unwrap(); - let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap(); assert_eq!(stats.get_connects_bad(), 1); } @@ -68,7 +83,10 @@ async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking() async fn invariant_acquire_reservation_ip_limit_rollback() { let user = "rollback-test-user"; let mut config = ProxyConfig::default(); - config.access.user_max_tcp_conns.insert(user.to_string(), 10); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 10); let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); @@ -114,7 +132,7 @@ async fn invariant_quota_exact_boundary_inclusive() { let ip_tracker = Arc::new(UserIpTracker::new()); let peer = "198.51.100.23:55000".parse().unwrap(); - stats.add_user_octets_from(user, 999); + preload_user_quota(stats.as_ref(), user, 999); let res1 = RunningClientHandler::acquire_user_connection_reservation_static( user, &config, @@ -126,7 +144,7 @@ async fn invariant_quota_exact_boundary_inclusive() { assert!(res1.is_ok()); res1.unwrap().release().await; - stats.add_user_octets_from(user, 1); + preload_user_quota(stats.as_ref(), user, 1); let res2 = RunningClientHandler::acquire_user_connection_reservation_static( user, &config, @@ -154,7 +172,15 @@ async fn invariant_direct_mode_partial_header_eof_is_error_not_bad_connect() { "198.51.100.25:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), diff --git a/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs b/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs index fcf51ab..3036f95 100644 --- a/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs +++ b/src/proxy/tests/client_masking_http2_fragmented_preface_security_tests.rs @@ -100,14 +100,7 @@ async fn run_http2_fragment_case(split_at: usize, delay_ms: u64, peer: SocketAdd #[tokio::test] async fn http2_preface_fragmentation_matrix_is_classified_and_forwarded() { - let cases = [ - (2usize, 0u64), - (3, 0), - (4, 0), - (2, 7), - (3, 7), - (8, 1), - ]; + let cases = [(2usize, 0u64), (3, 0), (4, 0), (2, 7), (3, 7), (8, 1)]; for (i, (split_at, delay_ms)) in cases.into_iter().enumerate() { let peer: SocketAddr = format!("198.51.100.{}:58{}", 140 + i, 100 + i) diff --git a/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs b/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs index cdf2136..64e7a85 100644 --- a/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs +++ b/src/proxy/tests/client_masking_prefetch_config_runtime_security_tests.rs @@ -29,7 +29,10 @@ async fn configured_prefetch_budget_20ms_recovers_tail_delayed_15ms() { .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") .await .expect("tail bytes must be writable"); - writer.shutdown().await.expect("writer shutdown must succeed"); + writer + .shutdown() + .await + .expect("writer shutdown must succeed"); }); let mut initial_data = b"C".to_vec(); @@ -60,7 +63,10 @@ async fn configured_prefetch_budget_5ms_misses_tail_delayed_15ms() { .write_all(b"ONNECT example.org:443 HTTP/1.1\r\n") .await .expect("tail bytes must be writable"); - writer.shutdown().await.expect("writer shutdown must succeed"); + writer + .shutdown() + .await + .expect("writer shutdown must succeed"); }); let mut initial_data = b"C".to_vec(); diff --git a/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs index 2e03ce9..b49db3c 100644 --- a/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs +++ b/src/proxy/tests/client_masking_prefetch_invariant_security_tests.rs @@ -245,7 +245,10 @@ async fn blackhat_integration_empty_initial_data_path_is_byte_exact_and_eof_clea assert_eq!(head[0], 0x16); read_and_discard_tls_record_body(&mut client_side, head).await; - client_side.write_all(&invalid_mtproto_record).await.unwrap(); + client_side + .write_all(&invalid_mtproto_record) + .await + .unwrap(); client_side.write_all(&trailing_record).await.unwrap(); client_side.shutdown().await.unwrap(); diff --git a/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs b/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs index 9ece258..cbb6603 100644 --- a/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs +++ b/src/proxy/tests/client_masking_prefetch_strict_boundary_security_tests.rs @@ -7,7 +7,9 @@ async fn run_strict_prefetch_case(prefetch_ms: u64, tail_delay_ms: u64) -> Vec= Duration::from_millis(40) - && replay_elapsed < Duration::from_millis(250), + replay_elapsed >= Duration::from_millis(40) && replay_elapsed < Duration::from_millis(250), "replay rejection path must still satisfy masking timing budget without unbounded DB/CPU delay" ); } diff --git a/src/proxy/tests/client_more_advanced_tests.rs b/src/proxy/tests/client_more_advanced_tests.rs index 021848a..8f9d832 100644 --- a/src/proxy/tests/client_more_advanced_tests.rs +++ b/src/proxy/tests/client_more_advanced_tests.rs @@ -6,6 +6,11 @@ use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn edge_mask_delay_bypassed_if_max_is_zero() { let mut config = ProxyConfig::default(); @@ -42,17 +47,13 @@ async fn boundary_user_data_quota_exact_match_rejects() { config.access.user_data_quota.insert(user.to_string(), 1024); let stats = Arc::new(Stats::new()); - stats.add_user_octets_from(user, 1024); + preload_user_quota(stats.as_ref(), user, 1024); let ip_tracker = Arc::new(UserIpTracker::new()); let peer = "198.51.100.10:55000".parse().unwrap(); let result = RunningClientHandler::acquire_user_connection_reservation_static( - user, - &config, - stats, - peer, - ip_tracker, + user, &config, stats, peer, ip_tracker, ) .await; @@ -74,11 +75,7 @@ async fn boundary_user_expiration_in_past_rejects() { let peer = "198.51.100.11:55000".parse().unwrap(); let result = RunningClientHandler::acquire_user_connection_reservation_static( - user, - &config, - stats, - peer, - ip_tracker, + user, &config, stats, peer, ip_tracker, ) .await; @@ -98,7 +95,15 @@ async fn blackhat_proxy_protocol_massive_garbage_rejected_quickly() { "198.51.100.12:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -136,7 +141,15 @@ async fn edge_tls_body_immediate_eof_triggers_masking_and_bad_connect() { "198.51.100.13:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -148,10 +161,15 @@ async fn edge_tls_body_immediate_eof_triggers_masking_and_bad_connect() { false, )); - client_side.write_all(&[0x16, 0x03, 0x01, 0x00, 100]).await.unwrap(); + client_side + .write_all(&[0x16, 0x03, 0x01, 0x00, 100]) + .await + .unwrap(); client_side.shutdown().await.unwrap(); - let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap(); assert_eq!(stats.get_connects_bad(), 1); } @@ -172,7 +190,15 @@ async fn security_classic_mode_disabled_masks_valid_length_payload() { "198.51.100.15:55000".parse().unwrap(), config, stats.clone(), - Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())), + Arc::new(UpstreamManager::new( + vec![], + 1, + 1, + 1, + 1, + false, + stats.clone(), + )), Arc::new(ReplayChecker::new(128, Duration::from_secs(60))), Arc::new(BufferPool::new()), Arc::new(SecureRandom::new()), @@ -187,7 +213,9 @@ async fn security_classic_mode_disabled_masks_valid_length_payload() { client_side.write_all(&vec![0xEF; 64]).await.unwrap(); client_side.shutdown().await.unwrap(); - let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap(); + let _ = tokio::time::timeout(Duration::from_secs(2), handler) + .await + .unwrap(); assert_eq!(stats.get_connects_bad(), 1); } @@ -195,7 +223,10 @@ async fn security_classic_mode_disabled_masks_valid_length_payload() { async fn concurrency_ip_tracker_strict_limit_one_rapid_churn() { let user = "rapid-churn-user"; let mut config = ProxyConfig::default(); - config.access.user_max_tcp_conns.insert(user.to_string(), 10); + config + .access + .user_max_tcp_conns + .insert(user.to_string(), 10); let stats = Arc::new(Stats::new()); let ip_tracker = Arc::new(UserIpTracker::new()); diff --git a/src/proxy/tests/client_security_tests.rs b/src/proxy/tests/client_security_tests.rs index 2b1fae6..1b46c6d 100644 --- a/src/proxy/tests/client_security_tests.rs +++ b/src/proxy/tests/client_security_tests.rs @@ -7,9 +7,9 @@ use crate::protocol::tls; use crate::proxy::handshake::HandshakeSuccess; use crate::stream::{CryptoReader, CryptoWriter}; use crate::transport::proxy_protocol::ProxyProtocolV1Builder; -use rand::rngs::StdRng; use rand::Rng; use rand::SeedableRng; +use rand::rngs::StdRng; use std::net::Ipv4Addr; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::{TcpListener, TcpStream}; @@ -34,7 +34,10 @@ fn handshake_timeout_with_mask_grace_includes_mask_margin() { config.timeouts.client_handshake = 2; config.censorship.mask = false; - assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_secs(2)); + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_secs(2) + ); config.censorship.mask = true; assert_eq!( @@ -86,7 +89,10 @@ impl tokio::io::AsyncRead for ErrorReader { _cx: &mut std::task::Context<'_>, _buf: &mut tokio::io::ReadBuf<'_>, ) -> std::task::Poll> { - std::task::Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "fake error"))) + std::task::Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "fake error", + ))) } } @@ -124,7 +130,10 @@ fn handshake_timeout_without_mask_is_exact_base() { config.timeouts.client_handshake = 7; config.censorship.mask = false; - assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_secs(7)); + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_secs(7) + ); } #[test] @@ -133,7 +142,10 @@ fn handshake_timeout_mask_enabled_adds_750ms() { config.timeouts.client_handshake = 3; config.censorship.mask = true; - assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_millis(3750)); + assert_eq!( + handshake_timeout_with_mask_grace(&config), + Duration::from_millis(3750) + ); } #[tokio::test] @@ -155,10 +167,12 @@ async fn read_with_progress_fragmented_io_works_over_multiple_calls() { let mut b = vec![0u8; chunk_size]; let n = read_with_progress(&mut cursor, &mut b).await.unwrap(); result.extend_from_slice(&b[..n]); - if n == 0 { break; } + if n == 0 { + break; + } } - assert_eq!(result, vec![1,2,3,4,5]); + assert_eq!(result, vec![1, 2, 3, 4, 5]); } #[tokio::test] @@ -174,7 +188,9 @@ async fn read_with_progress_stress_randomized_chunk_sizes() { let mut b = vec![0u8; chunk]; let read = read_with_progress(&mut cursor, &mut b).await.unwrap(); collected.extend_from_slice(&b[..read]); - if read == 0 { break; } + if read == 0 { + break; + } } assert_eq!(collected, input); @@ -215,10 +231,12 @@ fn wrap_tls_application_record_roundtrip_size_check() { let mut consumed = 0; while idx + 5 <= wrapped.len() { assert_eq!(wrapped[idx], 0x17); - let len = u16::from_be_bytes([wrapped[idx+3], wrapped[idx+4]]) as usize; + let len = u16::from_be_bytes([wrapped[idx + 3], wrapped[idx + 4]]) as usize; consumed += len; idx += 5 + len; - if idx >= wrapped.len() { break; } + if idx >= wrapped.len() { + break; + } } assert_eq!(consumed, payload_len); @@ -242,6 +260,11 @@ where CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() { let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new()); @@ -3040,7 +3063,7 @@ async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() { .insert("user".to_string(), 1024); let stats = Stats::new(); - stats.add_user_octets_from("user", 1024); + preload_user_quota(&stats, "user", 1024); let ip_tracker = UserIpTracker::new(); let peer_addr: SocketAddr = "203.0.113.211:50001".parse().unwrap(); diff --git a/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs b/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs index 08f52d1..7964cdd 100644 --- a/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs +++ b/src/proxy/tests/client_tls_record_wrap_hardening_security_tests.rs @@ -25,13 +25,26 @@ fn wrap_tls_application_record_oversized_payload_is_chunked_without_truncation() let len = u16::from_be_bytes([record[offset + 3], record[offset + 4]]) as usize; let body_start = offset + 5; let body_end = body_start + len; - assert!(body_end <= record.len(), "declared TLS record length must be in-bounds"); + assert!( + body_end <= record.len(), + "declared TLS record length must be in-bounds" + ); recovered.extend_from_slice(&record[body_start..body_end]); offset = body_end; frames += 1; } - assert_eq!(offset, record.len(), "record parser must consume exact output size"); - assert_eq!(frames, 2, "oversized payload should split into exactly two records"); - assert_eq!(recovered, payload, "chunked records must preserve full payload"); + assert_eq!( + offset, + record.len(), + "record parser must consume exact output size" + ); + assert_eq!( + frames, 2, + "oversized payload should split into exactly two records" + ); + assert_eq!( + recovered, payload, + "chunked records must preserve full payload" + ); } diff --git a/src/proxy/tests/direct_relay_security_tests.rs b/src/proxy/tests/direct_relay_security_tests.rs index 16fe8da..a731830 100644 --- a/src/proxy/tests/direct_relay_security_tests.rs +++ b/src/proxy/tests/direct_relay_security_tests.rs @@ -773,8 +773,7 @@ fn anchored_open_nix_path_writes_expected_lines() { "target/telemt-unknown-dc-anchored-open-ok-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let _ = fs::remove_file(&sanitized.resolved_path); let mut first = open_unknown_dc_log_append_anchored(&sanitized) @@ -787,7 +786,10 @@ fn anchored_open_nix_path_writes_expected_lines() { let content = fs::read_to_string(&sanitized.resolved_path).expect("anchored log file must be readable"); - let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); + let lines: Vec<&str> = content + .lines() + .filter(|line| !line.trim().is_empty()) + .collect(); assert_eq!(lines.len(), 2, "expected one line per anchored append call"); assert!( lines.contains(&"dc_idx=31200") && lines.contains(&"dc_idx=31201"), @@ -811,8 +813,7 @@ fn anchored_open_parallel_appends_preserve_line_integrity() { "target/telemt-unknown-dc-anchored-open-parallel-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let _ = fs::remove_file(&sanitized.resolved_path); let mut workers = Vec::new(); @@ -831,8 +832,15 @@ fn anchored_open_parallel_appends_preserve_line_integrity() { let content = fs::read_to_string(&sanitized.resolved_path).expect("parallel log file must be readable"); - let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); - assert_eq!(lines.len(), 64, "expected one complete line per worker append"); + let lines: Vec<&str> = content + .lines() + .filter(|line| !line.trim().is_empty()) + .collect(); + assert_eq!( + lines.len(), + 64, + "expected one complete line per worker append" + ); for line in lines { assert!( line.starts_with("dc_idx="), @@ -867,8 +875,7 @@ fn anchored_open_creates_private_0600_file_permissions() { "target/telemt-unknown-dc-anchored-perms-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let _ = fs::remove_file(&sanitized.resolved_path); let mut file = open_unknown_dc_log_append_anchored(&sanitized) @@ -905,8 +912,7 @@ fn anchored_open_rejects_existing_symlink_target() { "target/telemt-unknown-dc-anchored-symlink-target-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let outside = std::env::temp_dir().join(format!( "telemt-unknown-dc-anchored-symlink-outside-{}.log", @@ -943,8 +949,7 @@ fn anchored_open_high_contention_multi_write_preserves_complete_lines() { "target/telemt-unknown-dc-anchored-contention-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); let _ = fs::remove_file(&sanitized.resolved_path); let workers = 24usize; @@ -970,7 +975,10 @@ fn anchored_open_high_contention_multi_write_preserves_complete_lines() { let content = fs::read_to_string(&sanitized.resolved_path) .expect("contention output file must be readable"); - let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect(); + let lines: Vec<&str> = content + .lines() + .filter(|line| !line.trim().is_empty()) + .collect(); assert_eq!( lines.len(), workers * rounds, @@ -1014,8 +1022,7 @@ fn append_unknown_dc_line_returns_error_for_read_only_descriptor() { "target/telemt-unknown-dc-append-ro-{}/unknown-dc.log", std::process::id() ); - let sanitized = - sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); + let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize"); fs::write(&sanitized.resolved_path, "seed\n").expect("seed file must be writable"); let mut readonly = std::fs::OpenOptions::new() diff --git a/src/proxy/tests/handshake_advanced_clever_tests.rs b/src/proxy/tests/handshake_advanced_clever_tests.rs index 9b12f21..76347c4 100644 --- a/src/proxy/tests/handshake_advanced_clever_tests.rs +++ b/src/proxy/tests/handshake_advanced_clever_tests.rs @@ -1,5 +1,5 @@ use super::*; -use crate::crypto::{sha256, sha256_hmac, AesCtr}; +use crate::crypto::{AesCtr, sha256, sha256_hmac}; use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::Arc; @@ -175,7 +175,10 @@ async fn tls_minimum_viable_length_boundary() { None, ) .await; - assert!(matches!(res, HandshakeResult::Success(_)), "Exact minimum length TLS handshake must succeed"); + assert!( + matches!(res, HandshakeResult::Success(_)), + "Exact minimum length TLS handshake must succeed" + ); let short_handshake = vec![0x42u8; min_len - 1]; let res_short = handle_tls_handshake( @@ -189,7 +192,10 @@ async fn tls_minimum_viable_length_boundary() { None, ) .await; - assert!(matches!(res_short, HandshakeResult::BadClient { .. }), "Handshake 1 byte shorter than minimum must fail closed"); + assert!( + matches!(res_short, HandshakeResult::BadClient { .. }), + "Handshake 1 byte shorter than minimum must fail closed" + ); } #[tokio::test] @@ -219,9 +225,16 @@ async fn mtproto_extreme_dc_index_serialization() { match res { HandshakeResult::Success((_, _, success)) => { - assert_eq!(success.dc_idx, extreme_dc, "Extreme DC index {} must serialize/deserialize perfectly", extreme_dc); + assert_eq!( + success.dc_idx, extreme_dc, + "Extreme DC index {} must serialize/deserialize perfectly", + extreme_dc + ); } - _ => panic!("MTProto handshake with extreme DC index {} failed", extreme_dc), + _ => panic!( + "MTProto handshake with extreme DC index {} failed", + extreme_dc + ), } } } @@ -253,7 +266,11 @@ async fn alpn_strict_case_and_padding_rejection() { None, ) .await; - assert!(matches!(res, HandshakeResult::BadClient { .. }), "ALPN strict enforcement must reject {:?}", bad_alpn); + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "ALPN strict enforcement must reject {:?}", + bad_alpn + ); } } @@ -265,8 +282,15 @@ fn ipv4_mapped_ipv6_bucketing_anomaly() { let norm_1 = normalize_auth_probe_ip(ipv4_mapped_1); let norm_2 = normalize_auth_probe_ip(ipv4_mapped_2); - assert_eq!(norm_1, norm_2, "IPv4-mapped IPv6 addresses must collapse into the same /64 bucket (::0)"); - assert_eq!(norm_1, IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), "The bucket must be exactly ::0"); + assert_eq!( + norm_1, norm_2, + "IPv4-mapped IPv6 addresses must collapse into the same /64 bucket (::0)" + ); + assert_eq!( + norm_1, + IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), + "The bucket must be exactly ::0" + ); } // --- Category 2: Adversarial & Black Hat --- @@ -309,7 +333,10 @@ async fn mtproto_invalid_ciphertext_does_not_poison_replay_cache() { None, ) .await; - assert!(matches!(res_valid, HandshakeResult::Success(_)), "Invalid MTProto ciphertext must not poison the replay cache"); + assert!( + matches!(res_valid, HandshakeResult::Success(_)), + "Invalid MTProto ciphertext must not poison the replay cache" + ); } #[tokio::test] @@ -352,7 +379,10 @@ async fn tls_invalid_session_does_not_poison_replay_cache() { None, ) .await; - assert!(matches!(res_valid, HandshakeResult::Success(_)), "Invalid TLS payload must not poison the replay cache"); + assert!( + matches!(res_valid, HandshakeResult::Success(_)), + "Invalid TLS payload must not poison the replay cache" + ); } #[tokio::test] @@ -387,7 +417,10 @@ async fn server_hello_delay_timing_neutrality_on_hmac_failure() { let elapsed = start.elapsed(); assert!(matches!(res, HandshakeResult::BadClient { .. })); - assert!(elapsed >= Duration::from_millis(45), "Invalid HMAC must still incur the configured ServerHello delay to prevent timing side-channels"); + assert!( + elapsed >= Duration::from_millis(45), + "Invalid HMAC must still incur the configured ServerHello delay to prevent timing side-channels" + ); } #[tokio::test] @@ -421,7 +454,10 @@ async fn server_hello_delay_inversion_resilience() { let elapsed = start.elapsed(); assert!(matches!(res, HandshakeResult::Success(_))); - assert!(elapsed >= Duration::from_millis(90), "Delay logic must gracefully handle min > max inversions via max.max(min)"); + assert!( + elapsed >= Duration::from_millis(90), + "Delay logic must gracefully handle min > max inversions via max.max(min)" + ); } #[tokio::test] @@ -436,10 +472,16 @@ async fn mixed_valid_and_invalid_user_secrets_configuration() { for i in 0..9 { let bad_secret = if i % 2 == 0 { "badhex!" } else { "1122" }; - config.access.users.insert(format!("bad_user_{}", i), bad_secret.to_string()); + config + .access + .users + .insert(format!("bad_user_{}", i), bad_secret.to_string()); } let valid_secret_hex = "99999999999999999999999999999999"; - config.access.users.insert("good_user".to_string(), valid_secret_hex.to_string()); + config + .access + .users + .insert("good_user".to_string(), valid_secret_hex.to_string()); config.general.modes.secure = true; config.general.modes.classic = true; config.general.modes.tls = true; @@ -463,7 +505,10 @@ async fn mixed_valid_and_invalid_user_secrets_configuration() { ) .await; - assert!(matches!(res, HandshakeResult::Success(_)), "Proxy must gracefully skip invalid secrets and authenticate the valid one"); + assert!( + matches!(res, HandshakeResult::Success(_)), + "Proxy must gracefully skip invalid secrets and authenticate the valid one" + ); } #[tokio::test] @@ -494,7 +539,10 @@ async fn tls_emulation_fallback_when_cache_missing() { ) .await; - assert!(matches!(res, HandshakeResult::Success(_)), "TLS emulation must gracefully fall back to standard ServerHello if cache is missing"); + assert!( + matches!(res, HandshakeResult::Success(_)), + "TLS emulation must gracefully fall back to standard ServerHello if cache is missing" + ); } #[tokio::test] @@ -524,7 +572,10 @@ async fn classic_mode_over_tls_transport_protocol_confusion() { ) .await; - assert!(matches!(res, HandshakeResult::Success(_)), "Intermediate tag over TLS must succeed if classic mode is enabled, locking in cross-transport behavior"); + assert!( + matches!(res, HandshakeResult::Success(_)), + "Intermediate tag over TLS must succeed if classic mode is enabled, locking in cross-transport behavior" + ); } #[test] @@ -543,9 +594,15 @@ fn generate_tg_nonce_never_emits_reserved_bytes() { false, ); - assert!(!RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]), "Nonce must never start with reserved bytes"); + assert!( + !RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]), + "Nonce must never start with reserved bytes" + ); let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]]; - assert!(!RESERVED_NONCE_BEGINNINGS.contains(&first_four), "Nonce must never match reserved 4-byte beginnings"); + assert!( + !RESERVED_NONCE_BEGINNINGS.contains(&first_four), + "Nonce must never match reserved 4-byte beginnings" + ); } } @@ -568,11 +625,18 @@ async fn dashmap_concurrent_saturation_stress() { } for task in tasks { - task.await.expect("Task panicked during concurrent DashMap stress"); + task.await + .expect("Task panicked during concurrent DashMap stress"); } - assert!(auth_probe_is_throttled_for_testing(ip_a), "IP A must be throttled after concurrent stress"); - assert!(auth_probe_is_throttled_for_testing(ip_b), "IP B must be throttled after concurrent stress"); + assert!( + auth_probe_is_throttled_for_testing(ip_a), + "IP A must be throttled after concurrent stress" + ); + assert!( + auth_probe_is_throttled_for_testing(ip_b), + "IP B must be throttled after concurrent stress" + ); } #[test] @@ -586,7 +650,12 @@ fn prototag_invalid_bytes_fail_closed() { ]; for tag in invalid_tags { - assert_eq!(ProtoTag::from_bytes(tag), None, "Invalid ProtoTag bytes {:?} must fail closed", tag); + assert_eq!( + ProtoTag::from_bytes(tag), + None, + "Invalid ProtoTag bytes {:?} must fail closed", + tag + ); } } @@ -603,7 +672,10 @@ fn auth_probe_eviction_hash_collision_stress() { auth_probe_record_failure_with_state(state, ip, now); } - assert!(state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, "Eviction logic must successfully bound the map size under heavy insertion stress"); + assert!( + state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, + "Eviction logic must successfully bound the map size under heavy insertion stress" + ); } #[test] diff --git a/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs b/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs index 6c48cc1..77cea19 100644 --- a/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs +++ b/src/proxy/tests/handshake_auth_probe_eviction_bias_security_tests.rs @@ -88,6 +88,9 @@ fn light_fuzz_offset_always_stays_inside_state_len() { let now = base + Duration::from_nanos(seed & 0x0fff); let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit); - assert!(start < state_len, "scan offset must stay inside state length"); + assert!( + start < state_len, + "scan offset must stay inside state length" + ); } -} \ No newline at end of file +} diff --git a/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs index ece6ff5..c91a215 100644 --- a/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs +++ b/src/proxy/tests/handshake_auth_probe_scan_budget_security_tests.rs @@ -96,4 +96,4 @@ fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() { "scan offset must stay inside state length" ); } -} \ No newline at end of file +} diff --git a/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs index 260a1b9..bf97990 100644 --- a/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs +++ b/src/proxy/tests/handshake_auth_probe_scan_offset_stress_tests.rs @@ -113,4 +113,4 @@ fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() { "scan offset must always remain inside state length" ); } -} \ No newline at end of file +} diff --git a/src/proxy/tests/handshake_more_clever_tests.rs b/src/proxy/tests/handshake_more_clever_tests.rs index 77df442..9782469 100644 --- a/src/proxy/tests/handshake_more_clever_tests.rs +++ b/src/proxy/tests/handshake_more_clever_tests.rs @@ -1,8 +1,8 @@ use super::*; -use crate::crypto::{sha256, sha256_hmac, AesCtr}; +use crate::crypto::{AesCtr, sha256, sha256_hmac}; use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES}; -use rand::{Rng, SeedableRng}; use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; use std::collections::HashSet; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::sync::Arc; @@ -223,7 +223,10 @@ fn auth_probe_backoff_extreme_fail_streak_clamps_safely() { assert_eq!(updated.fail_streak, u32::MAX); let expected_blocked_until = now + Duration::from_millis(AUTH_PROBE_BACKOFF_MAX_MS); - assert_eq!(updated.blocked_until, expected_blocked_until, "Extreme fail streak must clamp cleanly to AUTH_PROBE_BACKOFF_MAX_MS"); + assert_eq!( + updated.blocked_until, expected_blocked_until, + "Extreme fail streak must clamp cleanly to AUTH_PROBE_BACKOFF_MAX_MS" + ); } #[test] @@ -250,12 +253,19 @@ fn generate_tg_nonce_cryptographic_uniqueness_and_entropy() { total_set_bits += byte.count_ones() as usize; } - assert!(nonces.insert(nonce), "generate_tg_nonce emitted a duplicate nonce! RNG is stuck."); + assert!( + nonces.insert(nonce), + "generate_tg_nonce emitted a duplicate nonce! RNG is stuck." + ); } let total_bits = iterations * HANDSHAKE_LEN * 8; let ratio = (total_set_bits as f64) / (total_bits as f64); - assert!(ratio > 0.48 && ratio < 0.52, "Nonce entropy is degraded. Set bit ratio: {}", ratio); + assert!( + ratio > 0.48 && ratio < 0.52, + "Nonce entropy is degraded. Set bit ratio: {}", + ratio + ); } #[tokio::test] @@ -267,10 +277,19 @@ async fn mtproto_multi_user_decryption_isolation() { config.general.modes.secure = true; config.access.ignore_time_skew = true; - config.access.users.insert("user_a".to_string(), "11111111111111111111111111111111".to_string()); - config.access.users.insert("user_b".to_string(), "22222222222222222222222222222222".to_string()); + config.access.users.insert( + "user_a".to_string(), + "11111111111111111111111111111111".to_string(), + ); + config.access.users.insert( + "user_b".to_string(), + "22222222222222222222222222222222".to_string(), + ); let good_secret_hex = "33333333333333333333333333333333"; - config.access.users.insert("user_c".to_string(), good_secret_hex.to_string()); + config + .access + .users + .insert("user_c".to_string(), good_secret_hex.to_string()); let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); let peer: SocketAddr = "192.0.2.104:12345".parse().unwrap(); @@ -291,9 +310,14 @@ async fn mtproto_multi_user_decryption_isolation() { match res { HandshakeResult::Success((_, _, success)) => { - assert_eq!(success.user, "user_c", "Decryption attempts on previous users must not corrupt the handshake buffer for the valid user"); + assert_eq!( + success.user, "user_c", + "Decryption attempts on previous users must not corrupt the handshake buffer for the valid user" + ); } - _ => panic!("Multi-user MTProto handshake failed. Decryption buffer might be mutating in place."), + _ => panic!( + "Multi-user MTProto handshake failed. Decryption buffer might be mutating in place." + ), } } @@ -325,7 +349,9 @@ async fn invalid_secret_warning_lock_contention_and_bound() { } let warned = INVALID_SECRET_WARNED.get().unwrap(); - let guard = warned.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let guard = warned + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()); assert_eq!( guard.len(), @@ -342,7 +368,11 @@ async fn mtproto_strict_concurrent_replay_race_condition() { let secret_hex = "4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A"; let config = Arc::new(test_config_with_secret_hex(secret_hex)); let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60))); - let valid_handshake = Arc::new(make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1)); + let valid_handshake = Arc::new(make_valid_mtproto_handshake( + secret_hex, + ProtoTag::Secure, + 1, + )); let tasks = 100; let barrier = Arc::new(Barrier::new(tasks)); @@ -355,7 +385,10 @@ async fn mtproto_strict_concurrent_replay_race_condition() { let hs = valid_handshake.clone(); handles.push(tokio::spawn(async move { - let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 250) as u8)), 10000 + i as u16); + let peer = SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 250) as u8)), + 10000 + i as u16, + ); b.wait().await; handle_mtproto_handshake( &hs, @@ -382,8 +415,15 @@ async fn mtproto_strict_concurrent_replay_race_condition() { } } - assert_eq!(successes, 1, "Replay cache race condition allowed multiple identical MTProto handshakes to succeed"); - assert_eq!(failures, tasks - 1, "Replay cache failed to forcefully reject concurrent duplicates"); + assert_eq!( + successes, 1, + "Replay cache race condition allowed multiple identical MTProto handshakes to succeed" + ); + assert_eq!( + failures, + tasks - 1, + "Replay cache failed to forcefully reject concurrent duplicates" + ); } #[tokio::test] @@ -398,7 +438,8 @@ async fn tls_alpn_zero_length_protocol_handled_safely() { let rng = SecureRandom::new(); let peer: SocketAddr = "192.0.2.107:12345".parse().unwrap(); - let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b""]); + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b""]); let res = handle_tls_handshake( &handshake, @@ -412,7 +453,10 @@ async fn tls_alpn_zero_length_protocol_handled_safely() { ) .await; - assert!(matches!(res, HandshakeResult::BadClient { .. }), "0-length ALPN must be safely rejected without panicking"); + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "0-length ALPN must be safely rejected without panicking" + ); } #[tokio::test] @@ -427,7 +471,8 @@ async fn tls_sni_massive_hostname_does_not_panic() { let peer: SocketAddr = "192.0.2.108:12345".parse().unwrap(); let massive_hostname = String::from_utf8(vec![b'a'; 65000]).unwrap(); - let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, &massive_hostname, &[]); + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, &massive_hostname, &[]); let res = handle_tls_handshake( &handshake, @@ -441,7 +486,13 @@ async fn tls_sni_massive_hostname_does_not_panic() { ) .await; - assert!(matches!(res, HandshakeResult::Success(_) | HandshakeResult::BadClient { .. }), "Massive SNI hostname must be processed or ignored without stack overflow or panic"); + assert!( + matches!( + res, + HandshakeResult::Success(_) | HandshakeResult::BadClient { .. } + ), + "Massive SNI hostname must be processed or ignored without stack overflow or panic" + ); } #[tokio::test] @@ -455,7 +506,8 @@ async fn tls_progressive_truncation_fuzzing_no_panics() { let rng = SecureRandom::new(); let peer: SocketAddr = "192.0.2.109:12345".parse().unwrap(); - let valid_handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b"h2"]); + let valid_handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b"h2"]); let full_len = valid_handshake.len(); // Truncated corpus only: full_len is a valid baseline and should not be @@ -473,7 +525,11 @@ async fn tls_progressive_truncation_fuzzing_no_panics() { None, ) .await; - assert!(matches!(res, HandshakeResult::BadClient { .. }), "Truncated TLS handshake at len {} must fail safely without panicking", i); + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "Truncated TLS handshake at len {} must fail safely without panicking", + i + ); } } @@ -504,7 +560,10 @@ async fn mtproto_pure_entropy_fuzzing_no_panics() { ) .await; - assert!(matches!(res, HandshakeResult::BadClient { .. }), "Pure entropy MTProto payload must fail closed and never panic"); + assert!( + matches!(res, HandshakeResult::BadClient { .. }), + "Pure entropy MTProto payload must fail closed and never panic" + ); } } @@ -517,10 +576,16 @@ fn decode_user_secret_odd_length_hex_rejection() { let mut config = ProxyConfig::default(); config.access.users.clear(); - config.access.users.insert("odd_user".to_string(), "1234567890123456789012345678901".to_string()); + config.access.users.insert( + "odd_user".to_string(), + "1234567890123456789012345678901".to_string(), + ); let decoded = decode_user_secrets(&config, None); - assert!(decoded.is_empty(), "Odd-length hex string must be gracefully rejected by hex::decode without unwrapping"); + assert!( + decoded.is_empty(), + "Odd-length hex string must be gracefully rejected by hex::decode without unwrapping" + ); } #[test] @@ -552,7 +617,10 @@ fn saturation_grace_pre_existing_high_fail_streak_immediate_throttle() { } let is_throttled = auth_probe_should_apply_preauth_throttle(peer_ip, now); - assert!(is_throttled, "A peer with a pre-existing high fail streak must be immediately throttled when saturation begins, receiving no unearned grace period"); + assert!( + is_throttled, + "A peer with a pre-existing high fail streak must be immediately throttled when saturation begins, receiving no unearned grace period" + ); } #[test] @@ -586,7 +654,11 @@ fn mtproto_classic_tags_rejected_when_only_secure_mode_enabled() { config.general.modes.tls = false; assert!(!mode_enabled_for_proto(&config, ProtoTag::Abridged, false)); - assert!(!mode_enabled_for_proto(&config, ProtoTag::Intermediate, false)); + assert!(!mode_enabled_for_proto( + &config, + ProtoTag::Intermediate, + false + )); } #[test] diff --git a/src/proxy/tests/handshake_real_bug_stress_tests.rs b/src/proxy/tests/handshake_real_bug_stress_tests.rs index d7234ff..1e27ed5 100644 --- a/src/proxy/tests/handshake_real_bug_stress_tests.rs +++ b/src/proxy/tests/handshake_real_bug_stress_tests.rs @@ -1,5 +1,5 @@ use super::*; -use crate::crypto::{sha256, sha256_hmac, AesCtr, SecureRandom}; +use crate::crypto::{AesCtr, SecureRandom, sha256, sha256_hmac}; use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION}; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::sync::Arc; @@ -80,8 +80,7 @@ fn make_valid_tls_client_hello_with_alpn( digest[28 + i] ^= ts[i]; } - record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] - .copy_from_slice(&digest); + record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest); record } @@ -331,7 +330,11 @@ async fn saturation_grace_exhaustion_under_concurrency_keeps_peer_throttled() { let final_state = state.get(&peer_ip).expect("state must exist"); assert!( - final_state.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS + final_state.fail_streak + >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS ); - assert!(auth_probe_should_apply_preauth_throttle(peer_ip, Instant::now())); + assert!(auth_probe_should_apply_preauth_throttle( + peer_ip, + Instant::now() + )); } diff --git a/src/proxy/tests/handshake_security_tests.rs b/src/proxy/tests/handshake_security_tests.rs index d06f63e..6796c5c 100644 --- a/src/proxy/tests/handshake_security_tests.rs +++ b/src/proxy/tests/handshake_security_tests.rs @@ -956,6 +956,89 @@ async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() { } } +#[tokio::test] +async fn tls_unknown_sni_drop_policy_returns_hard_error() { + let secret = [0x48u8; 16]; + let mut config = test_config_with_secret_hex("48484848484848484848484848484848"); + config.censorship.unknown_sni_action = UnknownSniAction::Drop; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.190:44326".parse().unwrap(); + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "unknown.example", &[b"h2"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!( + result, + HandshakeResult::Error(ProxyError::UnknownTlsSni) + )); +} + +#[tokio::test] +async fn tls_unknown_sni_mask_policy_falls_back_to_bad_client() { + let secret = [0x49u8; 16]; + let mut config = test_config_with_secret_hex("49494949494949494949494949494949"); + config.censorship.unknown_sni_action = UnknownSniAction::Mask; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.191:44326".parse().unwrap(); + let handshake = + make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "unknown.example", &[b"h2"]); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::BadClient { .. })); +} + +#[tokio::test] +async fn tls_missing_sni_keeps_legacy_auth_path() { + let secret = [0x4Au8; 16]; + let mut config = test_config_with_secret_hex("4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a"); + config.censorship.unknown_sni_action = UnknownSniAction::Drop; + + let replay_checker = ReplayChecker::new(128, Duration::from_secs(60)); + let rng = SecureRandom::new(); + let peer: SocketAddr = "198.51.100.192:44326".parse().unwrap(); + let handshake = make_valid_tls_handshake(&secret, 0); + + let result = handle_tls_handshake( + &handshake, + tokio::io::empty(), + tokio::io::sink(), + peer, + &config, + &replay_checker, + &rng, + None, + ) + .await; + + assert!(matches!(result, HandshakeResult::Success(_))); +} + #[tokio::test] async fn alpn_enforce_rejects_unsupported_client_alpn() { let secret = [0x33u8; 16]; diff --git a/src/proxy/tests/handshake_timing_manual_bench_tests.rs b/src/proxy/tests/handshake_timing_manual_bench_tests.rs index 95e9f49..13d112c 100644 --- a/src/proxy/tests/handshake_timing_manual_bench_tests.rs +++ b/src/proxy/tests/handshake_timing_manual_bench_tests.rs @@ -1,5 +1,5 @@ use super::*; -use crate::crypto::{sha256, sha256_hmac, AesCtr, SecureRandom}; +use crate::crypto::{AesCtr, SecureRandom, sha256, sha256_hmac}; use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION}; use std::net::SocketAddr; use std::time::{Duration, Instant}; @@ -169,10 +169,10 @@ async fn mtproto_user_scan_timing_manual_benchmark() { ); } - config.access.users.insert( - preferred_user.to_string(), - target_secret_hex.to_string(), - ); + config + .access + .users + .insert(preferred_user.to_string(), target_secret_hex.to_string()); let replay_checker_preferred = ReplayChecker::new(65_536, Duration::from_secs(60)); let replay_checker_full_scan = ReplayChecker::new(65_536, Duration::from_secs(60)); diff --git a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs index 84c904f..a977409 100644 --- a/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs +++ b/src/proxy/tests/masking_ab_envelope_blur_integration_security_tests.rs @@ -544,7 +544,6 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u if hardened_acc + 0.05 <= baseline_acc { meaningful_improvement_seen = true; } - } assert!( diff --git a/src/proxy/tests/masking_additional_hardening_security_tests.rs b/src/proxy/tests/masking_additional_hardening_security_tests.rs index 29170c1..a6f6386 100644 --- a/src/proxy/tests/masking_additional_hardening_security_tests.rs +++ b/src/proxy/tests/masking_additional_hardening_security_tests.rs @@ -78,7 +78,11 @@ fn timing_normalization_zero_floor_safety_net_defaults_to_mask_timeout() { config.censorship.mask_timing_normalization_ceiling_ms = 0; let budget = mask_outcome_target_budget(&config); - assert_eq!(budget, MASK_TIMEOUT); + assert_eq!( + budget, + Duration::from_millis(0), + "zero floor/ceiling must produce zero extra normalization budget" + ); } #[tokio::test] diff --git a/src/proxy/tests/masking_aggressive_mode_security_tests.rs b/src/proxy/tests/masking_aggressive_mode_security_tests.rs index a77fc14..7356dc0 100644 --- a/src/proxy/tests/masking_aggressive_mode_security_tests.rs +++ b/src/proxy/tests/masking_aggressive_mode_security_tests.rs @@ -85,7 +85,10 @@ async fn aggressive_mode_shapes_backend_silent_non_eof_path() { let legacy = capture_forwarded_len_with_mode(body_sent, false, false, false, 0).await; let aggressive = capture_forwarded_len_with_mode(body_sent, false, true, false, 0).await; - assert!(legacy < floor, "legacy mode should keep timeout path unshaped"); + assert!( + legacy < floor, + "legacy mode should keep timeout path unshaped" + ); assert!( aggressive >= floor, "aggressive mode must shape backend-silent non-EOF paths (aggressive={aggressive}, floor={floor})" diff --git a/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs b/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs index 614af9b..718189c 100644 --- a/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs +++ b/src/proxy/tests/masking_connect_failure_close_matrix_security_tests.rs @@ -52,7 +52,10 @@ async fn run_connect_failure_case( .await .unwrap() .unwrap(); - assert_eq!(n, 0, "connect-failure path must close client-visible writer"); + assert_eq!( + n, 0, + "connect-failure path must close client-visible writer" + ); started.elapsed() } @@ -67,13 +70,9 @@ async fn connect_failure_refusal_close_behavior_matrix() { let peer: SocketAddr = format!("203.0.113.210:{}", 54100 + idx as u16) .parse() .unwrap(); - let elapsed = run_connect_failure_case( - "127.0.0.1", - unused_port, - timing_normalization_enabled, - peer, - ) - .await; + let elapsed = + run_connect_failure_case("127.0.0.1", unused_port, timing_normalization_enabled, peer) + .await; if timing_normalization_enabled { assert!( diff --git a/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs b/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs index b52af35..f2c39a2 100644 --- a/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs +++ b/src/proxy/tests/masking_consume_idle_timeout_security_tests.rs @@ -79,7 +79,10 @@ async fn io_error_terminates_cleanly() { } } - tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(ErrReader, usize::MAX)) - .await - .expect("consume_client_data did not return on I/O error"); + tokio::time::timeout( + MASK_RELAY_TIMEOUT, + consume_client_data(ErrReader, usize::MAX), + ) + .await + .expect("consume_client_data did not return on I/O error"); } diff --git a/src/proxy/tests/masking_extended_attack_surface_security_tests.rs b/src/proxy/tests/masking_extended_attack_surface_security_tests.rs index 040f567..650731c 100644 --- a/src/proxy/tests/masking_extended_attack_surface_security_tests.rs +++ b/src/proxy/tests/masking_extended_attack_surface_security_tests.rs @@ -32,8 +32,16 @@ async fn run_self_target_refusal( let (mut client, server) = duplex(1024); let started = Instant::now(); let task = tokio::spawn(async move { - handle_bad_client(server, tokio::io::sink(), initial, peer, local_addr, &config, &beobachten) - .await; + handle_bad_client( + server, + tokio::io::sink(), + initial, + peer, + local_addr, + &config, + &beobachten, + ) + .await; }); client @@ -214,4 +222,4 @@ async fn stress_high_fanout_self_target_refusal_no_deadlock_or_timeout() { }) .await .expect("high-fanout refusal workload must complete without deadlock"); -} \ No newline at end of file +} diff --git a/src/proxy/tests/masking_http_probe_boundary_security_tests.rs b/src/proxy/tests/masking_http_probe_boundary_security_tests.rs index 47b6dc6..c8f3ec0 100644 --- a/src/proxy/tests/masking_http_probe_boundary_security_tests.rs +++ b/src/proxy/tests/masking_http_probe_boundary_security_tests.rs @@ -2,7 +2,13 @@ use super::*; #[test] fn exact_four_byte_http_tokens_are_classified() { - for token in [b"GET ".as_ref(), b"POST".as_ref(), b"HEAD".as_ref(), b"PUT ".as_ref(), b"PRI ".as_ref()] { + for token in [ + b"GET ".as_ref(), + b"POST".as_ref(), + b"HEAD".as_ref(), + b"PUT ".as_ref(), + b"PRI ".as_ref(), + ] { assert!( is_http_probe(token), "exact 4-byte token must be classified as HTTP probe: {:?}", @@ -76,4 +82,4 @@ fn light_fuzz_four_byte_ascii_noise_not_misclassified() { token ); } -} \ No newline at end of file +} diff --git a/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs b/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs index 8d99b8f..ed6d1ab 100644 --- a/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs +++ b/src/proxy/tests/masking_interface_cache_concurrency_security_tests.rs @@ -38,4 +38,4 @@ async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() { 1, "parallel cold misses must coalesce into a single interface enumeration" ); -} \ No newline at end of file +} diff --git a/src/proxy/tests/masking_interface_cache_security_tests.rs b/src/proxy/tests/masking_interface_cache_security_tests.rs index 6be99d0..17debb0 100644 --- a/src/proxy/tests/masking_interface_cache_security_tests.rs +++ b/src/proxy/tests/masking_interface_cache_security_tests.rs @@ -37,7 +37,10 @@ async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() { let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr"); let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await; - assert!(!is_local, "different port must not be treated as local listener"); + assert!( + !is_local, + "different port must not be treated as local listener" + ); assert_eq!( local_interface_enumerations_for_tests(), 0, diff --git a/src/proxy/tests/masking_production_cap_regression_security_tests.rs b/src/proxy/tests/masking_production_cap_regression_security_tests.rs index f2368a1..9ff51ba 100644 --- a/src/proxy/tests/masking_production_cap_regression_security_tests.rs +++ b/src/proxy/tests/masking_production_cap_regression_security_tests.rs @@ -63,17 +63,11 @@ impl AsyncWrite for CountingWriter { Poll::Ready(Ok(buf.len())) } - fn poll_flush( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } - fn poll_shutdown( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } } diff --git a/src/proxy/tests/masking_self_target_loop_security_tests.rs b/src/proxy/tests/masking_self_target_loop_security_tests.rs index 18cb0d7..7f6cb29 100644 --- a/src/proxy/tests/masking_self_target_loop_security_tests.rs +++ b/src/proxy/tests/masking_self_target_loop_security_tests.rs @@ -1,6 +1,6 @@ use super::*; -use std::net::TcpListener as StdTcpListener; use std::net::SocketAddr; +use std::net::TcpListener as StdTcpListener; use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::net::TcpListener; use tokio::time::{Duration, Instant, timeout}; @@ -15,74 +15,38 @@ fn closed_local_port() -> u16 { #[tokio::test] async fn self_target_detection_matches_literal_ipv4_listener() { let local: SocketAddr = "198.51.100.40:443".parse().unwrap(); - assert!(is_mask_target_local_listener_async( - "198.51.100.40", - 443, - local, - None, - ) - .await); + assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, None,).await); } #[tokio::test] async fn self_target_detection_matches_bracketed_ipv6_listener() { let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap(); - assert!(is_mask_target_local_listener_async( - "[2001:db8::44]", - 8443, - local, - None, - ) - .await); + assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, None,).await); } #[tokio::test] async fn self_target_detection_keeps_same_ip_different_port_forwardable() { let local: SocketAddr = "203.0.113.44:443".parse().unwrap(); - assert!(!is_mask_target_local_listener_async( - "203.0.113.44", - 8443, - local, - None, - ) - .await); + assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, None,).await); } #[tokio::test] async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() { let local: SocketAddr = "127.0.0.1:443".parse().unwrap(); - assert!(is_mask_target_local_listener_async( - "::ffff:127.0.0.1", - 443, - local, - None, - ) - .await); + assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, None,).await); } #[tokio::test] async fn self_target_detection_unspecified_bind_blocks_loopback_target() { let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); - assert!(is_mask_target_local_listener_async( - "127.0.0.1", - 443, - local, - None, - ) - .await); + assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, None,).await); } #[tokio::test] async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() { let local: SocketAddr = "0.0.0.0:443".parse().unwrap(); let remote: SocketAddr = "198.51.100.44:443".parse().unwrap(); - assert!(!is_mask_target_local_listener_async( - "mask.example", - 443, - local, - Some(remote), - ) - .await); + assert!(!is_mask_target_local_listener_async("mask.example", 443, local, Some(remote),).await); } #[tokio::test] @@ -306,7 +270,10 @@ async fn offline_mask_target_refusal_respects_timing_normalization_budget() { }); client.shutdown().await.unwrap(); - timeout(Duration::from_secs(2), task).await.unwrap().unwrap(); + timeout(Duration::from_secs(2), task) + .await + .unwrap() + .unwrap(); let elapsed = started.elapsed(); assert!( @@ -350,7 +317,10 @@ async fn offline_mask_target_refusal_with_idle_client_is_bounded_by_consume_time .await .expect("connection should still be open before consume timeout expires"); - timeout(Duration::from_secs(2), task).await.unwrap().unwrap(); + timeout(Duration::from_secs(2), task) + .await + .unwrap() + .unwrap(); let elapsed = started.elapsed(); assert!( diff --git a/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs b/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs index 1c342ea..fda6de7 100644 --- a/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs +++ b/src/proxy/tests/masking_timing_budget_coupling_security_tests.rs @@ -40,7 +40,10 @@ async fn adversarial_delayed_interface_lookup_does_not_consume_outcome_floor_bud tokio::time::sleep(Duration::from_millis(80)).await; drop(held_refresh_guard); - client.shutdown().await.expect("client shutdown must succeed"); + client + .shutdown() + .await + .expect("client shutdown must succeed"); timeout(Duration::from_secs(2), task) .await @@ -52,4 +55,4 @@ async fn adversarial_delayed_interface_lookup_does_not_consume_outcome_floor_bud elapsed >= Duration::from_millis(180) && elapsed < Duration::from_millis(350), "timing normalization floor must start after pre-outcome self-target checks" ); -} \ No newline at end of file +} diff --git a/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs b/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs new file mode 100644 index 0000000..7c176bc --- /dev/null +++ b/src/proxy/tests/middle_relay_atomic_quota_invariant_tests.rs @@ -0,0 +1,189 @@ +use super::*; +use crate::crypto::AesCtr; +use bytes::Bytes; +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::task::{Context, Poll}; +use tokio::io::AsyncWrite; + +struct CountedWriter { + write_calls: Arc, + fail_writes: bool, +} + +impl CountedWriter { + fn new(write_calls: Arc, fail_writes: bool) -> Self { + Self { + write_calls, + fail_writes, + } + } +} + +impl AsyncWrite for CountedWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + this.write_calls.fetch_add(1, Ordering::Relaxed); + if this.fail_writes { + Poll::Ready(Err(io::Error::new( + io::ErrorKind::BrokenPipe, + "forced write failure", + ))) + } else { + Poll::Ready(Ok(buf.len())) + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +fn make_crypto_writer(inner: CountedWriter) -> CryptoWriter { + let key = [0u8; 32]; + let iv = 0u128; + CryptoWriter::new(inner, AesCtr::new(&key, iv), 8 * 1024) +} + +#[tokio::test] +async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() { + let stats = Stats::new(); + let user = "middle-me-writer-no-rollback-user"; + let user_stats = stats.get_or_create_user_stats_handle(user); + let write_calls = Arc::new(AtomicUsize::new(0)); + let mut writer = make_crypto_writer(CountedWriter::new(write_calls.clone(), true)); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + let payload = Bytes::from_static(&[0x11, 0x22, 0x33, 0x44, 0x55]); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: payload.clone(), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(user_stats.as_ref()), + Some(64), + 0, + &bytes_me2c, + 11, + true, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::Io(_))), + "write failure must propagate as I/O error" + ); + assert!( + write_calls.load(Ordering::Relaxed) > 0, + "writer must be attempted after successful quota reservation" + ); + assert_eq!( + stats.get_user_quota_used(user), + payload.len() as u64, + "reserved quota must not roll back on write failure" + ); + assert_eq!( + stats.get_quota_write_fail_bytes_total(), + payload.len() as u64, + "write-fail byte metric must include failed payload size" + ); + assert_eq!( + stats.get_quota_write_fail_events_total(), + 1, + "write-fail events metric must increment once" + ); + assert_eq!( + stats.get_user_total_octets(user), + 0, + "telemetry octets_to should not advance when write fails" + ); + assert_eq!( + bytes_me2c.load(Ordering::Relaxed), + 0, + "ME->C committed byte counter must not advance on write failure" + ); +} + +#[tokio::test] +async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() { + let stats = Stats::new(); + let user = "middle-me-writer-precheck-user"; + let limit = 8u64; + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), limit); + + let write_calls = Arc::new(AtomicUsize::new(0)); + let mut writer = make_crypto_writer(CountedWriter::new(write_calls.clone(), false)); + let mut frame_buf = Vec::new(); + let bytes_me2c = AtomicU64::new(0); + + let result = process_me_writer_response( + MeResponse::Data { + flags: 0, + data: Bytes::from_static(&[0xAA, 0xBB, 0xCC]), + }, + &mut writer, + ProtoTag::Intermediate, + &SecureRandom::new(), + &mut frame_buf, + &stats, + user, + Some(user_stats.as_ref()), + Some(limit), + 0, + &bytes_me2c, + 12, + true, + false, + ) + .await; + + assert!( + matches!(result, Err(ProxyError::DataQuotaExceeded { .. })), + "pre-write quota rejection must return typed quota error" + ); + assert_eq!( + write_calls.load(Ordering::Relaxed), + 0, + "writer must not be polled when pre-write quota reservation fails" + ); + assert_eq!( + stats.get_me_d2c_quota_reject_pre_write_total(), + 1, + "pre-write quota reject metric must increment" + ); + assert_eq!( + stats.get_user_quota_used(user), + limit, + "failed pre-write reservation must keep previous quota usage unchanged" + ); + assert_eq!( + stats.get_quota_write_fail_bytes_total(), + 0, + "write-fail bytes metric must stay unchanged on pre-write reject" + ); + assert_eq!( + stats.get_quota_write_fail_events_total(), + 0, + "write-fail events metric must stay unchanged on pre-write reject" + ); + assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); +} diff --git a/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs b/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs deleted file mode 100644 index 6f0e91a..0000000 --- a/src/proxy/tests/middle_relay_blackhat_campaign_integration_tests.rs +++ /dev/null @@ -1,113 +0,0 @@ -use super::*; -use crate::stats::Stats; -use dashmap::DashMap; -use std::sync::Arc; -use std::sync::atomic::{AtomicUsize, Ordering}; -use tokio::sync::Barrier; -use tokio::time::{Duration, timeout}; - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() { - let _guard = super::quota_user_lock_test_scope(); - let _pressure_guard = super::relay_idle_pressure_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "middle-blackhat-held-{}-{idx}", - std::process::id() - ))); - } - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "precondition: bounded lock cache must be saturated" - ); - - let (tx, _rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Close) - .await - .expect("queue prefill should succeed"); - - let pressure_seq_before = relay_pressure_event_seq(); - let pressure_errors = Arc::new(AtomicUsize::new(0)); - let mut pressure_workers = Vec::new(); - for _ in 0..16 { - let tx = tx.clone(); - let pressure_errors = Arc::clone(&pressure_errors); - pressure_workers.push(tokio::spawn(async move { - if enqueue_c2me_command(&tx, C2MeCommand::Close).await.is_err() { - pressure_errors.fetch_add(1, Ordering::Relaxed); - } - })); - } - - let stats = Arc::new(Stats::new()); - let user = format!("middle-blackhat-quota-race-{}", std::process::id()); - let gate = Arc::new(Barrier::new(16)); - - let mut quota_workers = Vec::new(); - for _ in 0..16u8 { - let stats = Arc::clone(&stats); - let user = user.clone(); - let gate = Arc::clone(&gate); - quota_workers.push(tokio::spawn(async move { - gate.wait().await; - let user_lock = quota_user_lock(&user); - let _quota_guard = user_lock.lock().await; - - if quota_would_be_exceeded_for_user(&stats, &user, Some(1), 1) { - return false; - } - stats.add_user_octets_to(&user, 1); - true - })); - } - - let mut ok_count = 0usize; - let mut denied_count = 0usize; - for worker in quota_workers { - let result = timeout(Duration::from_secs(2), worker) - .await - .expect("quota worker must finish") - .expect("quota worker must not panic"); - if result { - ok_count += 1; - } else { - denied_count += 1; - } - } - - for worker in pressure_workers { - timeout(Duration::from_secs(2), worker) - .await - .expect("pressure worker must finish") - .expect("pressure worker must not panic"); - } - - assert_eq!( - stats.get_user_total_octets(&user), - 1, - "black-hat campaign must not overshoot same-user quota under saturation" - ); - assert!(ok_count <= 1, "at most one quota contender may succeed"); - assert!( - denied_count >= 15, - "all remaining contenders must be quota-denied" - ); - - let pressure_seq_after = relay_pressure_event_seq(); - assert!( - pressure_seq_after > pressure_seq_before, - "queue pressure leg must trigger pressure accounting" - ); - assert!( - pressure_errors.load(Ordering::Relaxed) >= 1, - "at least one pressure worker should fail from persistent backpressure" - ); - - drop(retained); -} diff --git a/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs b/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs deleted file mode 100644 index 44c201f..0000000 --- a/src/proxy/tests/middle_relay_coverage_high_risk_security_tests.rs +++ /dev/null @@ -1,777 +0,0 @@ -use super::*; -use crate::crypto::AesCtr; -use crate::crypto::SecureRandom; -use crate::stats::Stats; -use crate::stream::{BufferPool, PooledBuffer}; -use std::sync::Arc; -use tokio::io::AsyncReadExt; -use tokio::io::duplex; -use tokio::sync::mpsc; -use tokio::time::{Duration as TokioDuration, timeout}; - -fn make_pooled_payload(data: &[u8]) -> PooledBuffer { - let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); - let mut payload = pool.get(); - payload.resize(data.len(), 0); - payload[..data.len()].copy_from_slice(data); - payload -} - -#[tokio::test] -async fn write_client_payload_abridged_short_quickack_sets_flag_and_preserves_payload() { - let (mut read_side, write_side) = duplex(4096); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = vec![0xA1, 0xB2, 0xC3, 0xD4, 0x10, 0x20, 0x30, 0x40]; - - write_client_payload( - &mut writer, - ProtoTag::Abridged, - RPC_FLAG_QUICKACK, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("abridged quickack payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 1 + payload.len()]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read serialized abridged frame"); - let plaintext = decryptor.decrypt(&encrypted); - - assert_eq!(plaintext[0], 0x80 | ((payload.len() / 4) as u8)); - assert_eq!(&plaintext[1..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_abridged_extended_header_is_encoded_correctly() { - let (mut read_side, write_side) = duplex(16 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - // Boundary where abridged switches to extended length encoding. - let payload = vec![0x5Au8; 0x7f * 4]; - - write_client_payload( - &mut writer, - ProtoTag::Abridged, - RPC_FLAG_QUICKACK, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("extended abridged payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 4 + payload.len()]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read serialized extended abridged frame"); - let plaintext = decryptor.decrypt(&encrypted); - - assert_eq!(plaintext[0], 0xff, "0x7f with quickack bit must be set"); - assert_eq!(&plaintext[1..4], &[0x7f, 0x00, 0x00]); - assert_eq!(&plaintext[4..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_abridged_misaligned_is_rejected_fail_closed() { - let (_read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - let err = write_client_payload( - &mut writer, - ProtoTag::Abridged, - 0, - &[1, 2, 3], - &rng, - &mut frame_buf, - ) - .await - .expect_err("misaligned abridged payload must be rejected"); - - let msg = format!("{err}"); - assert!( - msg.contains("4-byte aligned"), - "error should explain alignment contract, got: {msg}" - ); -} - -#[tokio::test] -async fn write_client_payload_secure_misaligned_is_rejected_fail_closed() { - let (_read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - let err = write_client_payload( - &mut writer, - ProtoTag::Secure, - 0, - &[9, 8, 7, 6, 5], - &rng, - &mut frame_buf, - ) - .await - .expect_err("misaligned secure payload must be rejected"); - - let msg = format!("{err}"); - assert!( - msg.contains("Secure payload must be 4-byte aligned"), - "error should be explicit for fail-closed triage, got: {msg}" - ); -} - -#[tokio::test] -async fn write_client_payload_intermediate_quickack_sets_length_msb() { - let (mut read_side, write_side) = duplex(4096); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = b"hello-middle-relay"; - - write_client_payload( - &mut writer, - ProtoTag::Intermediate, - RPC_FLAG_QUICKACK, - payload, - &rng, - &mut frame_buf, - ) - .await - .expect("intermediate quickack payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 4 + payload.len()]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read intermediate frame"); - let plaintext = decryptor.decrypt(&encrypted); - - let mut len_bytes = [0u8; 4]; - len_bytes.copy_from_slice(&plaintext[..4]); - let len_with_flags = u32::from_le_bytes(len_bytes); - assert_ne!(len_with_flags & 0x8000_0000, 0, "quickack bit must be set"); - assert_eq!((len_with_flags & 0x7fff_ffff) as usize, payload.len()); - assert_eq!(&plaintext[4..], payload); -} - -#[tokio::test] -async fn write_client_payload_secure_quickack_prefix_and_padding_bounds_hold() { - let (mut read_side, write_side) = duplex(4096); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = vec![0x33u8; 100]; // 4-byte aligned as required by secure mode. - - write_client_payload( - &mut writer, - ProtoTag::Secure, - RPC_FLAG_QUICKACK, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("secure quickack payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - // Secure mode adds 1..=3 bytes of randomized tail padding. - let mut encrypted_header = [0u8; 4]; - read_side - .read_exact(&mut encrypted_header) - .await - .expect("must read secure header"); - let decrypted_header = decryptor.decrypt(&encrypted_header); - let header: [u8; 4] = decrypted_header - .try_into() - .expect("decrypted secure header must be 4 bytes"); - let wire_len_raw = u32::from_le_bytes(header); - - assert_ne!( - wire_len_raw & 0x8000_0000, - 0, - "secure quickack bit must be set" - ); - - let wire_len = (wire_len_raw & 0x7fff_ffff) as usize; - assert!(wire_len >= payload.len()); - let padding_len = wire_len - payload.len(); - assert!( - (1..=3).contains(&padding_len), - "secure writer must add bounded random tail padding, got {padding_len}" - ); - - let mut encrypted_body = vec![0u8; wire_len]; - read_side - .read_exact(&mut encrypted_body) - .await - .expect("must read secure body"); - let decrypted_body = decryptor.decrypt(&encrypted_body); - assert_eq!(&decrypted_body[..payload.len()], payload.as_slice()); -} - -#[tokio::test] -#[ignore = "heavy: allocates >64MiB to validate abridged too-large fail-closed branch"] -async fn write_client_payload_abridged_too_large_is_rejected_fail_closed() { - let (_read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - // Exactly one 4-byte word above the encodable 24-bit abridged length range. - let payload = vec![0x00u8; (1 << 24) * 4]; - let err = write_client_payload( - &mut writer, - ProtoTag::Abridged, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect_err("oversized abridged payload must be rejected"); - - let msg = format!("{err}"); - assert!( - msg.contains("Abridged frame too large"), - "error must clearly indicate oversize fail-close path, got: {msg}" - ); -} - -#[tokio::test] -async fn write_client_ack_intermediate_is_little_endian() { - let (mut read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - - write_client_ack(&mut writer, ProtoTag::Intermediate, 0x11_22_33_44) - .await - .expect("ack serialization should succeed"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = [0u8; 4]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read ack bytes"); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain.as_slice(), &0x11_22_33_44u32.to_le_bytes()); -} - -#[tokio::test] -async fn write_client_ack_abridged_is_big_endian() { - let (mut read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - - write_client_ack(&mut writer, ProtoTag::Abridged, 0xDE_AD_BE_EF) - .await - .expect("ack serialization should succeed"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = [0u8; 4]; - read_side - .read_exact(&mut encrypted) - .await - .expect("must read ack bytes"); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain.as_slice(), &0xDE_AD_BE_EFu32.to_be_bytes()); -} - -#[tokio::test] -async fn write_client_payload_abridged_short_boundary_0x7e_is_single_byte_header() { - let (mut read_side, write_side) = duplex(1024 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = vec![0xABu8; 0x7e * 4]; - - write_client_payload( - &mut writer, - ProtoTag::Abridged, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("boundary payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 1 + payload.len()]; - read_side.read_exact(&mut encrypted).await.unwrap(); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain[0], 0x7e); - assert_eq!(&plain[1..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_abridged_extended_without_quickack_has_clean_prefix() { - let (mut read_side, write_side) = duplex(16 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = vec![0x42u8; 0x80 * 4]; - - write_client_payload( - &mut writer, - ProtoTag::Abridged, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("extended payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = vec![0u8; 4 + payload.len()]; - read_side.read_exact(&mut encrypted).await.unwrap(); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain[0], 0x7f); - assert_eq!(&plain[1..4], &[0x80, 0x00, 0x00]); - assert_eq!(&plain[4..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_intermediate_zero_length_emits_header_only() { - let (mut read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - write_client_payload( - &mut writer, - ProtoTag::Intermediate, - 0, - &[], - &rng, - &mut frame_buf, - ) - .await - .expect("zero-length intermediate payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = [0u8; 4]; - read_side.read_exact(&mut encrypted).await.unwrap(); - let plain = decryptor.decrypt(&encrypted); - assert_eq!(plain.as_slice(), &[0, 0, 0, 0]); -} - -#[tokio::test] -async fn write_client_payload_intermediate_ignores_unrelated_flags() { - let (mut read_side, write_side) = duplex(1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = [7u8; 12]; - - write_client_payload( - &mut writer, - ProtoTag::Intermediate, - 0x4000_0000, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted = [0u8; 16]; - read_side.read_exact(&mut encrypted).await.unwrap(); - let plain = decryptor.decrypt(&encrypted); - let len = u32::from_le_bytes(plain[0..4].try_into().unwrap()); - assert_eq!(len, payload.len() as u32, "only quickack bit may affect header"); - assert_eq!(&plain[4..], payload.as_slice()); -} - -#[tokio::test] -async fn write_client_payload_secure_without_quickack_keeps_msb_clear() { - let (mut read_side, write_side) = duplex(4096); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = [0x1Du8; 64]; - - write_client_payload( - &mut writer, - ProtoTag::Secure, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted_header = [0u8; 4]; - read_side.read_exact(&mut encrypted_header).await.unwrap(); - let plain_header = decryptor.decrypt(&encrypted_header); - let h: [u8; 4] = plain_header.as_slice().try_into().unwrap(); - let wire_len_raw = u32::from_le_bytes(h); - assert_eq!(wire_len_raw & 0x8000_0000, 0, "quickack bit must stay clear"); -} - -#[tokio::test] -async fn secure_padding_light_fuzz_distribution_has_multiple_outcomes() { - let (mut read_side, write_side) = duplex(256 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let payload = [0x55u8; 100]; - let mut seen = [false; 4]; - - for _ in 0..96 { - write_client_payload( - &mut writer, - ProtoTag::Secure, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("secure payload should serialize"); - writer.flush().await.expect("flush must succeed"); - - let mut encrypted_header = [0u8; 4]; - read_side.read_exact(&mut encrypted_header).await.unwrap(); - let plain_header = decryptor.decrypt(&encrypted_header); - let h: [u8; 4] = plain_header.as_slice().try_into().unwrap(); - let wire_len = (u32::from_le_bytes(h) & 0x7fff_ffff) as usize; - let padding_len = wire_len - payload.len(); - assert!((1..=3).contains(&padding_len)); - seen[padding_len] = true; - - let mut encrypted_body = vec![0u8; wire_len]; - read_side.read_exact(&mut encrypted_body).await.unwrap(); - let _ = decryptor.decrypt(&encrypted_body); - } - - let distinct = (1..=3).filter(|idx| seen[*idx]).count(); - assert!( - distinct >= 2, - "padding generator should not collapse to a single outcome under campaign" - ); -} - -#[tokio::test] -async fn write_client_payload_mixed_proto_sequence_preserves_stream_sync() { - let (mut read_side, write_side) = duplex(128 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024); - let mut decryptor = AesCtr::new(&key, iv); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - let p1 = vec![1u8; 8]; - let p2 = vec![2u8; 16]; - let p3 = vec![3u8; 20]; - - write_client_payload(&mut writer, ProtoTag::Abridged, 0, &p1, &rng, &mut frame_buf) - .await - .unwrap(); - write_client_payload( - &mut writer, - ProtoTag::Intermediate, - RPC_FLAG_QUICKACK, - &p2, - &rng, - &mut frame_buf, - ) - .await - .unwrap(); - write_client_payload(&mut writer, ProtoTag::Secure, 0, &p3, &rng, &mut frame_buf) - .await - .unwrap(); - writer.flush().await.unwrap(); - - // Frame 1: abridged short. - let mut e1 = vec![0u8; 1 + p1.len()]; - read_side.read_exact(&mut e1).await.unwrap(); - let d1 = decryptor.decrypt(&e1); - assert_eq!(d1[0], (p1.len() / 4) as u8); - assert_eq!(&d1[1..], p1.as_slice()); - - // Frame 2: intermediate with quickack. - let mut e2 = vec![0u8; 4 + p2.len()]; - read_side.read_exact(&mut e2).await.unwrap(); - let d2 = decryptor.decrypt(&e2); - let l2 = u32::from_le_bytes(d2[0..4].try_into().unwrap()); - assert_ne!(l2 & 0x8000_0000, 0); - assert_eq!((l2 & 0x7fff_ffff) as usize, p2.len()); - assert_eq!(&d2[4..], p2.as_slice()); - - // Frame 3: secure with bounded tail. - let mut e3h = [0u8; 4]; - read_side.read_exact(&mut e3h).await.unwrap(); - let d3h = decryptor.decrypt(&e3h); - let l3 = (u32::from_le_bytes(d3h.as_slice().try_into().unwrap()) & 0x7fff_ffff) as usize; - assert!(l3 >= p3.len()); - assert!((1..=3).contains(&(l3 - p3.len()))); - let mut e3b = vec![0u8; l3]; - read_side.read_exact(&mut e3b).await.unwrap(); - let d3b = decryptor.decrypt(&e3b); - assert_eq!(&d3b[..p3.len()], p3.as_slice()); -} - -#[test] -fn should_yield_sender_boundary_matrix_blackhat() { - assert!(!should_yield_c2me_sender(0, false)); - assert!(!should_yield_c2me_sender(0, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true)); - assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false)); - assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); - assert!(should_yield_c2me_sender( - C2ME_SENDER_FAIRNESS_BUDGET.saturating_add(1024), - true - )); -} - -#[test] -fn should_yield_sender_light_fuzz_matches_oracle() { - let mut s: u64 = 0xD00D_BAAD_F00D_CAFE; - for _ in 0..5000 { - s ^= s << 7; - s ^= s >> 9; - s ^= s << 8; - let sent = (s as usize) & 0x1fff; - let backlog = (s & 1) != 0; - - let expected = backlog && sent >= C2ME_SENDER_FAIRNESS_BUDGET; - assert_eq!(should_yield_c2me_sender(sent, backlog), expected); - } -} - -#[test] -fn quota_would_be_exceeded_exact_remaining_one_byte() { - let stats = Stats::new(); - let user = "quota-edge"; - let quota = 100u64; - stats.add_user_octets_to(user, 99); - - assert!( - !quota_would_be_exceeded_for_user(&stats, user, Some(quota), 1), - "exactly remaining budget should be allowed" - ); - assert!( - quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2), - "one byte beyond remaining budget must be rejected" - ); -} - -#[test] -fn quota_would_be_exceeded_saturating_edge_remains_fail_closed() { - let stats = Stats::new(); - let user = "quota-saturating-edge"; - let quota = u64::MAX - 3; - stats.add_user_octets_to(user, u64::MAX - 4); - - assert!( - quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2), - "saturating arithmetic edge must stay fail-closed" - ); -} - -#[test] -fn quota_exceeded_boundary_is_inclusive() { - let stats = Stats::new(); - let user = "quota-inclusive-boundary"; - stats.add_user_octets_to(user, 50); - - assert!(quota_exceeded_for_user(&stats, user, Some(50))); - assert!(!quota_exceeded_for_user(&stats, user, Some(51))); -} - -#[test] -fn quota_soft_helper_matches_capped_generic_helper_matrix() { - let stats = Stats::new(); - let user = "quota-soft-parity"; - - for used in [0u64, 1, 7, 63, 127, 255] { - stats.sub_user_octets_to(user, stats.get_user_total_octets(user)); - stats.add_user_octets_to(user, used); - - for quota in [8u64, 64, 128, 256] { - for overshoot in [0u64, 1, 5, 32] { - for bytes in [0u64, 1, 2, 7, 31, 64] { - let soft = quota_would_be_exceeded_for_user_soft( - &stats, - user, - Some(quota), - bytes, - overshoot, - ); - let capped = quota_would_be_exceeded_for_user( - &stats, - user, - Some(quota_soft_cap(quota, overshoot)), - bytes, - ); - assert_eq!( - soft, capped, - "soft helper parity mismatch: used={used} quota={quota} overshoot={overshoot} bytes={bytes}" - ); - } - } - } - } -} - -#[test] -fn quota_soft_helper_none_limit_never_rejects() { - let stats = Stats::new(); - let user = "quota-soft-none"; - stats.add_user_octets_to(user, u64::MAX); - - assert!(!quota_would_be_exceeded_for_user_soft( - &stats, - user, - None, - u64::MAX, - u64::MAX, - )); -} - -#[test] -fn quota_soft_cap_saturates_and_stays_fail_closed() { - let stats = Stats::new(); - let user = "quota-soft-saturating"; - let quota = u64::MAX - 2; - let overshoot = 100; - - assert_eq!(quota_soft_cap(quota, overshoot), u64::MAX); - - stats.add_user_octets_to(user, u64::MAX - 1); - assert!(quota_would_be_exceeded_for_user_soft( - &stats, - user, - Some(quota), - 2, - overshoot, - )); -} - -#[tokio::test] -async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() { - let (tx, mut rx) = mpsc::channel::(4); - enqueue_c2me_command(&tx, C2MeCommand::Close) - .await - .expect("close should enqueue on fast path"); - - let recv = timeout(TokioDuration::from_millis(50), rx.recv()) - .await - .expect("must receive close command") - .expect("close command should be present"); - assert!(matches!(recv, C2MeCommand::Close)); -} - -#[tokio::test] -async fn enqueue_c2me_data_full_then_drain_preserves_order() { - let (tx, mut rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[1]), - flags: 10, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let producer = tokio::spawn(async move { - enqueue_c2me_command( - &tx2, - C2MeCommand::Data { - payload: make_pooled_payload(&[2, 2]), - flags: 20, - }, - ) - .await - }); - - tokio::time::sleep(TokioDuration::from_millis(10)).await; - - let first = rx.recv().await.expect("first item should exist"); - match first { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[1]); - assert_eq!(flags, 10); - } - C2MeCommand::Close => panic!("unexpected close as first item"), - } - - producer.await.unwrap().expect("producer should complete"); - - let second = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .expect("second item should exist"); - match second { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[2, 2]); - assert_eq!(flags, 20); - } - C2MeCommand::Close => panic!("unexpected close as second item"), - } -} diff --git a/src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs deleted file mode 100644 index a787aa6..0000000 --- a/src/proxy/tests/middle_relay_cross_mode_lock_release_regression_security_tests.rs +++ /dev/null @@ -1,295 +0,0 @@ -use super::*; -use crate::crypto::{AesCtr, SecureRandom}; -use crate::stats::Stats; -use crate::stream::CryptoWriter; -use bytes::Bytes; -use std::pin::Pin; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::{Arc, Mutex}; -use std::task::{Context, Poll, Waker}; -use tokio::io::AsyncWrite; -use tokio::sync::Notify; -use tokio::task::JoinSet; -use tokio::time::{Duration, timeout}; - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -#[derive(Default)] -struct BlockingWriteState { - write_entered: AtomicBool, - released: AtomicBool, - write_waker: Mutex>, - write_entered_notify: Notify, -} - -struct BlockingWrite { - state: Arc, -} - -impl BlockingWrite { - fn new(state: Arc) -> Self { - Self { state } - } -} - -impl AsyncWrite for BlockingWrite { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.state.write_entered.store(true, Ordering::Release); - self.state.write_entered_notify.notify_waiters(); - - if self.state.released.load(Ordering::Acquire) { - return Poll::Ready(Ok(buf.len())); - } - - if let Ok(mut slot) = self.state.write_waker.lock() { - *slot = Some(cx.waker().clone()); - } - - Poll::Pending - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - -async fn wait_until_blocking_write_entered(state: &Arc) { - for _ in 0..8 { - if state.write_entered.load(Ordering::Acquire) { - return; - } - let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await; - } - - panic!("blocking writer did not enter poll_write in bounded time"); -} - -fn release_blocking_write(state: &Arc) { - state.released.store(true, Ordering::Release); - if let Ok(mut slot) = state.write_waker.lock() - && let Some(waker) = slot.take() - { - waker.wake(); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn adversarial_blocked_write_releases_cross_mode_lock_and_preserves_fail_closed_quota() { - let stats = Arc::new(Stats::new()); - let user = format!("middle-cross-release-regression-{}", std::process::id()); - let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user)); - let bytes_me2c = Arc::new(AtomicU64::new(0)); - let writer_state = Arc::new(BlockingWriteState::default()); - - let first = { - let stats = Arc::clone(&stats); - let user = user.clone(); - let cross_mode_lock = Arc::clone(&cross_mode_lock); - let bytes_me2c = Arc::clone(&bytes_me2c); - let writer_state = Arc::clone(&writer_state); - tokio::spawn(async move { - let mut writer = make_crypto_writer(BlockingWrite::new(writer_state)); - let mut frame_buf = Vec::new(); - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xAA, 0xBB, 0xCC, 0xDD]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats.as_ref(), - &user, - Some(4), - 0, - Some(&cross_mode_lock), - bytes_me2c.as_ref(), - 41_000, - false, - false, - ) - .await - }) - }; - - wait_until_blocking_write_entered(&writer_state).await; - - let guard = timeout(Duration::from_millis(40), cross_mode_lock.lock()) - .await - .expect("cross-mode lock must be released while first write is pending"); - drop(guard); - - let second = { - let stats = Arc::clone(&stats); - let user = user.clone(); - let cross_mode_lock = Arc::clone(&cross_mode_lock); - let bytes_me2c = Arc::clone(&bytes_me2c); - tokio::spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - timeout( - Duration::from_millis(150), - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xEE]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats.as_ref(), - &user, - Some(4), - 0, - Some(&cross_mode_lock), - bytes_me2c.as_ref(), - 41_001, - false, - false, - ), - ) - .await - }) - }; - - let second_result = second - .await - .expect("second task must not panic") - .expect("second write must not block on cross-mode lock"); - assert!( - matches!(second_result, Err(ProxyError::DataQuotaExceeded { .. })), - "second write must fail closed due to first write reservation" - ); - - release_blocking_write(&writer_state); - - let first_result = timeout(Duration::from_millis(300), first) - .await - .expect("first task timed out") - .expect("first task must not panic"); - assert!(first_result.is_ok()); - - assert_eq!(stats.get_user_total_octets(&user), 4); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_pending_write_does_not_starve_same_user_waiters_after_quota_boundary() { - let stats = Arc::new(Stats::new()); - let user = format!("middle-cross-release-stress-{}", std::process::id()); - let cross_mode_lock = Arc::new(cross_mode_quota_user_lock_for_tests(&user)); - let bytes_me2c = Arc::new(AtomicU64::new(0)); - let writer_state = Arc::new(BlockingWriteState::default()); - - let first = { - let stats = Arc::clone(&stats); - let user = user.clone(); - let cross_mode_lock = Arc::clone(&cross_mode_lock); - let bytes_me2c = Arc::clone(&bytes_me2c); - let writer_state = Arc::clone(&writer_state); - tokio::spawn(async move { - let mut writer = make_crypto_writer(BlockingWrite::new(writer_state)); - let mut frame_buf = Vec::new(); - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x01, 0x02]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats.as_ref(), - &user, - Some(3), - 0, - Some(&cross_mode_lock), - bytes_me2c.as_ref(), - 41_100, - false, - false, - ) - .await - }) - }; - - wait_until_blocking_write_entered(&writer_state).await; - - let mut set = JoinSet::new(); - for idx in 0..48u64 { - let stats = Arc::clone(&stats); - let user = user.clone(); - let cross_mode_lock = Arc::clone(&cross_mode_lock); - let bytes_me2c = Arc::clone(&bytes_me2c); - set.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - timeout( - Duration::from_millis(200), - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x10]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats.as_ref(), - &user, - Some(3), - 0, - Some(&cross_mode_lock), - bytes_me2c.as_ref(), - 41_200 + idx, - false, - false, - ), - ) - .await - }); - } - - let mut ok = 0usize; - let mut quota_exceeded = 0usize; - while let Some(done) = set.join_next().await { - let timed = done.expect("waiter task must not panic"); - let result = timed.expect("waiter must not block behind pending first write"); - match result { - Ok(_) => ok += 1, - Err(ProxyError::DataQuotaExceeded { .. }) => quota_exceeded += 1, - Err(other) => panic!("unexpected error in waiter: {other:?}"), - } - } - - assert_eq!(ok, 1, "exactly one waiter should consume remaining one-byte quota"); - assert_eq!(quota_exceeded, 47); - - release_blocking_write(&writer_state); - - let first_result = timeout(Duration::from_millis(300), first) - .await - .expect("first task timed out") - .expect("first task must not panic"); - assert!(first_result.is_ok()); - - assert_eq!(stats.get_user_total_octets(&user), 3); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3); -} diff --git a/src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs deleted file mode 100644 index 37e1b87..0000000 --- a/src/proxy/tests/middle_relay_cross_mode_lookup_efficiency_security_tests.rs +++ /dev/null @@ -1,116 +0,0 @@ -use super::*; -use crate::crypto::{AesCtr, SecureRandom}; -use crate::stats::Stats; -use crate::stream::CryptoWriter; -use bytes::Bytes; -use std::sync::Arc; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Mutex, OnceLock}; - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -fn lookup_counter_test_lock() -> &'static Mutex<()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| Mutex::new(())) -} - -#[tokio::test] -async fn tdd_prefetched_cross_mode_lock_avoids_per_frame_registry_lookup_in_me_to_client_writer() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Stats::new(); - let user = format!("middle-cross-mode-lookup-{}", std::process::id()); - let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - - crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - for idx in 0..8u64 { - let outcome = process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xAB]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1024), - 0, - Some(&cross_mode_lock), - &bytes_me2c, - 20_000 + idx, - false, - false, - ) - .await; - - assert!(outcome.is_ok()); - } - - assert_eq!( - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), - 0, - "prefetched lock path must not re-query lock registry per frame" - ); - assert_eq!(stats.get_user_total_octets(&user), 8); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 8); -} - -#[tokio::test] -async fn control_without_prefetched_lock_still_uses_registry_lookup_path() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Stats::new(); - let user = format!("middle-cross-mode-lookup-control-{}", std::process::id()); - - crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let outcome = process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xCD]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1024), - 0, - None, - &bytes_me2c, - 20_100, - false, - false, - ) - .await; - - assert!(outcome.is_ok()); - assert_eq!( - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), - 1, - "fallback path without prefetched lock should perform a registry lookup" - ); -} diff --git a/src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs deleted file mode 100644 index bc7c857..0000000 --- a/src/proxy/tests/middle_relay_cross_mode_quota_lock_matrix_security_tests.rs +++ /dev/null @@ -1,376 +0,0 @@ -use super::*; -use crate::crypto::{AesCtr, SecureRandom}; -use crate::stats::Stats; -use crate::stream::CryptoWriter; -use bytes::Bytes; -use std::sync::Arc; -use std::sync::atomic::{AtomicU64, Ordering}; -use tokio::time::{Duration, timeout}; - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -#[tokio::test] -async fn positive_quota_limited_me_to_client_write_updates_counters_exactly_once() { - let stats = Stats::new(); - let user = format!("middle-cross-matrix-positive-{}", std::process::id()); - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3, 4]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(128), - 0, - &bytes_me2c, - 10_001, - false, - false, - ) - .await; - - assert!(result.is_ok()); - assert_eq!(stats.get_user_total_octets(&user), 4); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4); -} - -#[tokio::test] -async fn negative_held_cross_mode_lock_blocks_quota_limited_me_to_client_path() { - let stats = Stats::new(); - let user = format!("middle-cross-matrix-negative-{}", std::process::id()); - let held = cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold lock before ME->C call"); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let blocked = timeout( - Duration::from_millis(25), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x41]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(256), - 0, - &bytes_me2c, - 10_002, - false, - false, - ), - ) - .await; - - assert!(blocked.is_err()); - drop(held_guard); -} - -#[tokio::test] -async fn edge_quota_none_bypasses_cross_mode_lock_guard_in_me_to_client_path() { - let stats = Stats::new(); - let user = format!("middle-cross-matrix-edge-none-{}", std::process::id()); - let held = cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold lock while quota is disabled"); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let outcome = timeout( - Duration::from_millis(80), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x11, 0x22]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - None, - 0, - &bytes_me2c, - 10_003, - false, - false, - ), - ) - .await - .expect("quota-none path must not wait on cross-mode lock"); - - assert!(outcome.is_ok()); - drop(held_guard); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_same_user_parallel_quota_limited_writes_stay_hard_capped() { - let stats = Arc::new(Stats::new()); - let user = format!("middle-cross-matrix-adversarial-{}", std::process::id()); - let limit = 64u64; - let bytes_me2c = Arc::new(AtomicU64::new(0)); - let mut tasks = Vec::new(); - - for idx in 0..256u64 { - let stats = Arc::clone(&stats); - let bytes_me2c = Arc::clone(&bytes_me2c); - let user = user.clone(); - tasks.push(tokio::spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xEE]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats.as_ref(), - &user, - Some(limit), - 0, - bytes_me2c.as_ref(), - 11_000 + idx, - false, - false, - ) - .await - })); - } - - let mut ok = 0usize; - for task in tasks { - match task.await.expect("task must not panic") { - Ok(_) => ok += 1, - Err(ProxyError::DataQuotaExceeded { .. }) => {} - Err(other) => panic!("unexpected error in adversarial parallel case: {other:?}"), - } - } - - assert_eq!(ok, limit as usize); - assert_eq!(stats.get_user_total_octets(&user), limit); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), limit); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_shared_lock_blocks_direct_relay_and_middle_relay_for_same_user() { - let user = format!("middle-cross-matrix-integration-{}", std::process::id()); - let relay_lock = crate::proxy::relay::cross_mode_quota_user_lock_for_tests(&user); - let middle_lock = cross_mode_quota_user_lock_for_tests(&user); - assert!( - Arc::ptr_eq(&relay_lock, &middle_lock), - "relay and middle-relay must share the same cross-mode lock identity" - ); - - let held_guard = relay_lock - .try_lock() - .expect("test must hold shared cross-mode lock"); - - let stats = Stats::new(); - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let middle_blocked = timeout( - Duration::from_millis(25), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x92]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1024), - 0, - &bytes_me2c, - 12_001, - false, - false, - ), - ) - .await; - assert!(middle_blocked.is_err()); - - drop(held_guard); - - let middle_ready = timeout( - Duration::from_millis(250), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x94]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1024), - 0, - &bytes_me2c, - 12_002, - false, - false, - ), - ) - .await - .expect("middle path must complete after release"); - - assert!(middle_ready.is_ok()); -} - -#[tokio::test] -async fn light_fuzz_mixed_payload_sizes_with_periodic_lock_holds_keeps_accounting_consistent() { - let stats = Stats::new(); - let user = format!("middle-cross-matrix-fuzz-{}", std::process::id()); - let bytes_me2c = AtomicU64::new(0); - let mut seed = 0xC0DE_1234_55AA_9988u64; - - for case in 0..96u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let hold = (seed & 0x03) == 0; - let mut held_lock = None; - let maybe_guard = if hold { - held_lock = Some(cross_mode_quota_user_lock_for_tests(&user)); - Some( - held_lock - .as_ref() - .expect("held lock should be present") - .try_lock() - .expect("cross-mode lock should be acquirable in fuzz round"), - ) - } else { - None - }; - - let payload_len = ((seed >> 8) as usize % 8) + 1; - let payload = vec![(seed & 0xff) as u8; payload_len]; - let before = stats.get_user_total_octets(&user); - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - - let timed = timeout( - Duration::from_millis(20), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1024), - 0, - &bytes_me2c, - 13_000 + case as u64, - false, - false, - ), - ) - .await; - - if hold { - assert!(timed.is_err(), "held-lock fuzz round must block within timeout"); - assert_eq!(stats.get_user_total_octets(&user), before); - } else { - let done = timed.expect("unheld fuzz round must complete in time"); - assert!(done.is_ok()); - } - - drop(maybe_guard); - drop(held_lock); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), stats.get_user_total_octets(&user)); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_held_user_lock_does_not_block_other_users_me_to_client_writes() { - let held_user = format!("middle-cross-matrix-stress-held-{}", std::process::id()); - let free_user = format!("middle-cross-matrix-stress-free-{}", std::process::id()); - - let held = cross_mode_quota_user_lock_for_tests(&held_user); - let held_guard = held - .try_lock() - .expect("test must hold lock for blocked user"); - - let mut tasks = Vec::new(); - for idx in 0..64u64 { - let user = free_user.clone(); - tasks.push(tokio::spawn(async move { - let stats = Stats::new(); - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xA0]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1), - 0, - &bytes_me2c, - 14_000 + idx, - false, - false, - ) - .await - })); - } - - timeout(Duration::from_secs(2), async { - for task in tasks { - let done = task.await.expect("free-user task must not panic"); - assert!(done.is_ok()); - } - }) - .await - .expect("free-user tasks should complete without waiting for held user's lock"); - - drop(held_guard); -} diff --git a/src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs b/src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs deleted file mode 100644 index 51092bd..0000000 --- a/src/proxy/tests/middle_relay_cross_mode_quota_reservation_security_tests.rs +++ /dev/null @@ -1,254 +0,0 @@ -use super::*; -use crate::crypto::{AesCtr, SecureRandom}; -use crate::stats::Stats; -use crate::stream::CryptoWriter; -use bytes::Bytes; -use std::pin::Pin; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::sync::{Arc, Mutex}; -use std::task::{Context, Poll, Waker}; -use tokio::io::AsyncWrite; -use tokio::sync::Notify; -use tokio::time::{Duration, timeout}; - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -#[derive(Default)] -struct BlockingWriteState { - write_entered: AtomicBool, - released: AtomicBool, - write_waker: Mutex>, - write_entered_notify: Notify, -} - -struct BlockingWrite { - state: Arc, -} - -impl BlockingWrite { - fn new(state: Arc) -> Self { - Self { state } - } -} - -impl AsyncWrite for BlockingWrite { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - self.state.write_entered.store(true, Ordering::Release); - self.state.write_entered_notify.notify_waiters(); - - if self.state.released.load(Ordering::Acquire) { - return Poll::Ready(Ok(buf.len())); - } - - if let Ok(mut slot) = self.state.write_waker.lock() { - *slot = Some(cx.waker().clone()); - } - Poll::Pending - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - -async fn wait_until_blocking_write_entered(state: &Arc) { - for _ in 0..8 { - if state.write_entered.load(Ordering::Acquire) { - return; - } - let _ = timeout(Duration::from_millis(25), state.write_entered_notify.notified()).await; - } - - panic!("blocking writer did not enter poll_write in bounded time"); -} - -fn release_blocking_write(state: &Arc) { - state.released.store(true, Ordering::Release); - if let Ok(mut slot) = state.write_waker.lock() - && let Some(waker) = slot.take() - { - waker.wake(); - } -} - -#[tokio::test] -async fn adversarial_held_cross_mode_lock_blocks_me_to_client_quota_reservation_path() { - let stats = Stats::new(); - let user = format!("middle-me2c-cross-mode-held-{}", std::process::id()); - let held = cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold shared cross-mode lock before ME->C write path"); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let blocked = timeout( - Duration::from_millis(25), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x41]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1024), - 0, - &bytes_me2c, - 9901, - false, - false, - ), - ) - .await; - - assert!( - blocked.is_err(), - "ME->C quota reservation path must be serialized by held shared cross-mode lock" - ); - - drop(held_guard); - - let released = timeout( - Duration::from_millis(250), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x42]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1024), - 0, - &bytes_me2c, - 9902, - false, - false, - ), - ) - .await - .expect("ME->C write must complete after cross-mode lock release"); - - assert!(released.is_ok()); -} - -#[tokio::test] -async fn business_uncontended_cross_mode_lock_allows_me_to_client_quota_reservation() { - let stats = Stats::new(); - let user = format!("middle-me2c-cross-mode-free-{}", std::process::id()); - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let outcome = timeout( - Duration::from_millis(250), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x55, 0x66]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1024), - 0, - &bytes_me2c, - 9903, - false, - false, - ), - ) - .await - .expect("uncontended ME->C path should not stall"); - - assert!(outcome.is_ok()); - assert_eq!(stats.get_user_total_octets(&user), 2); - assert_eq!(bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), 2); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn adversarial_cross_mode_lock_is_released_before_me_to_client_write_await() { - let stats = Arc::new(Stats::new()); - let user = format!("middle-me2c-lock-drop-before-write-{}", std::process::id()); - let cross_mode_lock = cross_mode_quota_user_lock_for_tests(&user); - let bytes_me2c = Arc::new(AtomicU64::new(0)); - let writer_state = Arc::new(BlockingWriteState::default()); - - let worker = { - let stats = Arc::clone(&stats); - let user = user.clone(); - let cross_mode_lock = Arc::clone(&cross_mode_lock); - let bytes_me2c = Arc::clone(&bytes_me2c); - let writer_state = Arc::clone(&writer_state); - tokio::spawn(async move { - let mut writer = make_crypto_writer(BlockingWrite::new(writer_state)); - let mut frame_buf = Vec::new(); - let rng = SecureRandom::new(); - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xDE, 0xAD, 0xBE, 0xEF]), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - stats.as_ref(), - &user, - Some(1024), - 0, - Some(&cross_mode_lock), - bytes_me2c.as_ref(), - 9910, - false, - false, - ) - .await - }) - }; - - wait_until_blocking_write_entered(&writer_state).await; - - let acquired_guard = timeout(Duration::from_millis(40), cross_mode_lock.lock()) - .await - .expect("cross-mode lock must be free while ME->C write is pending"); - drop(acquired_guard); - - release_blocking_write(&writer_state); - - let result = timeout(Duration::from_millis(300), worker) - .await - .expect("ME->C worker timed out after releasing blocking writer") - .expect("ME->C worker must not panic"); - - assert!(result.is_ok()); - assert_eq!(stats.get_user_total_octets(&user), 4); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 4); -} diff --git a/src/proxy/tests/middle_relay_hol_quota_security_tests.rs b/src/proxy/tests/middle_relay_hol_quota_security_tests.rs deleted file mode 100644 index 3ce0235..0000000 --- a/src/proxy/tests/middle_relay_hol_quota_security_tests.rs +++ /dev/null @@ -1,232 +0,0 @@ -use super::*; -use crate::crypto::{AesCtr, SecureRandom}; -use crate::stats::Stats; -use crate::stream::CryptoWriter; -use bytes::Bytes; -use std::io; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; -use std::task::{Context, Poll, Waker}; -use tokio::io::AsyncWrite; -use tokio::time::{Duration, timeout}; - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -#[derive(Default)] -struct GateState { - open: AtomicBool, - parked_waker: std::sync::Mutex>, -} - -impl GateState { - fn open(&self) { - self.open.store(true, Ordering::Relaxed); - if let Ok(mut guard) = self.parked_waker.lock() - && let Some(w) = guard.take() - { - w.wake(); - } - } - - fn has_waiter(&self) -> bool { - self.parked_waker - .lock() - .map(|guard| guard.is_some()) - .unwrap_or(false) - } -} - -#[derive(Default)] -struct GateWriter { - gate: Arc, -} - -impl GateWriter { - fn new(gate: Arc) -> Self { - Self { gate } - } -} - -impl AsyncWrite for GateWriter { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if self.gate.open.load(Ordering::Relaxed) { - return Poll::Ready(Ok(buf.len())); - } - - if let Ok(mut guard) = self.gate.parked_waker.lock() { - *guard = Some(cx.waker().clone()); - } - Poll::Pending - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - -struct FailingWriter; - -impl AsyncWrite for FailingWriter { - fn poll_write( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - _buf: &[u8], - ) -> Poll> { - Poll::Ready(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - "injected writer failure", - ))) - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn adversarial_same_user_slow_writer_must_not_hol_block_peer_connection() { - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - let rng = SecureRandom::new(); - let quota_limit = Some(1024); - let user = "hol-quota-user"; - - let gate = Arc::new(GateState::default()); - - let mut blocked_writer = make_crypto_writer(GateWriter::new(Arc::clone(&gate))); - let slow_task = tokio::spawn(async move { - let mut frame_buf = Vec::new(); - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x10, 0x20, 0x30, 0x40]), - }, - &mut blocked_writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - user, - quota_limit, - 0, - &bytes_me2c, - 7001, - false, - false, - ) - .await - }); - - timeout(Duration::from_millis(100), async { - loop { - if gate.has_waiter() { - break; - } - tokio::task::yield_now().await; - } - }) - .await - .expect("first writer must reach backpressure and park"); - - let stats_fast = Stats::new(); - let bytes_fast = AtomicU64::new(0); - let rng_fast = SecureRandom::new(); - let mut fast_writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf_fast = Vec::new(); - - timeout( - Duration::from_millis(50), - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x41]), - }, - &mut fast_writer, - ProtoTag::Intermediate, - &rng_fast, - &mut frame_buf_fast, - &stats_fast, - user, - quota_limit, - 0, - &bytes_fast, - 7002, - false, - false, - ), - ) - .await - .expect("peer connection must not be blocked by same-user stalled write") - .expect("fast peer write must succeed"); - - gate.open(); - let slow_result = timeout(Duration::from_secs(1), slow_task) - .await - .expect("stalled task must complete once gate opens") - .expect("stalled task must not panic"); - assert!(slow_result.is_ok()); -} - -#[tokio::test] -async fn negative_write_failure_rolls_back_pre_accounted_quota_and_forensics_bytes() { - let stats = Stats::new(); - let user = "rollback-user"; - stats.add_user_octets_from(user, 7); - - let bytes_me2c = AtomicU64::new(0); - let rng = SecureRandom::new(); - let mut writer = make_crypto_writer(FailingWriter); - let mut frame_buf = Vec::new(); - - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3, 4]), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - user, - Some(64), - 0, - &bytes_me2c, - 7003, - false, - false, - ) - .await; - - assert!(matches!(result, Err(ProxyError::Io(_)))); - assert_eq!( - stats.get_user_total_octets(user), - 7, - "failed client write must not overcharge user quota accounting" - ); - assert_eq!( - bytes_me2c.load(Ordering::Relaxed), - 0, - "failed client write must not inflate ME->C forensic byte counter" - ); -} \ No newline at end of file diff --git a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs index 6ea182b..fd3243d 100644 --- a/src/proxy/tests/middle_relay_idle_policy_security_tests.rs +++ b/src/proxy/tests/middle_relay_idle_policy_security_tests.rs @@ -2,8 +2,8 @@ use super::*; use crate::crypto::AesCtr; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader}; -use std::sync::atomic::AtomicU64; use std::sync::Arc; +use std::sync::atomic::AtomicU64; use tokio::io::AsyncWriteExt; use tokio::io::duplex; use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout}; diff --git a/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs b/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs index 112d926..b43825c 100644 --- a/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs +++ b/src/proxy/tests/middle_relay_idle_registry_poison_security_tests.rs @@ -29,7 +29,10 @@ fn blackhat_registry_poison_recovers_with_fail_closed_reset_and_pressure_account let before = relay_pressure_event_seq(); note_relay_pressure_event(); let after = relay_pressure_event_seq(); - assert!(after > before, "pressure accounting must still advance after poison"); + assert!( + after > before, + "pressure accounting must still advance after poison" + ); clear_relay_idle_pressure_state_for_testing(); } diff --git a/src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs b/src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs deleted file mode 100644 index 29384e0..0000000 --- a/src/proxy/tests/middle_relay_quota_extended_attack_surface_security_tests.rs +++ /dev/null @@ -1,372 +0,0 @@ -use super::*; -use crate::crypto::{AesCtr, SecureRandom}; -use crate::error::ProxyError; -use crate::stats::Stats; -use crate::stream::CryptoWriter; -use bytes::Bytes; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, OnceLock, Mutex}; -use tokio::sync::Mutex as AsyncMutex; -use tokio::task::JoinSet; -use tokio::time::{Duration, timeout}; - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -fn lookup_test_lock() -> &'static Mutex<()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| Mutex::new(())) -} - -#[tokio::test] -async fn positive_me2c_quota_counts_bytes_exactly_once() { - let _guard = lookup_test_lock().lock().unwrap(); - let stats = Stats::new(); - let user = format!("quota-middle-ext-positive-{}", std::process::id()); - let lock = Arc::new(AsyncMutex::new(())); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let result = process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3, 4, 5]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(64), - 0, - Some(&lock), - &bytes_me2c, - 70_001, - false, - false, - ) - .await; - - assert!(result.is_ok()); - assert_eq!(stats.get_user_total_octets(&user), 5); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 5); -} - -#[tokio::test] -async fn negative_held_crossmode_lock_blocks_me2c_write() { - let _guard = lookup_test_lock().lock().unwrap(); - let stats = Stats::new(); - let user = format!("quota-middle-ext-negative-{}", std::process::id()); - - let lock = Arc::new(AsyncMutex::new(())); - let _held = lock.try_lock().expect("lock must be held"); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let blocked = timeout( - Duration::from_millis(25), - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xFE]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(16), - 0, - Some(&lock), - &bytes_me2c, - 70_101, - false, - false, - ), - ) - .await; - - assert!(blocked.is_err()); - assert_eq!(stats.get_user_total_octets(&user), 0); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); -} - -#[tokio::test] -async fn edge_zero_quota_zero_payload_is_fail_closed() { - let _guard = lookup_test_lock().lock().unwrap(); - let stats = Stats::new(); - let user = format!("quota-middle-ext-edge-{}", std::process::id()); - - let lock = Arc::new(AsyncMutex::new(())); - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let result = process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::new(), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(0), - 0, - Some(&lock), - &bytes_me2c, - 70_201, - false, - false, - ) - .await; - - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert_eq!(stats.get_user_total_octets(&user), 0); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_parallel_me2c_race_falls_back_to_quota_error() { - let _guard = lookup_test_lock().lock().unwrap(); - let stats = Arc::new(Stats::new()); - let user = format!("quota-middle-ext-blackhat-{}", std::process::id()); - let quota = 64u64; - let lock = Arc::new(AsyncMutex::new(())); - let bytes_me2c = Arc::new(AtomicU64::new(0)); - - let mut set = JoinSet::new(); - for i in 0..256u64 { - let stats = Arc::clone(&stats); - let user = user.clone(); - let lock = Arc::clone(&lock); - let bytes_me2c = Arc::clone(&bytes_me2c); - - set.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let payload = vec![((i & 0xFF) as u8); (i % 4 + 1) as usize]; - - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats.as_ref(), - &user, - Some(quota), - 0, - Some(&lock), - bytes_me2c.as_ref(), - 70_301 + i, - false, - false, - ) - .await - }); - } - - let mut succeeded = 0usize; - while let Some(done) = set.join_next().await { - match done.expect("task must not panic") { - Ok(_) => succeeded += 1, - Err(ProxyError::DataQuotaExceeded { .. }) => {} - Err(other) => panic!("unexpected error {other:?}"), - } - } - - assert_eq!(stats.get_user_total_octets(&user), bytes_me2c.load(Ordering::Relaxed)); - assert!(stats.get_user_total_octets(&user) <= quota); - assert!(succeeded <= quota as usize); -} - -#[tokio::test] -async fn integration_shared_prefetched_lock_blocks_then_releases_writer() { - let stats = Stats::new(); - let user = format!("quota-middle-ext-integration-{}", std::process::id()); - let lock = Arc::new(AsyncMutex::new(())); - let held = lock - .try_lock() - .expect("integration test must hold prefetched lock first"); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let blocked = timeout( - Duration::from_millis(25), - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xA1]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(8), - 0, - Some(&lock), - &bytes_me2c, - 70_360, - false, - false, - ), - ) - .await; - assert!(blocked.is_err()); - - drop(held); - - let after_release = timeout( - Duration::from_millis(150), - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xA2]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(8), - 0, - Some(&lock), - &bytes_me2c, - 70_361, - false, - false, - ), - ) - .await - .expect("writer should progress once the shared lock is released"); - - assert!(after_release.is_ok()); -} - -#[tokio::test] -async fn light_fuzz_small_payloads_toggle_lock_state_stays_consistent() { - let _guard = lookup_test_lock().lock().unwrap(); - let stats = Stats::new(); - let user = format!("quota-middle-ext-fuzz-{}", std::process::id()); - let mut seed = 0xCAFE_BABE_1234u64; - let bytes_me2c = AtomicU64::new(0); - - for case in 0..48u32 { - seed ^= seed << 5; - seed ^= seed >> 12; - seed ^= seed << 13; - let hold = (seed & 0x1) == 0; - - let lock = Arc::new(AsyncMutex::new(())); - let maybe_guard = if hold { - Some(lock.try_lock().unwrap()) - } else { - None - }; - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - - let result = timeout( - Duration::from_millis(30), - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![(seed & 0xFF) as u8; ((seed as usize % 5) + 1)]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(128), - 0, - Some(&lock), - &bytes_me2c, - 70_401 + case as u64, - false, - false, - ), - ) - .await; - - if hold { - assert!(result.is_err()); - } else { - assert!(result.unwrap().is_ok()); - } - - drop(maybe_guard); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_parallel_free_users_during_held_user_lock_maintains_liveness() { - let _guard = lookup_test_lock().lock().unwrap(); - let held = Arc::new(AsyncMutex::new(())); - let _held_guard = held.try_lock().unwrap(); - - let mut set = JoinSet::new(); - for i in 0..48u64 { - set.spawn(async move { - let stats = Stats::new(); - let user = format!("quota-middle-ext-stress-free-{i}"); - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - let free_lock = Arc::new(AsyncMutex::new(())); - - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xEE]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(1), - 0, - Some(&free_lock), - &bytes_me2c, - 70_500 + i, - false, - false, - ) - .await - }); - } - - timeout(Duration::from_secs(2), async { - while let Some(task) = set.join_next().await { - task.unwrap().unwrap(); - } - }) - .await - .unwrap(); -} diff --git a/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs b/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs deleted file mode 100644 index d06e103..0000000 --- a/src/proxy/tests/middle_relay_quota_overflow_lock_security_tests.rs +++ /dev/null @@ -1,131 +0,0 @@ -use super::*; -use dashmap::DashMap; -use std::sync::Arc; - -#[test] -fn saturation_uses_stable_overflow_lock_without_cache_growth() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let prefix = format!("middle-quota-held-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX); - - let user = format!("middle-quota-overflow-{}", std::process::id()); - let first = quota_user_lock(&user); - let second = quota_user_lock(&user); - - assert!( - Arc::ptr_eq(&first, &second), - "overflow user must get deterministic same lock while cache is saturated" - ); - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "overflow path must not grow bounded lock map" - ); - assert!( - map.get(&user).is_none(), - "overflow user should stay outside bounded lock map under saturation" - ); - - drop(retained); -} - -#[test] -fn overflow_striping_keeps_different_users_distributed() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let prefix = format!("middle-quota-dist-held-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - let a = quota_user_lock("middle-overflow-user-a"); - let b = quota_user_lock("middle-overflow-user-b"); - let c = quota_user_lock("middle-overflow-user-c"); - - let distinct = [ - Arc::as_ptr(&a) as usize, - Arc::as_ptr(&b) as usize, - Arc::as_ptr(&c) as usize, - ] - .iter() - .copied() - .collect::>() - .len(); - - assert!( - distinct >= 2, - "striped overflow lock set should avoid collapsing all users to one lock" - ); - - drop(retained); -} - -#[test] -fn reclaim_path_caches_new_user_after_stale_entries_drop() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let prefix = format!("middle-quota-reclaim-held-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - drop(retained); - - let user = format!("middle-quota-reclaim-user-{}", std::process::id()); - let got = quota_user_lock(&user); - assert!(map.get(&user).is_some()); - assert!( - Arc::strong_count(&got) >= 2, - "after reclaim, lock should be held both by caller and map" - ); -} - -#[test] -fn overflow_path_same_user_is_stable_across_parallel_threads() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "middle-quota-thread-held-{}-{idx}", - std::process::id() - ))); - } - - let user = format!("middle-quota-overflow-thread-user-{}", std::process::id()); - let mut workers = Vec::new(); - for _ in 0..32 { - let user = user.clone(); - workers.push(std::thread::spawn(move || quota_user_lock(&user))); - } - - let first = workers - .remove(0) - .join() - .expect("thread must return lock handle"); - for worker in workers { - let got = worker.join().expect("thread must return lock handle"); - assert!( - Arc::ptr_eq(&first, &got), - "same overflow user should resolve to one striped lock even under contention" - ); - } - - drop(retained); -} diff --git a/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs b/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs deleted file mode 100644 index 963b3e0..0000000 --- a/src/proxy/tests/middle_relay_quota_reservation_adversarial_tests.rs +++ /dev/null @@ -1,1066 +0,0 @@ -use super::*; -use crate::crypto::{AesCtr, SecureRandom}; -use crate::stats::Stats; -use crate::stream::CryptoWriter; -use bytes::Bytes; -use std::sync::Arc; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::task::{Context, Poll}; -use tokio::io::AsyncWrite; -use tokio::task::JoinSet; - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -struct FailingWriter; - -impl AsyncWrite for FailingWriter { - fn poll_write( - self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - _buf: &[u8], - ) -> Poll> { - Poll::Ready(Err(std::io::Error::other("forced writer failure"))) - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } -} - -struct FailAfterBudgetWriter { - remaining: usize, - written: usize, -} - -impl FailAfterBudgetWriter { - fn new(remaining: usize) -> Self { - Self { - remaining, - written: 0, - } - } -} - -impl AsyncWrite for FailAfterBudgetWriter { - fn poll_write( - mut self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if self.remaining == 0 { - return Poll::Ready(Err(std::io::Error::other("forced short-write exhaustion"))); - } - - let n = self.remaining.min(buf.len()); - self.remaining -= n; - self.written += n; - Poll::Ready(Ok(n)) - } - - fn poll_flush( - self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown( - self: std::pin::Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } -} - -#[tokio::test] -async fn positive_exact_quota_boundary_allows_last_frame_and_blocks_next() { - let stats = Stats::new(); - let user = "quota-boundary-user"; - let bytes_me2c = AtomicU64::new(0); - - stats.add_user_octets_from(user, 5); - - let mut writer_one = make_crypto_writer(tokio::io::sink()); - let mut frame_buf_one = Vec::new(); - let first = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3]), - }, - &mut writer_one, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf_one, - &stats, - user, - Some(8), - 0, - &bytes_me2c, - 7101, - false, - false, - ) - .await; - - assert!(first.is_ok(), "frame that reaches boundary must be allowed"); - assert_eq!(stats.get_user_total_octets(user), 8); - - let mut writer_two = make_crypto_writer(tokio::io::sink()); - let mut frame_buf_two = Vec::new(); - let second = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[9]), - }, - &mut writer_two, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf_two, - &stats, - user, - Some(8), - 0, - &bytes_me2c, - 7102, - false, - false, - ) - .await; - - assert!( - matches!(second, Err(ProxyError::DataQuotaExceeded { .. })), - "frame after boundary must be rejected" - ); - assert_eq!(stats.get_user_total_octets(user), 8); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_parallel_reservation_stress_never_overshoots_quota_or_counters() { - let stats = Arc::new(Stats::new()); - let user = "reservation-stress-user"; - let quota_limit = 64u64; - let bytes_me2c = Arc::new(AtomicU64::new(0)); - let mut tasks = JoinSet::new(); - - for idx in 0..256u64 { - let user_owned = user.to_string(); - let stats_ref = Arc::clone(&stats); - let bytes_ref = Arc::clone(&bytes_me2c); - - tasks.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xAB]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats_ref.as_ref(), - &user_owned, - Some(quota_limit), - 0, - bytes_ref.as_ref(), - 7200 + idx, - false, - false, - ) - .await - }); - } - - let mut ok = 0usize; - let mut denied = 0usize; - while let Some(joined) = tasks.join_next().await { - match joined.expect("reservation stress task must not panic") { - Ok(_) => ok += 1, - Err(ProxyError::DataQuotaExceeded { .. }) => denied += 1, - Err(other) => panic!("unexpected error in stress case: {other:?}"), - } - } - - let total = stats.get_user_total_octets(user); - assert_eq!( - total, quota_limit, - "quota must be exactly exhausted without overshoot" - ); - assert_eq!( - bytes_me2c.load(Ordering::Relaxed), - total, - "ME->C forensic bytes must track committed quota usage" - ); - assert_eq!(ok, quota_limit as usize, "exactly quota_limit tasks must succeed"); - assert_eq!( - denied, - 256usize - (quota_limit as usize), - "remaining tasks must be exactly denied without silently swallowing state" - ); -} - -#[tokio::test] -async fn light_fuzz_random_frame_sizes_preserve_quota_and_counter_consistency() { - let stats = Stats::new(); - let user = "reservation-fuzz-user"; - let quota_limit = 128u64; - let bytes_me2c = AtomicU64::new(0); - let mut seed = 0xC0FE_EE11_8899_2211u64; - - for conn in 0..512u64 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - let len = ((seed & 0x0f) + 1) as usize; - let payload = vec![0x5A; len]; - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(quota_limit), - 0, - &bytes_me2c, - 7300 + conn, - false, - false, - ) - .await; - - if let Err(err) = result { - assert!( - matches!(err, ProxyError::DataQuotaExceeded { .. }), - "fuzz run produced unexpected error variant: {err:?}" - ); - } - } - - let total = stats.get_user_total_octets(user); - assert!(total <= quota_limit); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); -} - -#[tokio::test] -async fn positive_soft_overshoot_allows_burst_inside_soft_cap_then_blocks() { - let stats = Stats::new(); - let user = "soft-cap-boundary-user"; - let bytes_me2c = AtomicU64::new(0); - let quota_limit = 10u64; - let overshoot = 3u64; - - stats.add_user_octets_from(user, 10); - - let mut writer_one = make_crypto_writer(tokio::io::sink()); - let mut frame_buf_one = Vec::new(); - let first = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3]), - }, - &mut writer_one, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf_one, - &stats, - user, - Some(quota_limit), - overshoot, - &bytes_me2c, - 7401, - false, - false, - ) - .await; - assert!(first.is_ok(), "soft-cap buffer should allow reaching limit+overshoot"); - assert_eq!(stats.get_user_total_octets(user), 13); - - let mut writer_two = make_crypto_writer(tokio::io::sink()); - let mut frame_buf_two = Vec::new(); - let second = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[9]), - }, - &mut writer_two, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf_two, - &stats, - user, - Some(quota_limit), - overshoot, - &bytes_me2c, - 7402, - false, - false, - ) - .await; - assert!(matches!(second, Err(ProxyError::DataQuotaExceeded { .. }))); - assert_eq!(stats.get_user_total_octets(user), 13); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 3); -} - -#[tokio::test] -async fn negative_soft_overshoot_rejects_when_payload_exceeds_remaining_soft_budget() { - let stats = Stats::new(); - let user = "soft-cap-remaining-user"; - let bytes_me2c = AtomicU64::new(0); - let quota_limit = 10u64; - let overshoot = 4u64; - - stats.add_user_octets_from(user, 12); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(quota_limit), - overshoot, - &bytes_me2c, - 7501, - false, - false, - ) - .await; - - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert_eq!(stats.get_user_total_octets(user), 12); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); -} - -#[tokio::test] -async fn negative_write_failure_rolls_back_reservation_under_soft_cap_mode() { - let stats = Stats::new(); - let user = "soft-cap-rollback-user"; - let bytes_me2c = AtomicU64::new(0); - let mut writer = make_crypto_writer(FailingWriter); - let mut frame_buf = Vec::new(); - - stats.add_user_octets_from(user, 9); - - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(10), - 8, - &bytes_me2c, - 7601, - false, - false, - ) - .await; - - assert!(matches!(result, Err(ProxyError::Io(_)))); - assert_eq!(stats.get_user_total_octets(user), 9); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_parallel_soft_cap_stress_never_exceeds_soft_limit() { - let stats = Arc::new(Stats::new()); - let user = "soft-cap-stress-user"; - let quota_limit = 40u64; - let overshoot = 5u64; - let soft_limit = quota_limit + overshoot; - let bytes_me2c = Arc::new(AtomicU64::new(0)); - let mut tasks = JoinSet::new(); - - for idx in 0..256u64 { - let user_owned = user.to_string(); - let stats_ref = Arc::clone(&stats); - let bytes_ref = Arc::clone(&bytes_me2c); - tasks.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x42]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats_ref.as_ref(), - &user_owned, - Some(quota_limit), - overshoot, - bytes_ref.as_ref(), - 7700 + idx, - false, - false, - ) - .await - }); - } - - while let Some(joined) = tasks.join_next().await { - match joined.expect("soft-cap stress task must not panic") { - Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }) => {} - Err(other) => panic!("unexpected error in soft-cap stress case: {other:?}"), - } - } - - let total = stats.get_user_total_octets(user); - assert!(total <= soft_limit, "soft-cap stress must never overshoot soft limit"); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); -} - -#[tokio::test] -async fn light_fuzz_soft_cap_matrix_keeps_counters_and_limits_consistent() { - let stats = Stats::new(); - let user = "soft-cap-fuzz-user"; - let bytes_me2c = AtomicU64::new(0); - let mut seed = 0x9E37_79B9_7F4A_7C15u64; - - for conn in 0..1024u64 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let quota_limit = 32 + (seed & 0x3f); - let overshoot = seed.rotate_left(13) & 0x0f; - let len = ((seed >> 3) & 0x07) + 1; - let payload = vec![0xA5; len as usize]; - let before = stats.get_user_total_octets(user); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(quota_limit), - overshoot, - &bytes_me2c, - 7800 + conn, - false, - false, - ) - .await; - - if let Err(ref err) = result { - assert!( - matches!(err, ProxyError::DataQuotaExceeded { .. }), - "soft-cap fuzz produced unexpected error variant: {err:?}" - ); - } - - let after = stats.get_user_total_octets(user); - let soft_limit = quota_limit.saturating_add(overshoot); - match result { - Ok(_) => { - assert_eq!(after, before.saturating_add(len)); - assert!(after <= soft_limit, "accepted write must stay within active soft cap"); - } - Err(_) => { - assert_eq!(after, before, "rejected write must not mutate quota state"); - } - } - assert_eq!( - bytes_me2c.load(Ordering::Relaxed), - after, - "soft-cap fuzz must keep counters synchronized" - ); - } -} - -#[tokio::test] -async fn positive_no_quota_limit_accumulates_data_octets_exactly() { - let stats = Stats::new(); - let user = "no-quota-user"; - let bytes_me2c = AtomicU64::new(0); - let mut expected = 0u64; - - for (idx, len) in [1usize, 2, 3, 5, 8, 13, 21].iter().copied().enumerate() { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let payload = vec![0x41; len]; - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - None, - 0, - &bytes_me2c, - 7900 + idx as u64, - false, - false, - ) - .await; - - assert!(result.is_ok()); - expected += len as u64; - } - - assert_eq!(stats.get_user_total_octets(user), expected); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), expected); -} - -#[tokio::test] -async fn negative_zero_quota_rejects_non_empty_payload() { - let stats = Stats::new(); - let user = "zero-quota-user"; - let bytes_me2c = AtomicU64::new(0); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xAA]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(0), - 0, - &bytes_me2c, - 8001, - false, - false, - ) - .await; - - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert_eq!(stats.get_user_total_octets(user), 0); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); -} - -#[tokio::test] -async fn edge_zero_length_payload_with_zero_quota_is_fail_closed() { - let stats = Stats::new(); - let user = "zero-len-zero-quota-user"; - let bytes_me2c = AtomicU64::new(0); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::new(), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(0), - 0, - &bytes_me2c, - 8002, - false, - false, - ) - .await; - - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert_eq!(stats.get_user_total_octets(user), 0); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); -} - -#[tokio::test] -async fn positive_ack_response_does_not_touch_quota_counters() { - let stats = Stats::new(); - let user = "ack-accounting-user"; - let bytes_me2c = AtomicU64::new(11); - stats.add_user_octets_to(user, 23); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - MeResponse::Ack(0x33445566), - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(24), - 0, - &bytes_me2c, - 8003, - true, - true, - ) - .await; - - assert!(result.is_ok()); - assert_eq!(stats.get_user_total_octets(user), 23); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 11); -} - -#[tokio::test] -async fn edge_close_response_is_accounting_noop() { - let stats = Stats::new(); - let user = "close-accounting-user"; - let bytes_me2c = AtomicU64::new(19); - stats.add_user_octets_to(user, 31); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - MeResponse::Close, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(40), - 3, - &bytes_me2c, - 8004, - false, - true, - ) - .await; - - assert!(result.is_ok()); - assert_eq!(stats.get_user_total_octets(user), 31); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 19); -} - -#[tokio::test] -async fn negative_preloaded_above_soft_cap_rejects_even_single_byte() { - let stats = Stats::new(); - let user = "preloaded-over-soft-cap-user"; - let bytes_me2c = AtomicU64::new(0); - let quota_limit = 20u64; - let overshoot = 2u64; - stats.add_user_octets_to(user, quota_limit + overshoot + 1); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(quota_limit), - overshoot, - &bytes_me2c, - 8005, - false, - false, - ) - .await; - - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); - assert_eq!(stats.get_user_total_octets(user), quota_limit + overshoot + 1); -} - -#[tokio::test] -async fn adversarial_fail_writer_path_never_desynchronizes_quota_accounting() { - let stats = Stats::new(); - let user = "partial-write-rollback-user"; - let bytes_me2c = AtomicU64::new(0); - let mut writer = make_crypto_writer(FailAfterBudgetWriter::new(7)); - let mut frame_buf = Vec::new(); - let payload_len = 16 * 1024u64; - - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![0x42; 16 * 1024]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(payload_len), - 0, - &bytes_me2c, - 8006, - false, - false, - ) - .await; - - let total_after = stats.get_user_total_octets(user); - let forensic_after = bytes_me2c.load(Ordering::Relaxed); - assert_eq!(forensic_after, total_after); - assert!( - total_after == 0 || total_after == payload_len, - "writer failure path must either roll back fully or commit exactly one payload" - ); - - // Regardless of whether I/O failure surfaced immediately or was deferred, - // accounting must remain fail-closed and prevent silent overshoot. - let mut writer_two = make_crypto_writer(tokio::io::sink()); - let mut frame_buf_two = Vec::new(); - let second = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x99]), - }, - &mut writer_two, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf_two, - &stats, - user, - Some(payload_len), - 0, - &bytes_me2c, - 8007, - false, - false, - ) - .await; - - if total_after == payload_len { - assert!(matches!(second, Err(ProxyError::DataQuotaExceeded { .. }))); - } else { - assert!(second.is_ok()); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_parallel_oversized_frames_fail_closed_without_counter_leak() { - let stats = Arc::new(Stats::new()); - let user = "parallel-fail-rollback-user"; - let bytes_me2c = Arc::new(AtomicU64::new(0)); - let mut tasks = JoinSet::new(); - - for idx in 0..256u64 { - let user_owned = user.to_string(); - let stats_ref = Arc::clone(&stats); - let bytes_ref = Arc::clone(&bytes_me2c); - tasks.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![0xEE; 12 * 1024]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats_ref.as_ref(), - &user_owned, - Some(512), - 0, - bytes_ref.as_ref(), - 8100 + idx, - false, - false, - ) - .await - }); - } - - while let Some(joined) = tasks.join_next().await { - let result = joined.expect("parallel fail writer task must not panic"); - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - } - - assert_eq!(stats.get_user_total_octets(user), 0); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); -} - -#[tokio::test] -async fn integration_mixed_data_ack_close_sequence_preserves_data_only_accounting() { - let stats = Stats::new(); - let user = "mixed-sequence-user"; - let bytes_me2c = AtomicU64::new(0); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - - let data_one = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(32), - 0, - &bytes_me2c, - 8201, - false, - false, - ) - .await; - assert!(data_one.is_ok()); - - let ack = process_me_writer_response( - MeResponse::Ack(0x0102_0304), - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(32), - 0, - &bytes_me2c, - 8202, - true, - true, - ) - .await; - assert!(ack.is_ok()); - - let data_two = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[4, 5]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(32), - 0, - &bytes_me2c, - 8203, - false, - true, - ) - .await; - assert!(data_two.is_ok()); - - let close = process_me_writer_response( - MeResponse::Close, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(32), - 0, - &bytes_me2c, - 8204, - false, - true, - ) - .await; - assert!(close.is_ok()); - - assert_eq!(stats.get_user_total_octets(user), 5); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 5); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_parallel_multi_user_quota_isolation_no_cross_user_leakage() { - let stats = Arc::new(Stats::new()); - let user_a = "quota-isolation-a"; - let user_b = "quota-isolation-b"; - let limit_a = 50u64; - let limit_b = 80u64; - let bytes_a = Arc::new(AtomicU64::new(0)); - let bytes_b = Arc::new(AtomicU64::new(0)); - - let mut tasks = JoinSet::new(); - for idx in 0..200u64 { - let stats_ref = Arc::clone(&stats); - let bytes_ref = Arc::clone(&bytes_a); - tasks.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xA1]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats_ref.as_ref(), - user_a, - Some(limit_a), - 0, - bytes_ref.as_ref(), - 8300 + idx, - false, - false, - ) - .await - }); - } - - for idx in 0..220u64 { - let stats_ref = Arc::clone(&stats); - let bytes_ref = Arc::clone(&bytes_b); - tasks.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xB2]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats_ref.as_ref(), - user_b, - Some(limit_b), - 0, - bytes_ref.as_ref(), - 8500 + idx, - false, - false, - ) - .await - }); - } - - while let Some(joined) = tasks.join_next().await { - let result = joined.expect("quota isolation task must not panic"); - assert!(result.is_ok() || matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - } - - assert_eq!(stats.get_user_total_octets(user_a), limit_a); - assert_eq!(stats.get_user_total_octets(user_b), limit_b); - assert_eq!(bytes_a.load(Ordering::Relaxed), limit_a); - assert_eq!(bytes_b.load(Ordering::Relaxed), limit_b); -} - -#[tokio::test] -async fn light_fuzz_mixed_me_responses_preserve_quota_and_counter_invariants() { - let stats = Stats::new(); - let user = "mixed-fuzz-user"; - let bytes_me2c = AtomicU64::new(0); - let quota_limit = 96u64; - let mut seed = 0xDEAD_BEEF_2026_0323u64; - - for idx in 0..2048u64 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let choice = (seed & 0x03) as u8; - let response = if choice == 0 { - MeResponse::Ack((seed >> 8) as u32) - } else if choice == 1 { - MeResponse::Close - } else { - let len = ((seed >> 16) & 0x07) as usize; - let mut payload = vec![0u8; len]; - payload.fill((seed & 0xff) as u8); - MeResponse::Data { - flags: 0, - data: Bytes::from(payload), - } - }; - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let result = process_me_writer_response( - response, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - user, - Some(quota_limit), - 0, - &bytes_me2c, - 8800 + idx, - (idx & 1) == 0, - (idx & 2) == 0, - ) - .await; - - if let Err(err) = result { - assert!( - matches!(err, ProxyError::DataQuotaExceeded { .. }), - "mixed fuzz produced unexpected error variant: {err:?}" - ); - } - - let total = stats.get_user_total_octets(user); - assert!( - total <= quota_limit, - "mixed fuzz must keep usage at or below quota limit" - ); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); - } -} \ No newline at end of file diff --git a/src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs b/src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs deleted file mode 100644 index e4d0c6e..0000000 --- a/src/proxy/tests/middle_relay_quota_reservation_extreme_security_tests.rs +++ /dev/null @@ -1,399 +0,0 @@ -use super::*; -use crate::crypto::{AesCtr, SecureRandom}; -use crate::stats::Stats; -use crate::stream::CryptoWriter; -use bytes::Bytes; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, Mutex, OnceLock}; -use tokio::sync::Mutex as AsyncMutex; -use tokio::task::JoinSet; -use tokio::time::{Duration, timeout}; - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -fn lookup_counter_test_lock() -> &'static Mutex<()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| Mutex::new(())) -} - -#[tokio::test] -async fn positive_prefetched_cross_mode_lock_multi_frame_accounting_is_exact() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Stats::new(); - let user = format!("quota-extreme-positive-{}", std::process::id()); - let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - for idx in 0..12u64 { - let payload = vec![0x5A; ((idx % 4) + 1) as usize]; - let result = process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(512), - 0, - Some(&lock), - &bytes_me2c, - 31_000 + idx, - false, - false, - ) - .await; - - assert!(result.is_ok()); - } - - assert_eq!( - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), - 0, - "prefetched lock path must avoid hot-path registry lookups" - ); - assert_eq!( - stats.get_user_total_octets(&user), - bytes_me2c.load(Ordering::Relaxed), - "forensics and quota accounting must remain synchronized" - ); -} - -#[tokio::test] -async fn negative_held_prefetched_lock_blocks_writer_without_accounting_mutation() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Stats::new(); - let user = format!("quota-extreme-negative-{}", std::process::id()); - let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold lock before calling ME->C writer"); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let blocked = timeout( - Duration::from_millis(25), - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[1, 2, 3]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(64), - 0, - Some(&lock), - &bytes_me2c, - 31_100, - false, - false, - ), - ) - .await; - - assert!(blocked.is_err()); - assert_eq!(stats.get_user_total_octets(&user), 0); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); - - drop(held_guard); -} - -#[tokio::test] -async fn edge_zero_quota_and_zero_payload_is_fail_closed() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Stats::new(); - let user = format!("quota-extreme-edge-{}", std::process::id()); - let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - let result = process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::new(), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(0), - 0, - Some(&lock), - &bytes_me2c, - 31_200, - false, - false, - ) - .await; - - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert_eq!(stats.get_user_total_octets(&user), 0); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_blackhat_parallel_quota_race_never_overshoots_soft_cap() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Arc::new(Stats::new()); - let user = format!("quota-extreme-blackhat-{}", std::process::id()); - let quota = 80u64; - let overshoot = 7u64; - let soft_limit = quota + overshoot; - let lock = Arc::new(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user)); - let bytes_me2c = Arc::new(AtomicU64::new(0)); - - let mut set = JoinSet::new(); - for idx in 0..256u64 { - let stats = Arc::clone(&stats); - let user = user.clone(); - let lock = Arc::clone(&lock); - let bytes_me2c = Arc::clone(&bytes_me2c); - - set.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let len = ((idx % 5) + 1) as usize; - let payload = vec![0xAA; len]; - - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats.as_ref(), - &user, - Some(quota), - overshoot, - Some(&lock), - bytes_me2c.as_ref(), - 31_300 + idx, - false, - false, - ) - .await - }); - } - - while let Some(done) = set.join_next().await { - match done.expect("task must not panic") { - Ok(_) | Err(ProxyError::DataQuotaExceeded { .. }) => {} - Err(other) => panic!("unexpected error variant under black-hat race: {other:?}"), - } - } - - let total = stats.get_user_total_octets(&user); - assert!( - total <= soft_limit, - "parallel adversarial race must stay under soft cap" - ); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), total); -} - -#[tokio::test] -async fn integration_without_prefetched_lock_uses_registry_lookup_path() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Stats::new(); - let user = format!("quota-extreme-integration-{}", std::process::id()); - crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let bytes_me2c = AtomicU64::new(0); - - for idx in 0..3u64 { - let result = process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x41]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(16), - 0, - None, - &bytes_me2c, - 31_400 + idx, - false, - false, - ) - .await; - - assert!(result.is_ok()); - } - - assert_eq!( - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), - 3, - "control path should perform one lock-registry lookup per call" - ); -} - -#[tokio::test] -async fn light_fuzz_quota_matrix_preserves_fail_closed_accounting() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Stats::new(); - let user = format!("quota-extreme-fuzz-{}", std::process::id()); - let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let bytes_me2c = AtomicU64::new(0); - let mut seed = 0xA11C_55EE_2026_0323u64; - - for idx in 0..512u64 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let quota = 24 + (seed & 0x3f); - let overshoot = (seed >> 13) & 0x0f; - let len = ((seed >> 19) & 0x07) + 1; - - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - let before = stats.get_user_total_octets(&user); - - let result = process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![0x11; len as usize]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - &stats, - &user, - Some(quota), - overshoot, - Some(&lock), - &bytes_me2c, - 31_500 + idx, - false, - false, - ) - .await; - - let after = stats.get_user_total_octets(&user); - if result.is_ok() { - assert!(after >= before); - } else { - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert_eq!(after, before); - } - assert_eq!(bytes_me2c.load(Ordering::Relaxed), after); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_prefetched_lock_high_fanout_exact_quota_success_count() { - let _guard = lookup_counter_test_lock() - .lock() - .unwrap_or_else(|poison| poison.into_inner()); - - let stats = Arc::new(Stats::new()); - let user = format!("quota-extreme-stress-{}", std::process::id()); - let quota = 96u64; - let lock: Arc> = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let bytes_me2c = Arc::new(AtomicU64::new(0)); - - crate::proxy::quota_lock_registry::reset_cross_mode_quota_user_lock_lookup_count_for_tests(); - - let mut set = JoinSet::new(); - for idx in 0..384u64 { - let stats = Arc::clone(&stats); - let user = user.clone(); - let lock = Arc::clone(&lock); - let bytes_me2c = Arc::clone(&bytes_me2c); - - set.spawn(async move { - let mut writer = make_crypto_writer(tokio::io::sink()); - let mut frame_buf = Vec::new(); - process_me_writer_response_with_cross_mode_lock( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0xFF]), - }, - &mut writer, - ProtoTag::Intermediate, - &SecureRandom::new(), - &mut frame_buf, - stats.as_ref(), - &user, - Some(quota), - 0, - Some(&lock), - bytes_me2c.as_ref(), - 31_600 + idx, - false, - false, - ) - .await - }); - } - - let mut success = 0usize; - while let Some(done) = set.join_next().await { - match done.expect("task must not panic") { - Ok(_) => success += 1, - Err(ProxyError::DataQuotaExceeded { .. }) => {} - Err(other) => panic!("unexpected error variant in stress fanout: {other:?}"), - } - } - - assert_eq!(success, quota as usize); - assert_eq!(stats.get_user_total_octets(&user), quota); - assert_eq!(bytes_me2c.load(Ordering::Relaxed), quota); - assert_eq!( - crate::proxy::quota_lock_registry::cross_mode_quota_user_lock_lookup_count_for_user_for_tests(&user), - 0, - "stress prefetched path must not use lock registry lookups" - ); -} diff --git a/src/proxy/tests/middle_relay_security_tests.rs b/src/proxy/tests/middle_relay_security_tests.rs deleted file mode 100644 index 1d3b736..0000000 --- a/src/proxy/tests/middle_relay_security_tests.rs +++ /dev/null @@ -1,2517 +0,0 @@ -use super::*; -use crate::config::{GeneralConfig, MeRouteNoWriterMode, MeSocksKdfPolicy, MeWriterPickMode}; -use crate::crypto::AesCtr; -use crate::crypto::SecureRandom; -use crate::network::probe::NetworkDecision; -use crate::proxy::handshake::HandshakeSuccess; -use crate::proxy::route_mode::{RelayRouteMode, RouteRuntimeController}; -use crate::stats::Stats; -use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer}; -use crate::transport::middle_proxy::MePool; -use bytes::Bytes; -use rand::rngs::StdRng; -use rand::{RngExt, SeedableRng}; -use std::collections::{HashMap, HashSet}; -use std::net::SocketAddr; -use std::sync::Arc; -use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; -use std::sync::Mutex; -use std::thread; -use tokio::io::AsyncReadExt; -use tokio::io::AsyncWriteExt; -use tokio::io::duplex; -use tokio::sync::Barrier; -use tokio::time::{Duration as TokioDuration, timeout}; - -fn make_pooled_payload(data: &[u8]) -> PooledBuffer { - let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4)); - let mut payload = pool.get(); - payload.resize(data.len(), 0); - payload[..data.len()].copy_from_slice(data); - payload -} - -fn make_pooled_payload_from(pool: &Arc, data: &[u8]) -> PooledBuffer { - let mut payload = pool.get(); - payload.resize(data.len(), 0); - payload[..data.len()].copy_from_slice(data); - payload -} - -#[test] -fn should_yield_sender_only_on_budget_with_backlog() { - assert!(!should_yield_c2me_sender(0, true)); - assert!(!should_yield_c2me_sender( - C2ME_SENDER_FAIRNESS_BUDGET - 1, - true - )); - assert!(!should_yield_c2me_sender( - C2ME_SENDER_FAIRNESS_BUDGET, - false - )); - assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true)); -} - -#[tokio::test] -async fn enqueue_c2me_command_uses_try_send_fast_path() { - let (tx, mut rx) = mpsc::channel::(2); - enqueue_c2me_command( - &tx, - C2MeCommand::Data { - payload: make_pooled_payload(&[1, 2, 3]), - flags: 0, - }, - ) - .await - .unwrap(); - - let recv = timeout(TokioDuration::from_millis(50), rx.recv()) - .await - .unwrap() - .unwrap(); - match recv { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[1, 2, 3]); - assert_eq!(flags, 0); - } - C2MeCommand::Close => panic!("unexpected close command"), - } -} - -#[tokio::test] -async fn enqueue_c2me_command_falls_back_to_send_when_queue_is_full() { - let (tx, mut rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[9]), - flags: 9, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let producer = tokio::spawn(async move { - enqueue_c2me_command( - &tx2, - C2MeCommand::Data { - payload: make_pooled_payload(&[7, 7]), - flags: 7, - }, - ) - .await - .unwrap(); - }); - - let _ = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap(); - producer.await.unwrap(); - - let recv = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .unwrap(); - match recv { - C2MeCommand::Data { payload, flags } => { - assert_eq!(payload.as_ref(), &[7, 7]); - assert_eq!(flags, 7); - } - C2MeCommand::Close => panic!("unexpected close command"), - } -} - -#[tokio::test] -async fn enqueue_c2me_command_closed_channel_recycles_payload() { - let pool = Arc::new(BufferPool::with_config(64, 4)); - let payload = make_pooled_payload_from(&pool, &[1, 2, 3, 4]); - let (tx, rx) = mpsc::channel::(1); - drop(rx); - - let result = enqueue_c2me_command(&tx, C2MeCommand::Data { payload, flags: 0 }).await; - - assert!(result.is_err(), "closed queue must fail enqueue"); - drop(result); - assert!( - pool.stats().pooled >= 1, - "payload must return to pool when enqueue fails on closed channel" - ); -} - -#[tokio::test] -async fn enqueue_c2me_command_full_then_closed_recycles_waiting_payload() { - let pool = Arc::new(BufferPool::with_config(64, 4)); - let (tx, rx) = mpsc::channel::(1); - - tx.send(C2MeCommand::Data { - payload: make_pooled_payload_from(&pool, &[9]), - flags: 1, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let pool2 = pool.clone(); - let blocked_send = tokio::spawn(async move { - enqueue_c2me_command( - &tx2, - C2MeCommand::Data { - payload: make_pooled_payload_from(&pool2, &[7, 7, 7]), - flags: 2, - }, - ) - .await - }); - - tokio::time::sleep(TokioDuration::from_millis(10)).await; - drop(rx); - - let result = timeout(TokioDuration::from_secs(1), blocked_send) - .await - .expect("blocked send task must finish") - .expect("blocked send task must not panic"); - - assert!( - result.is_err(), - "closing receiver while sender is blocked must fail enqueue" - ); - drop(result); - assert!( - pool.stats().pooled >= 2, - "both queued and blocked payloads must return to pool after channel close" - ); -} - -#[tokio::test] -async fn enqueue_c2me_command_full_queue_times_out_without_receiver_progress() { - let (tx, _rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[1]), - flags: 0, - }) - .await - .unwrap(); - - let started = Instant::now(); - let result = enqueue_c2me_command( - &tx, - C2MeCommand::Data { - payload: make_pooled_payload(&[2, 2]), - flags: 1, - }, - ) - .await; - - assert!( - result.is_err(), - "enqueue must fail when queue stays full beyond bounded timeout" - ); - assert!( - started.elapsed() < TokioDuration::from_millis(400), - "full-queue timeout must resolve promptly" - ); -} - -#[test] -fn desync_dedup_cache_is_bounded() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - assert!( - should_emit_full_desync(key, false, now), - "unique keys up to cap must be tracked" - ); - } - - assert!( - should_emit_full_desync(u64::MAX, false, now), - "new key above cap must emit once after bounded eviction for forensic visibility" - ); - - assert!( - !should_emit_full_desync(u64::MAX, false, now), - "already tracked key inside dedup window must stay suppressed" - ); -} - -#[test] -fn quota_user_lock_cache_reuses_entry_for_same_user() { - let _guard = super::quota_user_lock_test_scope(); - - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let a = quota_user_lock("quota-user-a"); - let b = quota_user_lock("quota-user-a"); - assert!(Arc::ptr_eq(&a, &b), "same user must reuse same quota lock"); -} - -#[test] -fn quota_user_lock_cache_is_bounded_under_unique_churn() { - let _guard = super::quota_user_lock_test_scope(); - - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - for idx in 0..(QUOTA_USER_LOCKS_MAX + 128) { - let user = format!("quota-user-{idx}"); - let lock = quota_user_lock(&user); - drop(lock); - } - - assert!( - map.len() <= QUOTA_USER_LOCKS_MAX, - "quota lock cache must stay within configured bound" - ); -} - -#[test] -fn quota_user_lock_cache_saturation_returns_stable_overflow_lock_without_growth() { - let _guard = super::quota_user_lock_test_scope(); - - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - for attempt in 0..8u32 { - map.clear(); - - let prefix = format!("quota-held-user-{}-{attempt}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - let user = format!("{prefix}-{idx}"); - retained.push(quota_user_lock(&user)); - } - - if map.len() != QUOTA_USER_LOCKS_MAX { - drop(retained); - continue; - } - - let overflow_user = format!("quota-overflow-user-{}-{attempt}", std::process::id()); - let overflow_a = quota_user_lock(&overflow_user); - let overflow_b = quota_user_lock(&overflow_user); - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "overflow acquisition must not grow cache past hard limit" - ); - assert!( - map.get(&overflow_user).is_none(), - "overflow path should not cache new user lock when map is saturated and all entries are retained" - ); - assert!( - Arc::ptr_eq(&overflow_a, &overflow_b), - "overflow user lock should use deterministic striping under saturation" - ); - - drop(retained); - return; - } - - panic!("unable to observe stable saturated lock-cache precondition after bounded retries"); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_quota_race_under_lock_cache_saturation_still_allows_only_one_winner() { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - let user = format!("quota-saturated-user-{idx}"); - retained.push(quota_user_lock(&user)); - } - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "precondition: cache must be saturated for overflow-user race test" - ); - - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - let user = "gap-t04-saturated-lock-race-user"; - let barrier = Arc::new(Barrier::new(2)); - - let one = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x55, 9101, barrier.clone()); - let two = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x66, 9102, barrier); - let (r1, r2) = tokio::join!(one, two); - - assert!( - matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) - && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "both racers must resolve cleanly without unexpected errors" - ); - assert!( - matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), - "at least one racer must be quota-rejected even when lock cache is saturated" - ); - assert_eq!( - stats.get_user_total_octets(user), - 1, - "saturated lock cache must not permit double-success quota overshoot" - ); - - drop(retained); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_quota_race_under_lock_cache_saturation_never_allows_double_success() { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - let user = format!("quota-saturated-stress-holder-{idx}"); - retained.push(quota_user_lock(&user)); - } - - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - for round in 0..128u64 { - let user = format!("gap-t04-saturated-race-round-{round}"); - let barrier = Arc::new(Barrier::new(2)); - - let one = run_quota_race_attempt( - &stats, - &bytes_me2c, - &user, - 0x71, - 12_000 + round, - barrier.clone(), - ); - let two = run_quota_race_attempt(&stats, &bytes_me2c, &user, 0x72, 13_000 + round, barrier); - - let (r1, r2) = tokio::join!(one, two); - assert!( - matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) - && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "round {round}: racers must resolve cleanly" - ); - assert!( - matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), - "round {round}: at least one racer must be quota-rejected" - ); - assert_eq!( - stats.get_user_total_octets(&user), - 1, - "round {round}: saturated cache must still enforce exactly one forwarded byte" - ); - } - - drop(retained); -} - -#[test] -fn adversarial_forensics_trace_id_should_not_alias_conn_id() { - let now = Instant::now(); - let trace_id = 0x1122_3344_5566_7788; - let conn_id = 0x8877_6655_4433_2211; - let state = RelayForensicsState { - trace_id, - conn_id, - user: "trace-user".to_string(), - peer: "198.51.100.17:443".parse().unwrap(), - peer_hash: 0x8877_6655_4433_2211, - started_at: now, - bytes_c2me: 0, - bytes_me2c: Arc::new(AtomicU64::new(0)), - desync_all_full: false, - }; - - assert_ne!( - state.trace_id, state.conn_id, - "security expectation: trace correlation should be independent of connection identity" - ); - assert_eq!(state.trace_id, trace_id); - assert_eq!(state.conn_id, conn_id); -} - -#[tokio::test] -async fn abridged_ack_uses_big_endian_confirm_bytes_after_decryption() { - let (mut writer_side, reader_side) = duplex(8); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(reader_side, AesCtr::new(&key, iv), 8 * 1024); - - write_client_ack(&mut writer, ProtoTag::Abridged, 0x11_22_33_44) - .await - .expect("ack write must succeed"); - - let mut observed = [0u8; 4]; - writer_side - .read_exact(&mut observed) - .await - .expect("ack bytes must be readable"); - let mut decryptor = AesCtr::new(&key, iv); - let decrypted = decryptor.decrypt(&observed); - - assert_eq!( - decrypted, - 0x11_22_33_44u32.to_be_bytes(), - "abridged ACK should encode confirm bytes in big-endian order" - ); -} - -#[test] -fn desync_dedup_full_cache_churn_stays_suppressed() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - assert!(should_emit_full_desync(key, false, now)); - } - - for offset in 0..2048u64 { - let emitted = should_emit_full_desync(u64::MAX - offset, false, now); - if offset == 0 { - assert!( - emitted, - "first full-cache newcomer should emit for forensic visibility" - ); - } else { - assert!( - !emitted, - "full-cache newcomer churn inside emit interval must stay suppressed" - ); - } - } -} - -#[test] -fn dedup_hash_is_stable_for_same_input_within_process() { - let sample = ( - "scope_user", - hash_ip("198.51.100.7".parse().unwrap()), - ProtoTag::Secure, - ); - let first = hash_value(&sample); - let second = hash_value(&sample); - assert_eq!( - first, second, - "dedup hash must be stable within a process for cache lookups" - ); -} - -#[test] -fn dedup_hash_resists_simple_collision_bursts_for_peer_ip_space() { - let mut seen = HashSet::new(); - - for octet in 1u16..=2048 { - let third = ((octet / 256) & 0xff) as u8; - let fourth = (octet & 0xff) as u8; - let ip = IpAddr::V4(std::net::Ipv4Addr::new(198, 51, third, fourth)); - let key = hash_value(&( - "scope_user", - hash_ip(ip), - ProtoTag::Secure, - DESYNC_ERROR_CLASS, - )); - seen.insert(key); - } - - assert_eq!( - seen.len(), - 2048, - "adversarial peer-IP burst should not collapse dedup keys via trivial collisions" - ); -} - -#[test] -fn light_fuzz_dedup_hash_collision_rate_stays_negligible() { - let mut rng = StdRng::seed_from_u64(0x9E37_79B9_A1B2_C3D4); - let mut seen = HashSet::new(); - let samples = 8192usize; - - for _ in 0..samples { - let user_seed: u64 = rng.random(); - let peer_seed: u64 = rng.random(); - let proto = if (peer_seed & 1) == 0 { - ProtoTag::Secure - } else { - ProtoTag::Intermediate - }; - let key = hash_value(&(user_seed, peer_seed, proto, DESYNC_ERROR_CLASS)); - seen.insert(key); - } - - let collisions = samples - seen.len(); - assert!( - collisions <= 1, - "light fuzz collision count should remain negligible for 64-bit dedup keys" - ); -} - -#[test] -fn stress_desync_dedup_churn_keeps_cache_hard_bounded() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let now = Instant::now(); - let total = DESYNC_DEDUP_MAX_ENTRIES + 8192; - - let mut emitted_count = 0usize; - for key in 0..total as u64 { - let emitted = should_emit_full_desync(key, false, now); - if emitted { - emitted_count += 1; - } - } - - assert_eq!( - emitted_count, - DESYNC_DEDUP_MAX_ENTRIES + 1, - "after capacity is reached, same-tick newcomer churn must be rate-limited" - ); - - let len = DESYNC_DEDUP - .get() - .expect("dedup cache must be initialized by stress run") - .len(); - assert!( - len <= DESYNC_DEDUP_MAX_ENTRIES, - "dedup cache must stay bounded under stress churn" - ); -} - -#[test] -fn full_cache_newcomer_emission_is_rate_limited_but_periodic() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - // Same-tick newcomer storm: only the first should emit full forensic record. - let mut burst_emits = 0usize; - for i in 0..1024u64 { - if should_emit_full_desync(10_000_000 + i, false, base_now) { - burst_emits += 1; - } - } - assert_eq!( - burst_emits, 1, - "full-cache newcomer burst must be bounded to a single full emit per interval" - ); - - // After each interval elapses, one newcomer may emit again. - for step in 1..=6u64 { - let t = base_now + DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL * step as u32; - assert!( - should_emit_full_desync(20_000_000 + step, false, t), - "full-cache newcomer should re-emit once interval has elapsed" - ); - assert!( - !should_emit_full_desync(30_000_000 + step, false, t), - "additional newcomers in the same interval tick must remain suppressed" - ); - } -} - -#[test] -fn full_cache_mode_override_emits_every_event() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let now = Instant::now(); - for i in 0..10_000u64 { - assert!( - should_emit_full_desync(100_000_000 + i, true, now), - "desync_all_full override must bypass dedup and rate-limit suppression" - ); - } -} - -#[test] -fn report_desync_stats_follow_rate_limited_full_cache_policy() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - let stats = Stats::new(); - let mut state = make_forensics_state(); - state.started_at = base_now; - - for i in 0..128u64 { - state.peer_hash = 0xABC0_0000_0000_0000u64 ^ i; - let _ = report_desync_frame_too_large( - &state, - ProtoTag::Secure, - 3, - 1024, - 4096, - Some([0x16, 0x03, 0x03, 0x00]), - &stats, - ); - } - - assert_eq!( - stats.get_desync_total(), - 128, - "every detected desync must increment total counter" - ); - assert_eq!( - stats.get_desync_full_logged(), - 1, - "same-interval full-cache newcomer storm must allow only one full forensic emit" - ); - assert_eq!( - stats.get_desync_suppressed(), - 127, - "remaining same-interval full-cache newcomer events must be suppressed" - ); - - // After one full interval in real wall clock, a newcomer should emit again. - thread::sleep(DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL + TokioDuration::from_millis(20)); - state.peer_hash = 0xDEAD_BEEF_DEAD_BEEFu64; - let _ = report_desync_frame_too_large( - &state, - ProtoTag::Secure, - 4, - 1024, - 4097, - Some([0x16, 0x03, 0x03, 0x01]), - &stats, - ); - - assert_eq!( - stats.get_desync_full_logged(), - 2, - "full forensic emission must recover after rate-limit interval" - ); -} - -#[test] -fn concurrent_full_cache_newcomer_storm_is_single_emit_per_interval() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - let emits = Arc::new(AtomicUsize::new(0)); - let mut workers = Vec::new(); - for worker_id in 0..32u64 { - let emits = Arc::clone(&emits); - workers.push(thread::spawn(move || { - for i in 0..512u64 { - let key = 0x7000_0000_0000_0000u64 ^ (worker_id << 20) ^ i; - if should_emit_full_desync(key, false, base_now) { - emits.fetch_add(1, Ordering::Relaxed); - } - } - })); - } - - for worker in workers { - worker.join().expect("worker thread must not panic"); - } - - assert_eq!( - emits.load(Ordering::Relaxed), - 1, - "concurrent same-interval full-cache storm must allow only one full forensic emit" - ); -} - -#[test] -fn light_fuzz_full_cache_rate_limit_oracle_matches_model() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - let mut rng = StdRng::seed_from_u64(0xD15EA5E5_F00DBAAD); - let mut model_last_emit: Option = None; - - for i in 0..4096u64 { - let jitter_ms: u64 = rng.random_range(0..=3000); - let t = base_now + TokioDuration::from_millis(jitter_ms); - let key = 0x55AA_0000_0000_0000u64 ^ i ^ rng.random::(); - let actual = should_emit_full_desync(key, false, t); - - let expected = match model_last_emit { - None => { - model_last_emit = Some(t); - true - } - Some(last) => { - match t.checked_duration_since(last) { - Some(elapsed) if elapsed >= DESYNC_FULL_CACHE_EMIT_MIN_INTERVAL => { - model_last_emit = Some(t); - true - } - Some(_) => false, - None => { - // Match production fail-open behavior for non-monotonic synthetic input. - model_last_emit = Some(t); - true - } - } - } - }; - - assert_eq!( - actual, expected, - "full-cache rate-limit gate diverged from reference model under light fuzz" - ); - } -} - -#[test] -fn full_cache_gate_lock_poison_is_fail_closed_without_panic() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - // Poison the full-cache gate lock intentionally. - let gate = DESYNC_FULL_CACHE_LAST_EMIT_AT.get_or_init(|| Mutex::new(None)); - let _ = std::panic::catch_unwind(|| { - let _lock = gate - .lock() - .expect("gate lock must be lockable before poison"); - panic!("intentional gate poison for fail-closed regression"); - }); - - let emitted = should_emit_full_desync(0xFACE_0000_0000_0001, false, base_now); - assert!( - !emitted, - "poisoned full-cache gate must fail-closed (suppress) instead of panic or fail-open" - ); - assert!( - dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES, - "dedup cache must remain bounded even when gate lock is poisoned" - ); -} - -#[test] -fn full_cache_non_monotonic_time_emits_and_resets_gate_safely() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - // First event seeds the gate. - assert!(should_emit_full_desync( - 0xABCD_0000_0000_0001, - false, - base_now + TokioDuration::from_millis(900) - )); - - // Synthetic earlier timestamp must not panic; it should fail-open and reset gate. - assert!(should_emit_full_desync( - 0xABCD_0000_0000_0002, - false, - base_now + TokioDuration::from_millis(100) - )); - - // Same instant again remains suppressed after reset. - assert!(!should_emit_full_desync( - 0xABCD_0000_0000_0003, - false, - base_now + TokioDuration::from_millis(100) - )); -} - -#[test] -fn desync_dedup_full_cache_inserts_new_key_with_bounded_single_key_churn() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let dedup = DESYNC_DEDUP.get_or_init(DashMap::new); - let base_now = Instant::now(); - - // Fill with fresh entries so stale-pruning does not apply. - for key in 0..DESYNC_DEDUP_MAX_ENTRIES as u64 { - dedup.insert(key, base_now - TokioDuration::from_millis(10)); - } - - let before_keys: std::collections::HashSet = dedup.iter().map(|e| *e.key()).collect(); - - let newcomer_key = u64::MAX; - let emitted = should_emit_full_desync(newcomer_key, false, base_now); - assert!( - emitted, - "new entry under full fresh cache must emit after bounded eviction" - ); - assert!( - dedup.get(&newcomer_key).is_some(), - "new key must be inserted after bounded eviction" - ); - - let after_keys: std::collections::HashSet = dedup.iter().map(|e| *e.key()).collect(); - let removed_count = before_keys.difference(&after_keys).count(); - let added_count = after_keys.difference(&before_keys).count(); - - assert_eq!( - removed_count, 1, - "full-cache insertion must evict exactly one prior key" - ); - assert_eq!( - added_count, 1, - "full-cache insertion must add exactly one newcomer key" - ); - assert!( - dedup.len() <= DESYNC_DEDUP_MAX_ENTRIES, - "dedup cache must remain hard-bounded after full-cache churn" - ); -} - -#[test] -fn light_fuzz_desync_dedup_temporal_gate_behavior_is_stable() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("desync dedup test lock must be available"); - clear_desync_dedup_for_testing(); - - let key = 0xC0DE_CAFE_u64; - let start = Instant::now(); - - assert!( - should_emit_full_desync(key, false, start), - "first event for key must emit full forensic record" - ); - - // Deterministic pseudo-random time deltas around dedup window edge. - let mut s: u64 = 0x1234_5678_9ABC_DEF0; - for _ in 0..2048 { - s ^= s << 7; - s ^= s >> 9; - s ^= s << 8; - - let delta_ms = s % (DESYNC_DEDUP_WINDOW.as_millis() as u64 * 2 + 1); - let now = start + TokioDuration::from_millis(delta_ms); - let emitted = should_emit_full_desync(key, false, now); - - if delta_ms < DESYNC_DEDUP_WINDOW.as_millis() as u64 { - assert!( - !emitted, - "events inside dedup window must remain suppressed" - ); - } else { - // Once window elapsed for this key, at least one sample should re-emit and refresh. - if emitted { - return; - } - } - } - - panic!("expected at least one post-window sample to re-emit forensic record"); -} - -fn make_forensics_state() -> RelayForensicsState { - RelayForensicsState { - trace_id: 1, - conn_id: 2, - user: "test-user".to_string(), - peer: "127.0.0.1:50000".parse::().unwrap(), - peer_hash: 3, - started_at: Instant::now(), - bytes_c2me: 0, - bytes_me2c: Arc::new(AtomicU64::new(0)), - desync_all_full: false, - } -} - -fn make_crypto_reader(reader: R) -> CryptoReader -where - R: tokio::io::AsyncRead + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoReader::new(reader, AesCtr::new(&key, iv)) -} - -fn make_crypto_writer(writer: W) -> CryptoWriter -where - W: tokio::io::AsyncWrite + Unpin, -{ - let key = [0u8; 32]; - let iv = 0u128; - CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024) -} - -async fn make_me_pool_for_abort_test(stats: Arc) -> Arc { - let general = GeneralConfig::default(); - - MePool::new( - None, - vec![1u8; 32], - None, - false, - None, - Vec::new(), - 1, - None, - 12, - 1200, - HashMap::new(), - HashMap::new(), - None, - NetworkDecision::default(), - None, - Arc::new(SecureRandom::new()), - stats, - general.me_keepalive_enabled, - general.me_keepalive_interval_secs, - general.me_keepalive_jitter_secs, - general.me_keepalive_payload_random, - general.rpc_proxy_req_every, - general.me_warmup_stagger_enabled, - general.me_warmup_step_delay_ms, - general.me_warmup_step_jitter_ms, - general.me_reconnect_max_concurrent_per_dc, - general.me_reconnect_backoff_base_ms, - general.me_reconnect_backoff_cap_ms, - general.me_reconnect_fast_retry_count, - general.me_single_endpoint_shadow_writers, - general.me_single_endpoint_outage_mode_enabled, - general.me_single_endpoint_outage_disable_quarantine, - general.me_single_endpoint_outage_backoff_min_ms, - general.me_single_endpoint_outage_backoff_max_ms, - general.me_single_endpoint_shadow_rotate_every_secs, - general.me_floor_mode, - general.me_adaptive_floor_idle_secs, - general.me_adaptive_floor_min_writers_single_endpoint, - general.me_adaptive_floor_min_writers_multi_endpoint, - general.me_adaptive_floor_recover_grace_secs, - general.me_adaptive_floor_writers_per_core_total, - general.me_adaptive_floor_cpu_cores_override, - general.me_adaptive_floor_max_extra_writers_single_per_core, - general.me_adaptive_floor_max_extra_writers_multi_per_core, - general.me_adaptive_floor_max_active_writers_per_core, - general.me_adaptive_floor_max_warm_writers_per_core, - general.me_adaptive_floor_max_active_writers_global, - general.me_adaptive_floor_max_warm_writers_global, - general.hardswap, - general.me_pool_drain_ttl_secs, - general.me_instadrain, - general.me_pool_drain_threshold, - general.me_pool_drain_soft_evict_enabled, - general.me_pool_drain_soft_evict_grace_secs, - general.me_pool_drain_soft_evict_per_writer, - general.me_pool_drain_soft_evict_budget_per_core, - general.me_pool_drain_soft_evict_cooldown_ms, - general.effective_me_pool_force_close_secs(), - general.me_pool_min_fresh_ratio, - general.me_hardswap_warmup_delay_min_ms, - general.me_hardswap_warmup_delay_max_ms, - general.me_hardswap_warmup_extra_passes, - general.me_hardswap_warmup_pass_backoff_base_ms, - general.me_bind_stale_mode, - general.me_bind_stale_ttl_secs, - general.me_secret_atomic_snapshot, - general.me_deterministic_writer_sort, - MeWriterPickMode::default(), - general.me_writer_pick_sample_size, - MeSocksKdfPolicy::default(), - general.me_writer_cmd_channel_capacity, - general.me_route_channel_capacity, - general.me_route_backpressure_base_timeout_ms, - general.me_route_backpressure_high_timeout_ms, - general.me_route_backpressure_high_watermark_pct, - general.me_reader_route_data_wait_ms, - general.me_health_interval_ms_unhealthy, - general.me_health_interval_ms_healthy, - general.me_warn_rate_limit_ms, - MeRouteNoWriterMode::default(), - general.me_route_no_writer_wait_ms, - general.me_route_inline_recovery_attempts, - general.me_route_inline_recovery_wait_ms, - ) -} - -fn encrypt_for_reader(plaintext: &[u8]) -> Vec { - let key = [0u8; 32]; - let iv = 0u128; - let mut cipher = AesCtr::new(&key, iv); - cipher.encrypt(plaintext) -} - -#[tokio::test] -async fn read_client_payload_times_out_on_header_stall() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - let (reader, _writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let result = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - 1024, - TokioDuration::from_millis(25), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut), - "stalled header read must time out" - ); -} - -#[tokio::test] -async fn read_client_payload_times_out_on_payload_stall() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - let (reader, mut writer) = duplex(1024); - let encrypted_len = encrypt_for_reader(&[8, 0, 0, 0]); - writer.write_all(&encrypted_len).await.unwrap(); - - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let result = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - 1024, - TokioDuration::from_millis(25), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::Io(ref e)) if e.kind() == std::io::ErrorKind::TimedOut), - "stalled payload body read must time out" - ); -} - -#[tokio::test] -async fn read_client_payload_large_intermediate_frame_is_exact() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(262_144); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload_len = buffer_pool.buffer_size().saturating_mul(3).max(65_537); - let mut plaintext = Vec::with_capacity(4 + payload_len); - plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes()); - plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(31))); - - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let read = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - payload_len + 16, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("payload read must succeed") - .expect("frame must be present"); - - let (frame, quickack) = read; - assert!(!quickack, "quickack flag must be unset"); - assert_eq!( - frame.len(), - payload_len, - "payload size must match wire length" - ); - for (idx, byte) in frame.iter().enumerate() { - assert_eq!(*byte, (idx as u8).wrapping_mul(31)); - } - assert_eq!(frame_counter, 1, "exactly one frame must be counted"); -} - -#[tokio::test] -async fn read_client_payload_secure_strips_tail_padding_bytes() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload = [0x11u8, 0x22, 0x33, 0x44, 0xaa, 0xbb, 0xcc, 0xdd]; - let tail = [0xeeu8, 0xff, 0x99]; - let wire_len = payload.len() + tail.len(); - - let mut plaintext = Vec::with_capacity(4 + wire_len); - plaintext.extend_from_slice(&(wire_len as u32).to_le_bytes()); - plaintext.extend_from_slice(&payload); - plaintext.extend_from_slice(&tail); - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let read = read_client_payload( - &mut crypto_reader, - ProtoTag::Secure, - 1024, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("secure payload read must succeed") - .expect("secure frame must be present"); - - let (frame, quickack) = read; - assert!(!quickack, "quickack flag must be unset"); - assert_eq!(frame.as_ref(), &payload); - assert_eq!(frame_counter, 1, "one secure frame must be counted"); -} - -#[tokio::test] -async fn read_client_payload_secure_rejects_wire_len_below_4() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let mut plaintext = Vec::with_capacity(7); - plaintext.extend_from_slice(&3u32.to_le_bytes()); - plaintext.extend_from_slice(&[1u8, 2, 3]); - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let result = read_client_payload( - &mut crypto_reader, - ProtoTag::Secure, - 1024, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::Proxy(ref msg)) if msg.contains("Frame too small: 3")), - "secure wire length below 4 must be fail-closed by the frame-too-small guard" - ); -} - -#[tokio::test] -async fn read_client_payload_intermediate_skips_zero_len_frame() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload = [7u8, 6, 5, 4, 3, 2, 1, 0]; - let mut plaintext = Vec::with_capacity(4 + 4 + payload.len()); - plaintext.extend_from_slice(&0u32.to_le_bytes()); - plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - plaintext.extend_from_slice(&payload); - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let read = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - 1024, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("intermediate payload read must succeed") - .expect("frame must be present"); - - let (frame, quickack) = read; - assert!(!quickack, "quickack flag must be unset"); - assert_eq!(frame.as_ref(), &payload); - assert_eq!(frame_counter, 1, "zero-length frame must be skipped"); -} - -#[tokio::test] -async fn read_client_payload_abridged_extended_len_sets_quickack() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(4096); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload_len = 4 * 130; - let len_words = (payload_len / 4) as u32; - let mut plaintext = Vec::with_capacity(1 + 3 + payload_len); - plaintext.push(0xff | 0x80); - let lw = len_words.to_le_bytes(); - plaintext.extend_from_slice(&lw[..3]); - plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_add(17))); - - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let read = read_client_payload( - &mut crypto_reader, - ProtoTag::Abridged, - payload_len + 16, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("abridged payload read must succeed") - .expect("frame must be present"); - - let (frame, quickack) = read; - assert!( - quickack, - "quickack bit must be propagated from abridged header" - ); - assert_eq!(frame.len(), payload_len); - assert_eq!(frame_counter, 1, "one abridged frame must be counted"); -} - -#[tokio::test] -async fn read_client_payload_returns_buffer_to_pool_after_emit() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let pool = Arc::new(BufferPool::with_config(64, 8)); - pool.preallocate(1); - assert_eq!(pool.stats().pooled, 1, "precondition: one pooled buffer"); - - let (reader, mut writer) = duplex(4096); - let mut crypto_reader = make_crypto_reader(reader); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - // Force growth beyond default pool buffer size to catch ownership-take regressions. - let payload_len = 257usize; - let mut plaintext = Vec::with_capacity(4 + payload_len); - plaintext.extend_from_slice(&(payload_len as u32).to_le_bytes()); - plaintext.extend((0..payload_len).map(|idx| (idx as u8).wrapping_mul(13))); - - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let _ = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - payload_len + 8, - TokioDuration::from_secs(1), - &pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("payload read must succeed") - .expect("frame must be present"); - - assert_eq!(frame_counter, 1); - let pool_stats = pool.stats(); - assert!( - pool_stats.pooled >= 1, - "emitted payload buffer must be returned to pool to avoid pool drain" - ); -} - -#[tokio::test] -async fn read_client_payload_keeps_pool_buffer_checked_out_until_frame_drop() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let pool = Arc::new(BufferPool::with_config(64, 2)); - pool.preallocate(1); - assert_eq!( - pool.stats().pooled, - 1, - "one pooled buffer must be available" - ); - - let (reader, mut writer) = duplex(1024); - let mut crypto_reader = make_crypto_reader(reader); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let payload = [0x41u8, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48]; - let mut plaintext = Vec::with_capacity(4 + payload.len()); - plaintext.extend_from_slice(&(payload.len() as u32).to_le_bytes()); - plaintext.extend_from_slice(&payload); - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let (frame, quickack) = read_client_payload( - &mut crypto_reader, - ProtoTag::Intermediate, - 1024, - TokioDuration::from_secs(1), - &pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await - .expect("payload read must succeed") - .expect("frame must be present"); - - assert!(!quickack); - assert_eq!(frame.as_ref(), &payload); - assert_eq!( - pool.stats().pooled, - 0, - "buffer must stay checked out while frame payload is alive" - ); - - drop(frame); - assert!( - pool.stats().pooled >= 1, - "buffer must return to pool only after frame drop" - ); -} - -#[tokio::test] -async fn enqueue_c2me_close_unblocks_after_queue_drain() { - let (tx, mut rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[0x41]), - flags: 0, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let close_task = - tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); - - tokio::time::sleep(TokioDuration::from_millis(10)).await; - - let first = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .expect("first queued item must be present"); - assert!(matches!(first, C2MeCommand::Data { .. })); - - close_task - .await - .unwrap() - .expect("close enqueue must succeed after drain"); - - let second = timeout(TokioDuration::from_millis(100), rx.recv()) - .await - .unwrap() - .expect("close command must follow after queue drain"); - assert!(matches!(second, C2MeCommand::Close)); -} - -#[tokio::test] -async fn enqueue_c2me_close_full_then_receiver_drop_fails_cleanly() { - let (tx, rx) = mpsc::channel::(1); - tx.send(C2MeCommand::Data { - payload: make_pooled_payload(&[0x42]), - flags: 0, - }) - .await - .unwrap(); - - let tx2 = tx.clone(); - let close_task = - tokio::spawn(async move { enqueue_c2me_command(&tx2, C2MeCommand::Close).await }); - - tokio::time::sleep(TokioDuration::from_millis(10)).await; - drop(rx); - - let result = timeout(TokioDuration::from_secs(1), close_task) - .await - .expect("close task must finish") - .expect("close task must not panic"); - assert!( - result.is_err(), - "close enqueue must fail cleanly when receiver is dropped under pressure" - ); -} - -#[tokio::test] -async fn process_me_writer_response_ack_obeys_flush_policy() { - let (writer_side, _reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - let immediate = process_me_writer_response( - MeResponse::Ack(0x11223344), - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "user", - None, - 0, - &bytes_me2c, - 77, - true, - false, - ) - .await - .expect("ack response must be processed"); - - assert!(matches!( - immediate, - MeWriterResponseOutcome::Continue { - frames: 1, - bytes: 4, - flush_immediately: true, - } - )); - - let delayed = process_me_writer_response( - MeResponse::Ack(0x55667788), - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "user", - None, - 0, - &bytes_me2c, - 77, - false, - false, - ) - .await - .expect("ack response must be processed"); - - assert!(matches!( - delayed, - MeWriterResponseOutcome::Continue { - frames: 1, - bytes: 4, - flush_immediately: false, - } - )); -} - -#[tokio::test] -async fn process_me_writer_response_data_updates_byte_accounting() { - let (writer_side, _reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - let payload = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9]; - let outcome = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(payload.clone()), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "user", - None, - 0, - &bytes_me2c, - 88, - false, - false, - ) - .await - .expect("data response must be processed"); - - assert!(matches!( - outcome, - MeWriterResponseOutcome::Continue { - frames: 1, - bytes, - flush_immediately: false, - } if bytes == payload.len() - )); - assert_eq!( - bytes_me2c.load(std::sync::atomic::Ordering::Relaxed), - payload.len() as u64, - "ME->C byte accounting must increase by emitted payload size" - ); -} - -#[tokio::test] -async fn process_me_writer_response_data_enforces_live_user_quota() { - let (writer_side, mut reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - stats.add_user_octets_from("quota-user", 10); - - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![1u8, 2, 3, 4]), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "quota-user", - Some(12), - 0, - &bytes_me2c, - 89, - false, - false, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "quota-user"), - "ME->client runtime path must terminate when live user quota is crossed" - ); - - let mut raw = [0u8; 1]; - assert!( - timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw)) - .await - .is_err(), - "quota exhaustion must not write any ciphertext to the client stream" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn process_me_writer_response_concurrent_same_user_quota_does_not_overshoot_limit() { - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - let user = "quota-race-user"; - - let (writer_side_a, _reader_side_a) = duplex(1024); - let (writer_side_b, _reader_side_b) = duplex(1024); - let mut writer_a = make_crypto_writer(writer_side_a); - let mut writer_b = make_crypto_writer(writer_side_b); - let mut frame_buf_a = Vec::new(); - let mut frame_buf_b = Vec::new(); - let rng_a = SecureRandom::new(); - let rng_b = SecureRandom::new(); - - let fut_a = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x11]), - }, - &mut writer_a, - ProtoTag::Intermediate, - &rng_a, - &mut frame_buf_a, - &stats, - user, - Some(1), - 0, - &bytes_me2c, - 91, - false, - false, - ); - let fut_b = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from_static(&[0x22]), - }, - &mut writer_b, - ProtoTag::Intermediate, - &rng_b, - &mut frame_buf_b, - &stats, - user, - Some(1), - 0, - &bytes_me2c, - 92, - false, - false, - ); - - let (result_a, result_b) = tokio::join!(fut_a, fut_b); - - assert!( - matches!(result_a, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user") - || matches!(result_a, Ok(_)), - "concurrent quota test must complete without panicking" - ); - assert!( - matches!(result_b, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-race-user") - || matches!(result_b, Ok(_)), - "concurrent quota test must complete without panicking" - ); - assert!( - stats.get_user_total_octets(user) <= 1, - "same-user concurrent middle-relay responses must not overshoot the configured quota" - ); -} - -#[tokio::test] -async fn process_me_writer_response_data_does_not_forward_partial_payload_when_remaining_quota_is_smaller_than_message() - { - let (writer_side, mut reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - stats.add_user_octets_to("partial-quota-user", 3); - - let result = process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![1u8, 2, 3, 4]), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - &stats, - "partial-quota-user", - Some(4), - 0, - &bytes_me2c, - 90, - false, - false, - ) - .await; - - assert!( - matches!(result, Err(ProxyError::DataQuotaExceeded { user }) if user == "partial-quota-user"), - "ME->client runtime path must reject oversized payloads before writing" - ); - - let mut raw = [0u8; 1]; - assert!( - timeout(TokioDuration::from_millis(100), reader_side.read(&mut raw)) - .await - .is_err(), - "oversized payloads must not leak any partial ciphertext to the client stream" - ); -} - -#[tokio::test] -async fn middle_relay_abort_midflight_releases_route_gauge() { - let stats = Arc::new(Stats::new()); - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::new()); - let rng = Arc::new(SecureRandom::new()); - - let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); - let route_snapshot = route_runtime.snapshot(); - - let (server_side, client_side) = duplex(64 * 1024); - let (server_reader, server_writer) = tokio::io::split(server_side); - let crypto_reader = make_crypto_reader(server_reader); - let crypto_writer = make_crypto_writer(server_writer); - - let success = HandshakeSuccess { - user: "abort-middle-user".to_string(), - dc_idx: 2, - proto_tag: ProtoTag::Intermediate, - dec_key: [0u8; 32], - dec_iv: 0, - enc_key: [0u8; 32], - enc_iv: 0, - peer: "127.0.0.1:50001".parse().unwrap(), - is_tls: false, - }; - - let relay_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool, - stats.clone(), - config, - buffer_pool, - "127.0.0.1:443".parse().unwrap(), - rng, - route_runtime.subscribe(), - route_snapshot, - 0xdecafbad, - )); - - let started = tokio::time::timeout(TokioDuration::from_secs(2), async { - loop { - if stats.get_current_connections_me() == 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await; - assert!( - started.is_ok(), - "middle relay must increment route gauge before abort" - ); - - relay_task.abort(); - let joined = relay_task.await; - assert!( - joined.is_err(), - "aborted middle relay task must return join error" - ); - - tokio::time::sleep(TokioDuration::from_millis(20)).await; - assert_eq!( - stats.get_current_connections_me(), - 0, - "route gauge must be released when middle relay task is aborted mid-flight" - ); - - drop(client_side); -} - -#[tokio::test] -async fn middle_relay_cutover_midflight_releases_route_gauge() { - let stats = Arc::new(Stats::new()); - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::new()); - let rng = Arc::new(SecureRandom::new()); - - let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); - let route_snapshot = route_runtime.snapshot(); - - let (server_side, client_side) = duplex(64 * 1024); - let (server_reader, server_writer) = tokio::io::split(server_side); - let crypto_reader = make_crypto_reader(server_reader); - let crypto_writer = make_crypto_writer(server_writer); - - let success = HandshakeSuccess { - user: "cutover-middle-user".to_string(), - dc_idx: 2, - proto_tag: ProtoTag::Intermediate, - dec_key: [0u8; 32], - dec_iv: 0, - enc_key: [0u8; 32], - enc_iv: 0, - peer: "127.0.0.1:50003".parse().unwrap(), - is_tls: false, - }; - - let relay_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool, - stats.clone(), - config, - buffer_pool, - "127.0.0.1:443".parse().unwrap(), - rng, - route_runtime.subscribe(), - route_snapshot, - 0xfeed_beef, - )); - - tokio::time::timeout(TokioDuration::from_secs(2), async { - loop { - if stats.get_current_connections_me() == 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("middle relay must increment route gauge before cutover"); - - assert!( - route_runtime.set_mode(RelayRouteMode::Direct).is_some(), - "cutover must advance route generation" - ); - - let relay_result = tokio::time::timeout(TokioDuration::from_secs(6), relay_task) - .await - .expect("middle relay must terminate after cutover") - .expect("middle relay task must not panic"); - assert!( - relay_result.is_err(), - "cutover should terminate middle relay session" - ); - assert!( - matches!( - relay_result, - Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG - ), - "client-visible cutover error must stay generic and avoid route-internal metadata" - ); - - assert_eq!( - stats.get_current_connections_me(), - 0, - "route gauge must be released when middle relay exits on cutover" - ); - - drop(client_side); -} - -async fn run_quota_race_attempt( - stats: &Stats, - bytes_me2c: &AtomicU64, - user: &str, - payload: u8, - conn_id: u64, - barrier: Arc, -) -> Result { - let (writer_side, _reader_side) = duplex(1024); - let mut writer = make_crypto_writer(writer_side); - let rng = SecureRandom::new(); - let mut frame_buf = Vec::new(); - - barrier.wait().await; - process_me_writer_response( - MeResponse::Data { - flags: 0, - data: Bytes::from(vec![payload]), - }, - &mut writer, - ProtoTag::Intermediate, - &rng, - &mut frame_buf, - stats, - user, - Some(1), - 0, - bytes_me2c, - conn_id, - false, - false, - ) - .await -} - -#[tokio::test] -async fn abridged_max_extended_length_fails_closed_without_panic_or_partial_read() { - let _guard = desync_dedup_test_lock() - .lock() - .expect("middle relay test lock must be available"); - - let (reader, mut writer) = duplex(256); - let mut crypto_reader = make_crypto_reader(reader); - let buffer_pool = Arc::new(BufferPool::new()); - let stats = Stats::new(); - let forensics = make_forensics_state(); - let mut frame_counter = 0; - - let plaintext = vec![0x7f, 0xff, 0xff, 0xff]; - let encrypted = encrypt_for_reader(&plaintext); - writer.write_all(&encrypted).await.unwrap(); - - let result = read_client_payload( - &mut crypto_reader, - ProtoTag::Abridged, - 4096, - TokioDuration::from_secs(1), - &buffer_pool, - &forensics, - &mut frame_counter, - &stats, - ) - .await; - - assert!( - result.is_err(), - "oversized abridged length must fail closed" - ); - assert_eq!( - frame_counter, 0, - "oversized frame must not be counted as accepted" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn deterministic_quota_race_exactly_one_succeeds_and_one_is_rejected() { - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - let user = "gap-t04-race-user"; - let barrier = Arc::new(Barrier::new(2)); - - let f1 = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x11, 5001, barrier.clone()); - let f2 = run_quota_race_attempt(&stats, &bytes_me2c, user, 0x22, 5002, barrier); - - let (r1, r2) = tokio::join!(f1, f2); - - assert!( - matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "first racer must either finish or fail closed on quota" - ); - assert!( - matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "second racer must either finish or fail closed on quota" - ); - assert!( - matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), - "at least one racer must be quota-rejected" - ); - assert_eq!( - stats.get_user_total_octets(user), - 1, - "same-user race must forward/account exactly one payload byte" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_quota_race_bursts_never_allow_double_success_per_round() { - let stats = Stats::new(); - let bytes_me2c = AtomicU64::new(0); - - for round in 0..128u64 { - let user = format!("gap-t04-race-burst-{round}"); - let barrier = Arc::new(Barrier::new(2)); - - let one = run_quota_race_attempt( - &stats, - &bytes_me2c, - &user, - 0x33, - 6000 + round, - barrier.clone(), - ); - let two = run_quota_race_attempt(&stats, &bytes_me2c, &user, 0x44, 7000 + round, barrier); - - let (r1, r2) = tokio::join!(one, two); - assert!( - matches!(r1, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })) - && matches!(r2, Ok(_) | Err(ProxyError::DataQuotaExceeded { .. })), - "round {round}: racers must resolve cleanly without unexpected errors" - ); - assert!( - matches!(r1, Err(ProxyError::DataQuotaExceeded { .. })) - || matches!(r2, Err(ProxyError::DataQuotaExceeded { .. })), - "round {round}: at least one racer must be quota-rejected" - ); - assert_eq!( - stats.get_user_total_octets(&user), - 1, - "round {round}: same-user total octets must remain exactly 1 (single forwarded winner)" - ); - } -} - -#[tokio::test] -async fn middle_relay_cutover_storm_multi_session_keeps_generic_errors_and_releases_gauge() { - let session_count = 6usize; - let stats = Arc::new(Stats::new()); - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::new()); - let rng = Arc::new(SecureRandom::new()); - - let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); - let route_snapshot = route_runtime.snapshot(); - - let mut relay_tasks = Vec::with_capacity(session_count); - let mut client_sides = Vec::with_capacity(session_count); - - for idx in 0..session_count { - let (server_side, client_side) = duplex(64 * 1024); - client_sides.push(client_side); - let (server_reader, server_writer) = tokio::io::split(server_side); - let crypto_reader = make_crypto_reader(server_reader); - let crypto_writer = make_crypto_writer(server_writer); - - let success = HandshakeSuccess { - user: format!("cutover-storm-middle-user-{idx}"), - dc_idx: 2, - proto_tag: ProtoTag::Intermediate, - dec_key: [0u8; 32], - dec_iv: 0, - enc_key: [0u8; 32], - enc_iv: 0, - peer: SocketAddr::new( - std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)), - 52000 + idx as u16, - ), - is_tls: false, - }; - - relay_tasks.push(tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool.clone(), - stats.clone(), - config.clone(), - buffer_pool.clone(), - "127.0.0.1:443".parse().unwrap(), - rng.clone(), - route_runtime.subscribe(), - route_snapshot, - 0xB000_0000 + idx as u64, - ))); - } - - tokio::time::timeout(TokioDuration::from_secs(4), async { - loop { - if stats.get_current_connections_me() == session_count as u64 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("all middle sessions must become active before cutover storm"); - - let route_runtime_flipper = route_runtime.clone(); - let flipper = tokio::spawn(async move { - for step in 0..64u32 { - let mode = if (step & 1) == 0 { - RelayRouteMode::Direct - } else { - RelayRouteMode::Middle - }; - let _ = route_runtime_flipper.set_mode(mode); - tokio::time::sleep(TokioDuration::from_millis(15)).await; - } - }); - - for relay_task in relay_tasks { - let relay_result = tokio::time::timeout(TokioDuration::from_secs(10), relay_task) - .await - .expect("middle relay task must finish under cutover storm") - .expect("middle relay task must not panic"); - - assert!( - matches!( - relay_result, - Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG - ), - "storm-cutover termination must remain generic for all middle sessions" - ); - } - - flipper.abort(); - let _ = flipper.await; - - assert_eq!( - stats.get_current_connections_me(), - 0, - "middle route gauge must return to zero after cutover storm" - ); - - drop(client_sides); -} - -#[tokio::test] -async fn secure_padding_distribution_in_relay_writer() { - timeout(TokioDuration::from_secs(10), async { - let (mut client_side, relay_side) = duplex(512 * 1024); - let key = [0u8; 32]; - let iv = 0u128; - let mut writer = CryptoWriter::new(relay_side, AesCtr::new(&key, iv), 8 * 1024); - let rng = Arc::new(SecureRandom::new()); - let mut frame_buf = Vec::new(); - let mut decryptor = AesCtr::new(&key, iv); - - let mut padding_counts = [0usize; 4]; - let iterations = 180usize; - let payload = vec![0xAAu8; 100]; // 4-byte aligned - - for _ in 0..iterations { - write_client_payload( - &mut writer, - ProtoTag::Secure, - 0, - &payload, - &rng, - &mut frame_buf, - ) - .await - .expect("payload write must succeed"); - writer - .flush() - .await - .expect("writer flush must complete so encrypted frame becomes readable"); - - let mut len_buf = [0u8; 4]; - client_side - .read_exact(&mut len_buf) - .await - .expect("must read encrypted secure length"); - let decrypted_len_bytes = decryptor.decrypt(&len_buf); - let decrypted_len_bytes: [u8; 4] = decrypted_len_bytes - .try_into() - .expect("decrypted length must be 4 bytes"); - let wire_len = (u32::from_le_bytes(decrypted_len_bytes) & 0x7fff_ffff) as usize; - - assert!( - wire_len >= payload.len(), - "wire length must include at least payload bytes" - ); - let padding_len = wire_len - payload.len(); - assert!(padding_len >= 1 && padding_len <= 3); - padding_counts[padding_len] += 1; - - // Drain and decrypt frame bytes so CTR state stays aligned across writes. - let mut trash = vec![0u8; wire_len]; - client_side - .read_exact(&mut trash) - .await - .expect("must read encrypted secure frame body"); - let _ = decryptor.decrypt(&trash); - } - - for p in 1..=3 { - let count = padding_counts[p]; - assert!( - count > iterations / 8, - "padding length {p} is under-represented ({count}/{iterations})" - ); - } - }) - .await - .expect("secure padding distribution test exceeded runtime budget"); -} - -#[tokio::test] -async fn negative_middle_end_connection_lost_during_relay_exits_on_client_eof() { - let (client_reader_side, client_writer_side) = duplex(1024); - let (_relay_reader_side, relay_writer_side) = duplex(1024); - - let key = [0u8; 32]; - let iv = 0u128; - let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); - let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); - - let stats = Arc::new(Stats::new()); - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); - let rng = Arc::new(SecureRandom::new()); - let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); - - // Create an ME pool. - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - - // ConnRegistry ids are monotonic; reserve one id so we can predict the - // next session conn_id and close it deterministically without relying on - // writer-bound views such as active_conn_ids(). - let (probe_conn_id, probe_rx) = me_pool.registry().register().await; - drop(probe_rx); - me_pool.registry().unregister(probe_conn_id).await; - let target_conn_id = probe_conn_id.wrapping_add(1); - - let success = HandshakeSuccess { - user: "test-user".to_string(), - peer: "127.0.0.1:12345".parse().unwrap(), - dc_idx: 1, - proto_tag: ProtoTag::Intermediate, - enc_key: key, - enc_iv: iv, - dec_key: key, - dec_iv: iv, - is_tls: false, - }; - - let session_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool.clone(), - stats.clone(), - config.clone(), - buffer_pool.clone(), - "127.0.0.1:443".parse().unwrap(), - rng.clone(), - route_runtime.subscribe(), - route_runtime.snapshot(), - 0x1234_5678, - )); - - // Wait until session startup is visible, then unregister the predicted - // conn_id to close the per-session ME response channel. - timeout(TokioDuration::from_millis(500), async { - loop { - if stats.get_current_connections_me() >= 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("ME session must start before channel close simulation"); - - me_pool.registry().unregister(target_conn_id).await; - - drop(client_writer_side); - - let result = timeout(TokioDuration::from_secs(2), session_task) - .await - .expect("Session task must terminate after ME drop and client EOF") - .expect("Session task must not panic"); - - assert!( - result.is_ok(), - "Session should complete cleanly after ME drop when client closes, got: {:?}", - result - ); -} - -#[tokio::test] -async fn adversarial_middle_end_drop_plus_cutover_returns_generic_route_switch() { - let (client_reader_side, _client_writer_side) = duplex(1024); - let (_relay_reader_side, relay_writer_side) = duplex(1024); - - let key = [0u8; 32]; - let iv = 0u128; - let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); - let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); - - let stats = Arc::new(Stats::new()); - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); - let rng = Arc::new(SecureRandom::new()); - let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Middle)); - - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - - // Predict the next conn_id so we can force-drop its ME channel deterministically. - let (probe_conn_id, probe_rx) = me_pool.registry().register().await; - drop(probe_rx); - me_pool.registry().unregister(probe_conn_id).await; - let target_conn_id = probe_conn_id.wrapping_add(1); - - let success = HandshakeSuccess { - user: "test-user-cutover".to_string(), - peer: "127.0.0.1:12345".parse().unwrap(), - dc_idx: 1, - proto_tag: ProtoTag::Intermediate, - enc_key: key, - enc_iv: iv, - dec_key: key, - dec_iv: iv, - is_tls: false, - }; - - let runtime_clone = route_runtime.clone(); - let session_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool.clone(), - stats.clone(), - config, - buffer_pool, - "127.0.0.1:443".parse().unwrap(), - rng, - runtime_clone.subscribe(), - runtime_clone.snapshot(), - 0xC001_CAFE, - )); - - timeout(TokioDuration::from_millis(500), async { - loop { - if stats.get_current_connections_me() >= 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("ME session must start before race trigger"); - - // Race ME channel drop with route cutover and assert generic client-visible outcome. - me_pool.registry().unregister(target_conn_id).await; - assert!( - route_runtime.set_mode(RelayRouteMode::Direct).is_some(), - "cutover must advance generation" - ); - - let relay_result = timeout(TokioDuration::from_secs(6), session_task) - .await - .expect("session must terminate under ME-drop + cutover race") - .expect("session task must not panic"); - - assert!( - matches!( - relay_result, - Err(ProxyError::Proxy(ref msg)) if msg == ROUTE_SWITCH_ERROR_MSG - ), - "race outcome must remain generic and not leak ME internals, got: {:?}", - relay_result - ); -} - -#[tokio::test] -async fn stress_middle_end_drop_with_client_eof_never_hangs_across_burst() { - let stats = Arc::new(Stats::new()); - let me_pool = make_me_pool_for_abort_test(stats.clone()).await; - - for round in 0..32u64 { - let (client_reader_side, client_writer_side) = duplex(1024); - let (_relay_reader_side, relay_writer_side) = duplex(1024); - - let key = [0u8; 32]; - let iv = 0u128; - let crypto_reader = CryptoReader::new(client_reader_side, AesCtr::new(&key, iv)); - let crypto_writer = CryptoWriter::new(relay_writer_side, AesCtr::new(&key, iv), 1024); - - let config = Arc::new(ProxyConfig::default()); - let buffer_pool = Arc::new(BufferPool::with_config(1024, 1)); - let rng = Arc::new(SecureRandom::new()); - let route_runtime = RouteRuntimeController::new(RelayRouteMode::Middle); - - let (probe_conn_id, probe_rx) = me_pool.registry().register().await; - drop(probe_rx); - me_pool.registry().unregister(probe_conn_id).await; - let target_conn_id = probe_conn_id.wrapping_add(1); - - let success = HandshakeSuccess { - user: format!("stress-me-drop-eof-{round}"), - peer: "127.0.0.1:12345".parse().unwrap(), - dc_idx: 1, - proto_tag: ProtoTag::Intermediate, - enc_key: key, - enc_iv: iv, - dec_key: key, - dec_iv: iv, - is_tls: false, - }; - - let session_task = tokio::spawn(handle_via_middle_proxy( - crypto_reader, - crypto_writer, - success, - me_pool.clone(), - stats.clone(), - config, - buffer_pool, - "127.0.0.1:443".parse().unwrap(), - rng, - route_runtime.subscribe(), - route_runtime.snapshot(), - 0xD00D_0000 + round, - )); - - timeout(TokioDuration::from_millis(500), async { - loop { - if stats.get_current_connections_me() >= 1 { - break; - } - tokio::time::sleep(TokioDuration::from_millis(10)).await; - } - }) - .await - .expect("session must start before forced drop in burst round"); - - me_pool.registry().unregister(target_conn_id).await; - drop(client_writer_side); - - let result = timeout(TokioDuration::from_secs(2), session_task) - .await - .expect("burst round session must terminate quickly") - .expect("burst round session must not panic"); - - assert!( - result.is_ok(), - "burst round {round}: expected clean shutdown after ME drop + EOF, got: {:?}", - result - ); - } -} diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs index 34fc454..6b1d511 100644 --- a/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs @@ -217,7 +217,9 @@ async fn adversarial_lockstep_alternating_attack_under_jitter_closes() { } } - writer_task.await.expect("writer jitter task must not panic"); + writer_task + .await + .expect("writer jitter task must not panic"); assert!(closed, "alternating attack must close before EOF"); }); } @@ -247,7 +249,10 @@ async fn integration_mixed_population_attackers_close_benign_survive() { plaintext.push(0x01); plaintext.extend_from_slice(&[n, n, n, n]); } - writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap(); + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); drop(writer); let mut closed = false; @@ -279,7 +284,10 @@ async fn integration_mixed_population_attackers_close_benign_survive() { } plaintext.push(0x01); plaintext.extend_from_slice(&payload); - writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap(); + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); let got = read_once( &mut crypto_reader, @@ -329,7 +337,10 @@ async fn light_fuzz_parallel_patterns_no_hang_or_panic() { } } - writer.write_all(&encrypt_for_reader(&plaintext)).await.unwrap(); + writer + .write_all(&encrypt_for_reader(&plaintext)) + .await + .unwrap(); drop(writer); for _ in 0..320 { diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs index 853b381..cbbc971 100644 --- a/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs @@ -51,7 +51,9 @@ fn make_enabled_idle_policy() -> RelayClientIdlePolicy { fn append_tiny_frame(plaintext: &mut Vec, proto: ProtoTag) { match proto { ProtoTag::Abridged => plaintext.push(0x00), - ProtoTag::Intermediate | ProtoTag::Secure => plaintext.extend_from_slice(&0u32.to_le_bytes()), + ProtoTag::Intermediate | ProtoTag::Secure => { + plaintext.extend_from_slice(&0u32.to_le_bytes()) + } } } @@ -206,7 +208,11 @@ async fn intermediate_chunked_alternating_attack_closes_before_eof() { let mut plaintext = Vec::with_capacity(8 * 200); for n in 0..180u8 { append_tiny_frame(&mut plaintext, ProtoTag::Intermediate); - append_real_frame(&mut plaintext, ProtoTag::Intermediate, [n, n ^ 1, n ^ 2, n ^ 3]); + append_real_frame( + &mut plaintext, + ProtoTag::Intermediate, + [n, n ^ 1, n ^ 2, n ^ 3], + ); } let encrypted = encrypt_for_reader(&plaintext); @@ -240,7 +246,9 @@ async fn intermediate_chunked_alternating_attack_closes_before_eof() { } } - writer_task.await.expect("intermediate writer task must not panic"); + writer_task + .await + .expect("intermediate writer task must not panic"); assert!(closed, "intermediate alternating attack must fail closed"); } @@ -290,7 +298,9 @@ async fn secure_chunked_alternating_attack_closes_before_eof() { } } - writer_task.await.expect("secure writer task must not panic"); + writer_task + .await + .expect("secure writer task must not panic"); assert!(closed, "secure alternating attack must fail closed"); } diff --git a/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs index dee5dd9..fad87d0 100644 --- a/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs +++ b/src/proxy/tests/middle_relay_tiny_frame_debt_security_tests.rs @@ -2,8 +2,8 @@ use super::*; use crate::crypto::AesCtr; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader}; -use std::sync::atomic::AtomicU64; use std::sync::Arc; +use std::sync::atomic::AtomicU64; use std::time::Instant; use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; @@ -156,7 +156,10 @@ fn alternating_one_to_one_closes_with_bounded_real_frame_count() { } let (closed_at, _, reals) = simulate_tiny_debt_pattern(&pattern, pattern.len()); assert!(closed_at.is_some()); - assert!(reals <= 80, "expected bounded real frames before close, got {reals}"); + assert!( + reals <= 80, + "expected bounded real frames before close, got {reals}" + ); } #[test] @@ -183,7 +186,10 @@ fn alternating_one_to_seven_eventually_closes() { } } let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len()); - assert!(closed_at.is_some(), "1:7 tiny-to-real must eventually close"); + assert!( + closed_at.is_some(), + "1:7 tiny-to-real must eventually close" + ); } #[test] diff --git a/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs b/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs index 765c253..dbf6c4c 100644 --- a/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs +++ b/src/proxy/tests/middle_relay_zero_length_frame_security_tests.rs @@ -2,10 +2,10 @@ use super::*; use crate::crypto::AesCtr; use crate::stats::Stats; use crate::stream::{BufferPool, CryptoReader}; -use std::sync::atomic::AtomicU64; use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; +use std::sync::atomic::AtomicU64; use std::time::Instant; +use tokio::io::{AsyncRead, AsyncWriteExt, duplex}; fn make_crypto_reader(reader: T) -> CryptoReader where diff --git a/src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs b/src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs deleted file mode 100644 index fb0cf93..0000000 --- a/src/proxy/tests/quota_lock_registry_cross_mode_adversarial_tests.rs +++ /dev/null @@ -1,108 +0,0 @@ -use super::*; -use std::sync::Arc; -use std::sync::{Mutex, OnceLock}; - -fn cross_mode_lock_test_guard() -> std::sync::MutexGuard<'static, ()> { - static TEST_LOCK: OnceLock> = OnceLock::new(); - TEST_LOCK - .get_or_init(|| Mutex::new(())) - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) -} - -#[test] -fn same_user_returns_same_lock_identity() { - let _guard = cross_mode_lock_test_guard(); - let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); - locks.clear(); - - let a = cross_mode_quota_user_lock("cross-mode-same-user"); - let b = cross_mode_quota_user_lock("cross-mode-same-user"); - - assert!( - Arc::ptr_eq(&a, &b), - "same user must reuse a stable lock identity" - ); -} - -#[test] -fn saturation_overflow_path_returns_stable_striped_lock_without_cache_growth() { - let _guard = cross_mode_lock_test_guard(); - let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); - locks.clear(); - - let prefix = format!("cross-mode-saturated-{}", std::process::id()); - let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX); - for idx in 0..CROSS_MODE_QUOTA_USER_LOCKS_MAX { - retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}"))); - } - - assert_eq!( - locks.len(), - CROSS_MODE_QUOTA_USER_LOCKS_MAX, - "lock cache must be saturated for overflow check" - ); - - let overflow_user = format!("cross-mode-overflow-{}", std::process::id()); - let overflow_a = cross_mode_quota_user_lock(&overflow_user); - let overflow_b = cross_mode_quota_user_lock(&overflow_user); - - assert_eq!( - locks.len(), - CROSS_MODE_QUOTA_USER_LOCKS_MAX, - "overflow path must not grow bounded lock cache" - ); - assert!( - locks.get(&overflow_user).is_none(), - "overflow user must stay on striped fallback while cache is saturated" - ); - assert!( - Arc::ptr_eq(&overflow_a, &overflow_b), - "overflow user must receive a stable striped lock across repeated lookups" - ); - - drop(retained); -} - -#[test] -fn reclaim_drops_stale_entries_but_preserves_active_user_lock_identity() { - let _guard = cross_mode_lock_test_guard(); - let locks = CROSS_MODE_QUOTA_USER_LOCKS.get_or_init(DashMap::new); - locks.clear(); - - let prefix = format!("cross-mode-reclaim-{}", std::process::id()); - let protected_user = format!("{prefix}-protected"); - - let protected_lock = cross_mode_quota_user_lock(&protected_user); - let mut retained = Vec::with_capacity(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1)); - for idx in 0..(CROSS_MODE_QUOTA_USER_LOCKS_MAX.saturating_sub(1)) { - retained.push(cross_mode_quota_user_lock(&format!("{prefix}-{idx}"))); - } - - assert_eq!( - locks.len(), - CROSS_MODE_QUOTA_USER_LOCKS_MAX, - "fixture must saturate lock cache before reclaim path is exercised" - ); - - drop(retained); - - let newcomer_user = format!("{prefix}-newcomer"); - let _newcomer = cross_mode_quota_user_lock(&newcomer_user); - - assert!( - locks.get(&protected_user).is_some(), - "active protected user must remain cache-resident after reclaim" - ); - let locked = locks - .get(&protected_user) - .expect("protected user must remain in map after reclaim"); - assert!( - Arc::ptr_eq(locked.value(), &protected_lock), - "reclaim must not swap active user lock identity" - ); - assert!( - locks.get(&newcomer_user).is_some(), - "newcomer should become cacheable after stale entries are reclaimed" - ); -} diff --git a/src/proxy/tests/relay_adversarial_tests.rs b/src/proxy/tests/relay_adversarial_tests.rs index 14754cd..38e6fc7 100644 --- a/src/proxy/tests/relay_adversarial_tests.rs +++ b/src/proxy/tests/relay_adversarial_tests.rs @@ -78,7 +78,8 @@ async fn relay_hol_blocking_prevention_regression() { async fn relay_quota_mid_session_cutoff() { let stats = Arc::new(Stats::new()); let user = "quota-mid-user"; - let quota = 5000; + let quota = 5000u64; + let c2s_buf_size = 1024usize; let (client_peer, relay_client) = duplex(8192); let (relay_server, server_peer) = duplex(8192); @@ -93,7 +94,7 @@ async fn relay_quota_mid_session_cutoff() { client_writer, server_reader, server_writer, - 1024, + c2s_buf_size, 1024, user, Arc::clone(&stats), @@ -120,9 +121,25 @@ async fn relay_quota_mid_session_cutoff() { other => panic!("Expected DataQuotaExceeded error, got: {:?}", other), } - let mut small_buf = [0u8; 1]; - let n = sp_reader.read(&mut small_buf).await.unwrap(); - assert_eq!(n, 0, "Server must see EOF after quota reached"); + let mut overshoot_bytes = 0usize; + let mut buf = [0u8; 256]; + loop { + match timeout(Duration::from_millis(20), sp_reader.read(&mut buf)).await { + Ok(Ok(0)) => break, + Ok(Ok(n)) => overshoot_bytes = overshoot_bytes.saturating_add(n), + Ok(Err(e)) => panic!("server read must not fail after relay cutoff: {e}"), + Err(_) => break, + } + } + + assert!( + overshoot_bytes <= c2s_buf_size, + "post-write cutoff may leak at most one C->S chunk after boundary, got {overshoot_bytes}" + ); + assert!( + stats.get_user_quota_used(user) <= quota.saturating_add(c2s_buf_size as u64), + "accounted quota must remain bounded by one in-flight chunk overshoot" + ); } #[tokio::test] diff --git a/src/proxy/tests/relay_atomic_quota_invariant_tests.rs b/src/proxy/tests/relay_atomic_quota_invariant_tests.rs new file mode 100644 index 0000000..1bb00a6 --- /dev/null +++ b/src/proxy/tests/relay_atomic_quota_invariant_tests.rs @@ -0,0 +1,243 @@ +use super::*; +use std::collections::VecDeque; +use std::io; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; +use std::task::{Context, Poll}; +use tokio::io::{AsyncWrite, AsyncWriteExt}; +use tokio::time::Instant; + +struct ScriptedWriter { + scripted_writes: Arc>>, + write_calls: Arc, +} + +impl ScriptedWriter { + fn new(script: &[usize], write_calls: Arc) -> Self { + Self { + scripted_writes: Arc::new(Mutex::new(script.iter().copied().collect())), + write_calls, + } + } +} + +impl AsyncWrite for ScriptedWriter { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + this.write_calls.fetch_add(1, Ordering::Relaxed); + let planned = this + .scripted_writes + .lock() + .unwrap_or_else(|poisoned| poisoned.into_inner()) + .pop_front() + .unwrap_or(buf.len()); + Poll::Ready(Ok(planned.min(buf.len()))) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +fn make_stats_io_with_script( + user: &str, + quota_limit: u64, + precharged_quota: u64, + script: &[usize], +) -> ( + StatsIo, + Arc, + Arc, + Arc, +) { + let stats = Arc::new(Stats::new()); + if precharged_quota > 0 { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), precharged_quota); + } + + let write_calls = Arc::new(AtomicUsize::new(0)); + let quota_exceeded = Arc::new(AtomicBool::new(false)); + let io = StatsIo::new( + ScriptedWriter::new(script, write_calls.clone()), + Arc::new(SharedCounters::new()), + stats.clone(), + user.to_string(), + Some(quota_limit), + quota_exceeded.clone(), + Instant::now(), + ); + + (io, stats, write_calls, quota_exceeded) +} + +#[tokio::test] +async fn direct_partial_write_charges_only_committed_bytes_without_double_charge() { + let user = "direct-partial-charge-user"; + let (mut io, stats, write_calls, quota_exceeded) = + make_stats_io_with_script(user, 1_048_576, 0, &[8 * 1024, 8 * 1024, 48 * 1024]); + let payload = vec![0xAB; 64 * 1024]; + + let n1 = io + .write(&payload) + .await + .expect("first partial write must succeed"); + let n2 = io + .write(&payload) + .await + .expect("second partial write must succeed"); + let n3 = io.write(&payload).await.expect("tail write must succeed"); + + assert_eq!(n1, 8 * 1024); + assert_eq!(n2, 8 * 1024); + assert_eq!(n3, 48 * 1024); + assert_eq!(write_calls.load(Ordering::Relaxed), 3); + assert_eq!( + stats.get_user_quota_used(user), + (n1 + n2 + n3) as u64, + "quota accounting must follow committed bytes only" + ); + assert_eq!( + stats.get_user_total_octets(user), + (n1 + n2 + n3) as u64, + "telemetry octets should match committed bytes on successful writes" + ); + assert!( + !quota_exceeded.load(Ordering::Acquire), + "quota flag should stay false under large remaining budget" + ); +} + +#[tokio::test] +async fn direct_hybrid_branch_selection_matches_contract() { + let near_limit = 256 * 1024u64; + let near_remaining = 32 * 1024u64; + let (mut near_io, _stats, _calls, _flag) = make_stats_io_with_script( + "direct-near-limit-hard-check-user", + near_limit, + near_limit - near_remaining, + &[4 * 1024], + ); + let near_payload = vec![0x11; 4 * 1024]; + let near_written = near_io + .write(&near_payload) + .await + .expect("near-limit write must succeed"); + assert_eq!(near_written, 4 * 1024); + assert_eq!( + near_io.quota_bytes_since_check, 0, + "near-limit branch must go through immediate hard check" + ); + + let (mut far_small_io, _stats, _calls, _flag) = + make_stats_io_with_script("direct-far-small-amortized-user", 1_048_576, 0, &[4 * 1024]); + let far_small_payload = vec![0x22; 4 * 1024]; + let far_small_written = far_small_io + .write(&far_small_payload) + .await + .expect("small far-from-limit write must succeed"); + assert_eq!(far_small_written, 4 * 1024); + assert_eq!( + far_small_io.quota_bytes_since_check, + 4 * 1024, + "small far-from-limit write must go through amortized path" + ); + + let (mut far_large_io, _stats, _calls, _flag) = make_stats_io_with_script( + "direct-far-large-hard-check-user", + 1_048_576, + 0, + &[32 * 1024], + ); + let far_large_payload = vec![0x33; 32 * 1024]; + let far_large_written = far_large_io + .write(&far_large_payload) + .await + .expect("large write must succeed"); + assert_eq!(far_large_written, 32 * 1024); + assert_eq!( + far_large_io.quota_bytes_since_check, 0, + "large write must force immediate hard check even far from limit" + ); +} + +#[tokio::test] +async fn remaining_before_zero_rejects_without_calling_inner_writer() { + let user = "direct-zero-remaining-user"; + let limit = 8u64; + let (mut io, stats, write_calls, quota_exceeded) = + make_stats_io_with_script(user, limit, limit, &[1]); + + let err = io + .write(&[0x44]) + .await + .expect_err("write must fail when remaining quota is zero"); + + assert!( + is_quota_io_error(&err), + "zero-remaining gate must return typed quota I/O error" + ); + assert_eq!( + write_calls.load(Ordering::Relaxed), + 0, + "inner poll_write must not be called when remaining quota is zero" + ); + assert!( + quota_exceeded.load(Ordering::Acquire), + "zero-remaining gate must set exceeded flag" + ); + assert_eq!(stats.get_user_quota_used(user), limit); +} + +#[tokio::test] +async fn exceeded_flag_blocks_following_poll_before_inner_write() { + let user = "direct-exceeded-visibility-user"; + let (mut io, stats, write_calls, quota_exceeded) = + make_stats_io_with_script(user, 1, 0, &[1, 1]); + + let first = io + .write(&[0x55]) + .await + .expect("first byte should consume remaining quota"); + assert_eq!(first, 1); + assert!( + quota_exceeded.load(Ordering::Acquire), + "hard check should store quota_exceeded after boundary hit" + ); + + let second = io + .write(&[0x66]) + .await + .expect_err("next write must be rejected by early exceeded gate"); + assert!( + is_quota_io_error(&second), + "following write must fail with typed quota error" + ); + assert_eq!( + write_calls.load(Ordering::Relaxed), + 1, + "second write must be cut before touching inner writer" + ); + assert_eq!(stats.get_user_quota_used(user), 1); +} + +#[test] +fn adaptive_interval_clamp_matches_contract() { + assert_eq!(quota_adaptive_interval_bytes(0), 4 * 1024); + assert_eq!(quota_adaptive_interval_bytes(2 * 1024), 4 * 1024); + assert_eq!(quota_adaptive_interval_bytes(32 * 1024), 16 * 1024); + assert_eq!(quota_adaptive_interval_bytes(256 * 1024), 64 * 1024); + + assert!(should_immediate_quota_check(32 * 1024, 4 * 1024)); + assert!(should_immediate_quota_check(1_048_576, 32 * 1024)); + assert!(!should_immediate_quota_check(1_048_576, 4 * 1024)); +} diff --git a/src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs b/src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs deleted file mode 100644 index 9ea921c..0000000 --- a/src/proxy/tests/relay_cross_mode_pipeline_hol_integration_security_tests.rs +++ /dev/null @@ -1,267 +0,0 @@ -use super::relay_bidirectional; -use crate::stats::Stats; -use crate::stream::BufferPool; -use std::sync::Arc; -use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; -use tokio::time::{Duration, timeout}; - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -#[tokio::test] -async fn negative_same_user_pipeline_stalls_while_middle_lock_is_held() { - let _guard = quota_test_guard(); - - let user = format!("relay-pipeline-stall-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold shared cross-mode lock"); - - let stats = Arc::new(Stats::new()); - let (mut client_peer, relay_client) = duplex(1024); - let (relay_server, mut server_peer) = duplex(1024); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_user = user.clone(); - let relay_stats = Arc::clone(&stats); - let relay_task = tokio::spawn(async move { - relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 256, - 256, - &relay_user, - relay_stats, - Some(1024), - Arc::new(BufferPool::new()), - ) - .await - }); - - server_peer - .write_all(&[0xA1]) - .await - .expect("server write should enqueue while relay is stalled"); - - let mut one = [0u8; 1]; - let blocked_read = timeout(Duration::from_millis(40), client_peer.read_exact(&mut one)).await; - assert!( - blocked_read.is_err(), - "same-user relay must remain blocked while cross-mode lock is held" - ); - - drop(held_guard); - - timeout(Duration::from_millis(400), client_peer.read_exact(&mut one)) - .await - .expect("blocked relay must resume after cross-mode lock release") - .expect("resumed relay must deliver queued byte"); - assert_eq!(one, [0xA1]); - - drop(client_peer); - drop(server_peer); - - let relay_result = timeout(Duration::from_secs(1), relay_task) - .await - .expect("relay task must complete") - .expect("relay task must not panic"); - assert!(relay_result.is_ok()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_other_user_pipeline_progresses_while_blocked_user_is_stalled() { - let _guard = quota_test_guard(); - - let blocked_user = format!("relay-pipeline-blocked-{}", std::process::id()); - let free_user = format!("relay-pipeline-free-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user); - let held_guard = held - .try_lock() - .expect("test must hold blocked user's shared cross-mode lock"); - - let stats_blocked = Arc::new(Stats::new()); - let stats_free = Arc::new(Stats::new()); - - let (mut blocked_client, blocked_relay_client) = duplex(1024); - let (blocked_relay_server, mut blocked_server) = duplex(1024); - let (blocked_client_reader, blocked_client_writer) = tokio::io::split(blocked_relay_client); - let (blocked_server_reader, blocked_server_writer) = tokio::io::split(blocked_relay_server); - - let (mut free_client, free_relay_client) = duplex(1024); - let (free_relay_server, mut free_server) = duplex(1024); - let (free_client_reader, free_client_writer) = tokio::io::split(free_relay_client); - let (free_server_reader, free_server_writer) = tokio::io::split(free_relay_server); - - let blocked_task = { - let user = blocked_user.clone(); - let stats = Arc::clone(&stats_blocked); - tokio::spawn(async move { - relay_bidirectional( - blocked_client_reader, - blocked_client_writer, - blocked_server_reader, - blocked_server_writer, - 256, - 256, - &user, - stats, - Some(1024), - Arc::new(BufferPool::new()), - ) - .await - }) - }; - - let free_task = { - let user = free_user.clone(); - let stats = Arc::clone(&stats_free); - tokio::spawn(async move { - relay_bidirectional( - free_client_reader, - free_client_writer, - free_server_reader, - free_server_writer, - 256, - 256, - &user, - stats, - Some(1024), - Arc::new(BufferPool::new()), - ) - .await - }) - }; - - blocked_server - .write_all(&[0xB1]) - .await - .expect("blocked user server write should queue"); - free_server - .write_all(&[0xC1]) - .await - .expect("free user server write should queue"); - - let mut blocked_buf = [0u8; 1]; - let mut free_buf = [0u8; 1]; - - let blocked_stalled = timeout( - Duration::from_millis(40), - blocked_client.read_exact(&mut blocked_buf), - ) - .await; - assert!( - blocked_stalled.is_err(), - "blocked user must remain stalled while its lock is held" - ); - - timeout(Duration::from_millis(250), free_client.read_exact(&mut free_buf)) - .await - .expect("free user must make progress while other user is blocked") - .expect("free user read must succeed"); - assert_eq!(free_buf, [0xC1]); - - drop(held_guard); - - timeout(Duration::from_millis(400), blocked_client.read_exact(&mut blocked_buf)) - .await - .expect("blocked user must resume after release") - .expect("blocked user resumed read must succeed"); - assert_eq!(blocked_buf, [0xB1]); - - drop(blocked_client); - drop(blocked_server); - drop(free_client); - drop(free_server); - - assert!( - timeout(Duration::from_secs(1), blocked_task) - .await - .expect("blocked relay task must complete") - .expect("blocked relay task must not panic") - .is_ok() - ); - assert!( - timeout(Duration::from_secs(1), free_task) - .await - .expect("free relay task must complete") - .expect("free relay task must not panic") - .is_ok() - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_jittered_hold_release_cycles_preserve_pipeline_liveness() { - let _guard = quota_test_guard(); - - let mut seed = 0x5EED_C0DE_2026_0323u64; - for round in 0..24u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let hold_ms = 2 + (seed % 10); - let user = format!("relay-pipeline-fuzz-{}-{round}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold lock during fuzz round"); - - let stats = Arc::new(Stats::new()); - let (mut client_peer, relay_client) = duplex(1024); - let (relay_server, mut server_peer) = duplex(1024); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_user = user.clone(); - let relay_stats = Arc::clone(&stats); - let relay_task = tokio::spawn(async move { - relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 256, - 256, - &relay_user, - relay_stats, - Some(1024), - Arc::new(BufferPool::new()), - ) - .await - }); - - server_peer - .write_all(&[0xD1]) - .await - .expect("server write should queue in fuzz round"); - - let mut one = [0u8; 1]; - let stalled = timeout(Duration::from_millis(30), client_peer.read_exact(&mut one)).await; - assert!(stalled.is_err(), "held phase must stall same-user relay"); - - tokio::time::sleep(Duration::from_millis(hold_ms)).await; - drop(held_guard); - - timeout(Duration::from_millis(400), client_peer.read_exact(&mut one)) - .await - .expect("released phase must resume same-user relay") - .expect("released phase read must succeed"); - assert_eq!(one, [0xD1]); - - drop(client_peer); - drop(server_peer); - - assert!( - timeout(Duration::from_secs(1), relay_task) - .await - .expect("fuzz relay task must complete") - .expect("fuzz relay task must not panic") - .is_ok() - ); - } -} \ No newline at end of file diff --git a/src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs b/src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs deleted file mode 100644 index c967861..0000000 --- a/src/proxy/tests/relay_cross_mode_pipeline_latency_benchmark_security_tests.rs +++ /dev/null @@ -1,213 +0,0 @@ -use super::relay_bidirectional; -use crate::stats::Stats; -use crate::stream::BufferPool; -use std::sync::{Arc, Mutex}; -use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; -use tokio::sync::{Barrier, watch}; -use tokio::time::{Duration, Instant, timeout}; - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn percentile_index(len: usize, percentile: usize) -> usize { - ((len * percentile) / 100).min(len.saturating_sub(1)) -} - -#[tokio::test] -async fn micro_benchmark_pipeline_release_to_delivery_latency_stays_bounded() { - let _guard = quota_test_guard(); - - let rounds = 64usize; - let user = format!("relay-pipeline-latency-single-{}", std::process::id()); - let mut samples_ms = Vec::with_capacity(rounds); - - for round in 0..rounds { - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold shared cross-mode lock before round"); - - let stats = Arc::new(Stats::new()); - let (mut client_peer, relay_client) = duplex(1024); - let (relay_server, mut server_peer) = duplex(1024); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_user = user.clone(); - let relay_stats = Arc::clone(&stats); - let relay_task = tokio::spawn(async move { - relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 256, - 256, - &relay_user, - relay_stats, - Some(2048), - Arc::new(BufferPool::new()), - ) - .await - }); - - server_peer - .write_all(&[(round as u8) ^ 0xA5]) - .await - .expect("server write should queue before release"); - - let release_at = Instant::now(); - drop(held_guard); - - let mut one = [0u8; 1]; - timeout(Duration::from_millis(450), client_peer.read_exact(&mut one)) - .await - .expect("client must receive queued byte after release") - .expect("queued byte read must succeed"); - samples_ms.push(release_at.elapsed().as_millis() as u64); - - drop(client_peer); - drop(server_peer); - - let relay_result = timeout(Duration::from_secs(1), relay_task) - .await - .expect("relay task must complete") - .expect("relay task must not panic"); - assert!(relay_result.is_ok()); - } - - samples_ms.sort_unstable(); - let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)]; - let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)]; - - assert!( - p50_ms <= 45, - "single-flow release latency p50 must stay bounded; p50_ms={p50_ms}, samples={samples_ms:?}" - ); - assert!( - p95_ms <= 130, - "single-flow release latency p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_128_waiter_pipeline_release_latency_p95_stays_bounded() { - let _guard = quota_test_guard(); - - let waiters = 128usize; - let user = format!("relay-pipeline-latency-fanout-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold shared lock before fanout release benchmark"); - - let ready_barrier = Arc::new(Barrier::new(waiters + 1)); - let release_at = Arc::new(Mutex::new(None::)); - let (release_tx, release_rx) = watch::channel(false); - let mut tasks = Vec::with_capacity(waiters); - - for idx in 0..waiters { - let user = user.clone(); - let barrier = Arc::clone(&ready_barrier); - let release_at = Arc::clone(&release_at); - let mut release_rx = release_rx.clone(); - - tasks.push(tokio::spawn(async move { - let stats = Arc::new(Stats::new()); - let (mut client_peer, relay_client) = duplex(512); - let (relay_server, mut server_peer) = duplex(512); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_user = user; - let relay_stats = Arc::clone(&stats); - let relay_task = tokio::spawn(async move { - relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 256, - 256, - &relay_user, - relay_stats, - Some(2048), - Arc::new(BufferPool::new()), - ) - .await - }); - - server_peer - .write_all(&[(idx as u8) ^ 0x5A]) - .await - .expect("fanout server write should queue before release"); - - barrier.wait().await; - release_rx - .changed() - .await - .expect("release signal should remain available"); - - let started = { - let guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner()); - guard.expect("release timestamp must be populated before signal") - }; - - let mut one = [0u8; 1]; - timeout(Duration::from_millis(900), client_peer.read_exact(&mut one)) - .await - .expect("fanout waiter must receive queued byte after release") - .expect("fanout waiter read must succeed"); - - drop(client_peer); - drop(server_peer); - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("fanout relay task must complete") - .expect("fanout relay task must not panic"); - assert!(relay_result.is_ok()); - - started.elapsed().as_millis() as u64 - })); - } - - ready_barrier.wait().await; - { - let mut guard = release_at.lock().unwrap_or_else(|poison| poison.into_inner()); - *guard = Some(Instant::now()); - } - drop(held_guard); - release_tx - .send(true) - .expect("release broadcast must succeed"); - - let mut samples_ms = Vec::with_capacity(waiters); - timeout(Duration::from_secs(8), async { - for task in tasks { - let elapsed = task.await.expect("fanout waiter must not panic"); - samples_ms.push(elapsed); - } - }) - .await - .expect("fanout benchmark must complete in bounded time"); - - samples_ms.sort_unstable(); - let p50_ms = samples_ms[percentile_index(samples_ms.len(), 50)]; - let p95_ms = samples_ms[percentile_index(samples_ms.len(), 95)]; - let max_ms = *samples_ms.last().unwrap_or(&0); - - assert!( - p50_ms <= 120, - "fanout release latency p50 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}" - ); - assert!( - p95_ms <= 260, - "fanout release latency p95 must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}" - ); - assert!( - max_ms <= 700, - "fanout release latency max must stay bounded; p50_ms={p50_ms}, p95_ms={p95_ms}, max_ms={max_ms}" - ); -} \ No newline at end of file diff --git a/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs b/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs deleted file mode 100644 index adbdb22..0000000 --- a/src/proxy/tests/relay_cross_mode_quota_fairness_tdd_tests.rs +++ /dev/null @@ -1,604 +0,0 @@ -use super::*; -use crate::stats::Stats; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Poll, Waker}; -use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; -use tokio::sync::Barrier; -use tokio::time::{Duration, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn build_context() -> (Arc, Context<'static>) { - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); - (wake_counter, Context::from_waker(leaked_waker)) -} - -#[tokio::test] -async fn positive_cross_mode_uncontended_writer_progresses() { - let _guard = quota_test_guard(); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - "cross-mode-tdd-uncontended".to_string(), - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let result = io.write_all(&[0x11, 0x22]).await; - assert!(result.is_ok(), "uncontended writer must progress"); -} - -#[tokio::test] -async fn adversarial_held_cross_mode_lock_blocks_writer_even_if_local_lock_free() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-tdd-held-{}", std::process::id()); - let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let _held_guard = held - .try_lock() - .expect("test must hold cross-mode lock before polling writer"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]); - assert!(poll.is_pending(), "writer must not bypass held cross-mode lock"); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_parallel_waiters_resume_after_cross_mode_release() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-tdd-resume-{}", std::process::id()); - let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let held_guard = held - .try_lock() - .expect("test must hold cross-mode lock before launching waiters"); - - let stats = Arc::new(Stats::new()); - let mut waiters = Vec::new(); - for _ in 0..16 { - let stats = Arc::clone(&stats); - let user = user.clone(); - waiters.push(tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - stats, - user, - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - io.write_all(&[0x7F]).await - })); - } - - tokio::time::sleep(Duration::from_millis(5)).await; - drop(held_guard); - - timeout(Duration::from_secs(1), async { - for waiter in waiters { - let result = waiter.await.expect("waiter task must not panic"); - assert!(result.is_ok(), "waiter must complete after cross-mode release"); - } - }) - .await - .expect("all waiters must complete in bounded time"); -} - -#[tokio::test] -async fn adversarial_cross_mode_contention_wake_budget_stays_bounded() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-tdd-wakes-{}", std::process::id()); - let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let _held_guard = held - .try_lock() - .expect("test must hold cross-mode lock before polling"); - - let stats = Arc::new(Stats::new()); - let mut ios = Vec::new(); - let mut counters = Vec::new(); - for _ in 0..20 { - ios.push(StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(2048), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - )); - } - - for io in &mut ios { - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let poll = Pin::new(io).poll_write(&mut cx, &[0x33]); - assert!(poll.is_pending()); - counters.push(wake_counter); - } - - tokio::time::sleep(Duration::from_millis(25)).await; - let total_wakes: usize = counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - - assert!( - total_wakes <= 20 * 4, - "cross-mode contention should not create wake storms; wakes={total_wakes}" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn light_fuzz_cross_mode_release_timing_preserves_read_write_liveness() { - let _guard = quota_test_guard(); - - let mut seed = 0xC0DE_BAAD_2026_0322u64; - for round in 0..16u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let sleep_ms = 2 + (seed as u64 % 8); - let user = format!("cross-mode-tdd-fuzz-{}-{round}", std::process::id()); - let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let held_guard = held - .try_lock() - .expect("test must hold cross-mode lock in fuzz round"); - - let stats = Arc::new(Stats::new()); - let user_reader = user.clone(); - let reader_task = tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user_reader, - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - let mut one = [0u8; 1]; - io.read(&mut one).await - }); - - let user_writer = user.clone(); - let writer_task = tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user_writer, - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - io.write_all(&[0x44]).await - }); - - tokio::time::sleep(Duration::from_millis(sleep_ms)).await; - drop(held_guard); - - let read_done = timeout(Duration::from_millis(350), reader_task) - .await - .expect("reader task must complete after release") - .expect("reader task must not panic"); - assert!(read_done.is_ok()); - - let write_done = timeout(Duration::from_millis(350), writer_task) - .await - .expect("writer task must complete after release") - .expect("writer task must not panic"); - assert!(write_done.is_ok()); - } -} - -#[tokio::test] -async fn integration_middle_lock_blocks_relay_reader_for_same_user() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-middle-reader-block-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let _held_guard = held - .try_lock() - .expect("test must hold middle-relay shared lock"); - - let mut io = StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - let mut one = [0u8; 1]; - let mut buf = ReadBuf::new(&mut one); - let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!(poll.is_pending()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn integration_middle_lock_release_unblocks_relay_reader() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-middle-reader-release-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold middle-relay shared lock"); - - let task = tokio::spawn({ - let user = user.clone(); - async move { - let mut io = StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - let mut one = [0u8; 1]; - io.read(&mut one).await - } - }); - - tokio::time::sleep(Duration::from_millis(5)).await; - drop(held_guard); - - let done = timeout(Duration::from_millis(300), task) - .await - .expect("reader task must complete after release") - .expect("reader task must not panic"); - assert!(done.is_ok()); -} - -#[tokio::test] -async fn business_different_user_middle_lock_does_not_block_relay_writer() { - let _guard = quota_test_guard(); - - let held_user = format!("cross-mode-middle-held-{}", std::process::id()); - let active_user = format!("cross-mode-middle-active-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&held_user); - let _held_guard = held - .try_lock() - .expect("test must hold middle-relay lock for other user"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - active_user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x61]); - assert!(matches!(poll, Poll::Ready(Ok(1)))); -} - -#[tokio::test] -async fn edge_quota_none_bypasses_cross_mode_lock_even_when_held() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-none-limit-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let _held_guard = held - .try_lock() - .expect("test must hold lock while quota is disabled"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - None, - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x62, 0x63]); - assert!(matches!(poll, Poll::Ready(Ok(2)))); -} - -#[tokio::test] -async fn edge_quota_exceeded_flag_short_circuits_before_lock_path() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-pre-exceeded-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let _held_guard = held - .try_lock() - .expect("test must hold shared lock before poll"); - - let quota_exceeded = Arc::new(AtomicBool::new(true)); - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::clone("a_exceeded), - tokio::time::Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x64]); - assert!(matches!(poll, Poll::Ready(Err(ref e)) if is_quota_io_error(e))); -} - -#[tokio::test] -async fn adversarial_repoll_while_middle_lock_held_keeps_pending_without_usage_leak() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-repoll-held-{}", std::process::id()); - let stats = Arc::new(Stats::new()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let _held_guard = held - .try_lock() - .expect("test must hold lock for repoll sequence"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - for _ in 0..8 { - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x65]); - assert!(poll.is_pending()); - } - - assert_eq!(stats.get_user_total_octets(&user), 0); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_same_user_mixed_read_write_waiters_resume_after_release() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-mixed-resume-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold lock before spawning mixed waiters"); - - let mut tasks = Vec::new(); - for i in 0..12usize { - let user = user.clone(); - tasks.push(tokio::spawn(async move { - if i % 2 == 0 { - let mut io = StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - let mut b = [0u8; 1]; - io.read(&mut b).await.map(|_| ()) - } else { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - io.write_all(&[0x66]).await - } - })); - } - - tokio::time::sleep(Duration::from_millis(8)).await; - drop(held_guard); - - timeout(Duration::from_secs(1), async { - for task in tasks { - let result = task.await.expect("mixed waiter task must not panic"); - assert!(result.is_ok()); - } - }) - .await - .expect("all mixed waiters must finish after release"); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_one_user_blocked_other_user_progresses_under_middle_lock() { - let _guard = quota_test_guard(); - - let blocked_user = format!("cross-mode-blocked-{}", std::process::id()); - let free_user = format!("cross-mode-free-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&blocked_user); - let held_guard = held - .try_lock() - .expect("test must hold blocked user lock"); - - let blocked_task = tokio::spawn({ - let blocked_user = blocked_user.clone(); - async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - blocked_user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - io.write_all(&[0x77]).await - } - }); - - let free_task = tokio::spawn({ - let free_user = free_user.clone(); - async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - free_user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - io.write_all(&[0x78]).await - } - }); - - let free_done = timeout(Duration::from_millis(250), free_task) - .await - .expect("free user must not be blocked") - .expect("free user task must not panic"); - assert!(free_done.is_ok()); - - drop(held_guard); - let blocked_done = timeout(Duration::from_secs(1), blocked_task) - .await - .expect("blocked user must resume after release") - .expect("blocked user task must not panic"); - assert!(blocked_done.is_ok()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_middle_lock_release_allows_high_waiter_fanout_completion() { - let _guard = quota_test_guard(); - - let user = format!("cross-mode-fanout-{}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold lock before fanout"); - - let waiters = 48usize; - let gate = Arc::new(Barrier::new(waiters + 1)); - let mut tasks = Vec::new(); - for _ in 0..waiters { - let user = user.clone(); - let gate = Arc::clone(&gate); - tasks.push(tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - gate.wait().await; - io.write_all(&[0x79]).await - })); - } - - gate.wait().await; - tokio::time::sleep(Duration::from_millis(10)).await; - drop(held_guard); - - timeout(Duration::from_secs(2), async { - for task in tasks { - let result = task.await.expect("fanout task must not panic"); - assert!(result.is_ok()); - } - }) - .await - .expect("fanout waiters must complete after release"); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn light_fuzz_middle_lock_hold_release_cycles_preserve_same_user_liveness() { - let _guard = quota_test_guard(); - - let mut seed = 0xA11C_EE55_2026_0323u64; - for round in 0..20u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let hold_ms = 2 + (seed % 10); - let user = format!("cross-mode-middle-fuzz-{}-{round}", std::process::id()); - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(&user); - let held_guard = held - .try_lock() - .expect("test must hold lock in fuzz round"); - - let writer = tokio::spawn({ - let user = user.clone(); - async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - io.write_all(&[0x7A]).await - } - }); - - tokio::time::sleep(Duration::from_millis(hold_ms)).await; - drop(held_guard); - - let done = timeout(Duration::from_millis(400), writer) - .await - .expect("writer must complete after lock release") - .expect("writer task must not panic"); - assert!(done.is_ok()); - } -} diff --git a/src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs b/src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs deleted file mode 100644 index 5ea806a..0000000 --- a/src/proxy/tests/relay_cross_mode_quota_lock_security_tests.rs +++ /dev/null @@ -1,81 +0,0 @@ -use super::*; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::Waker; -use std::task::{Context, Poll}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn build_context() -> (Arc, Context<'static>) { - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); - (wake_counter, Context::from_waker(leaked_waker)) -} - -#[tokio::test] -async fn adversarial_middle_held_cross_mode_lock_blocks_relay_writer() { - let _guard = quota_user_lock_test_scope(); - - let user = "cross-mode-lock-shared-user"; - let held = crate::proxy::middle_relay::cross_mode_quota_user_lock_for_tests(user); - let _held_guard = held - .try_lock() - .expect("test must hold shared cross-mode lock before relay poll"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(crate::stats::Stats::new()), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42, 0x43]); - - assert!( - matches!(poll, Poll::Pending), - "relay writer must not bypass cross-mode lock held by middle-relay path" - ); -} - -#[tokio::test] -async fn business_cross_mode_lock_uncontended_allows_relay_writer_progress() { - let _guard = quota_user_lock_test_scope(); - - let user = "cross-mode-lock-progress-user"; - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(crate::stats::Stats::new()), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51, 0x52]); - - assert!( - matches!(poll, Poll::Ready(Ok(2))), - "relay writer should progress when shared cross-mode lock is uncontended" - ); -} diff --git a/src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs b/src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs deleted file mode 100644 index 9ac4621..0000000 --- a/src/proxy/tests/relay_dual_lock_alternating_contention_security_tests.rs +++ /dev/null @@ -1,340 +0,0 @@ -use super::*; -use crate::stats::Stats; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; -use tokio::io::AsyncWriteExt; -use tokio::time::{Duration, Instant, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -#[tokio::test] -async fn positive_uncontended_dual_lock_writer_has_zero_retry_attempt() { - let _guard = quota_test_guard(); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - format!("dual-lock-alt-positive-{}", std::process::id()), - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let write = io.write_all(&[0xAA, 0xBB]).await; - assert!(write.is_ok(), "uncontended write must complete"); - assert_eq!( - io.quota_write_retry_attempt, 0, - "uncontended write must not advance retry backoff" - ); -} - -#[tokio::test] -async fn adversarial_alternating_local_and_cross_mode_contention_preserves_backoff_growth() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-alt-adversarial-{}", std::process::id()); - let local_lock = quota_user_lock(&user); - let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - - let mut local_guard = Some( - local_lock - .try_lock() - .expect("test must hold local quota lock initially"), - ); - let mut cross_guard = None; - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); - assert!(first.is_pending(), "held local lock must block first poll"); - - let mut observed_wakes = 0usize; - for idx in 0..18usize { - tokio::time::sleep(Duration::from_millis(6)).await; - - if idx % 2 == 0 { - drop(local_guard.take()); - cross_guard = Some( - cross_mode_lock - .try_lock() - .expect("cross-mode lock should be acquirable while local lock released"), - ); - } else { - drop(cross_guard.take()); - local_guard = Some( - local_lock - .try_lock() - .expect("local lock should be acquirable while cross lock released"), - ); - } - - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - if wakes > observed_wakes { - observed_wakes = wakes; - let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x12]); - assert!( - pending.is_pending(), - "alternating contention must keep write pending while one lock is held" - ); - } - } - - assert!( - io.quota_write_retry_attempt >= 2, - "alternating contention must still ramp retry backoff; got {}", - io.quota_write_retry_attempt - ); - assert!( - wake_counter.wakes.load(Ordering::Relaxed) <= 32, - "alternating contention must stay wake-rate-limited" - ); - - drop(local_guard); - drop(cross_guard); - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x13]); - assert!(ready.is_ready(), "writer must resume after both locks released"); -} - -#[tokio::test] -async fn edge_retry_scheduler_resets_after_alternating_contention_clears() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-alt-edge-reset-{}", std::process::id()); - let local_lock = quota_user_lock(&user); - let local_guard = local_lock - .try_lock() - .expect("test must hold local lock for edge scenario"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0x21]); - assert!(first.is_pending()); - tokio::time::sleep(Duration::from_millis(15)).await; - if wake_counter.wakes.load(Ordering::Relaxed) > 0 { - let next = Pin::new(&mut io).poll_write(&mut cx, &[0x22]); - assert!(next.is_pending()); - } - - drop(local_guard); - - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x23]); - assert!(ready.is_ready()); - assert_eq!( - io.quota_write_retry_attempt, 0, - "successful dual-lock acquisition must reset retry scheduler" - ); - assert!(!io.quota_write_wake_scheduled); - assert!(io.quota_write_retry_sleep.is_none()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_cross_mode_waiters_remain_live_under_alternating_contention_then_resume() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-alt-integration-{}", std::process::id()); - let local_lock = quota_user_lock(&user); - let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - - let mut waiters = Vec::new(); - for _ in 0..16usize { - let user = user.clone(); - waiters.push(tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - timeout(Duration::from_secs(2), io.write_all(&[0x31])).await - })); - } - - let mut local_guard = Some( - local_lock - .try_lock() - .expect("integration toggle must acquire local lock first"), - ); - let mut cross_guard = None; - - for idx in 0..24usize { - tokio::time::sleep(Duration::from_millis(4)).await; - if idx % 2 == 0 { - drop(local_guard.take()); - cross_guard = cross_mode_lock.try_lock().ok(); - } else { - drop(cross_guard.take()); - local_guard = local_lock.try_lock().ok(); - } - } - - drop(local_guard); - drop(cross_guard); - - for waiter in waiters { - let done = waiter.await.expect("waiter task must not panic"); - assert!( - done.is_ok(), - "waiter must finish once alternating contention window ends" - ); - assert!(done.expect("waiter timeout must not fire").is_ok()); - } -} - -#[tokio::test] -async fn light_fuzz_alternating_contention_matrix_preserves_lock_gating() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-alt-fuzz-{}", std::process::id()); - let local_lock = quota_user_lock(&user); - let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let mut seed = 0xD00D_BAAD_F00D_2026u64; - - for _round in 0..64u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let hold_mode = (seed % 3) as u8; - let local_guard = if hold_mode == 0 { - Some( - local_lock - .try_lock() - .expect("fuzz local lock should be acquirable"), - ) - } else { - None - }; - let cross_guard = if hold_mode == 1 { - Some( - cross_mode_lock - .try_lock() - .expect("fuzz cross lock should be acquirable"), - ) - } else { - None - }; - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user.clone(), - Some(1024), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let write = timeout(Duration::from_millis(35), io.write_all(&[0x51])).await; - if hold_mode == 2 { - assert!(write.is_ok(), "unheld fuzz round must make progress"); - assert!(write.expect("unheld round timeout").is_ok()); - } else { - assert!( - write.is_err(), - "held-lock fuzz round must remain pending inside bounded window" - ); - } - - drop(local_guard); - drop(cross_guard); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_fanout_alternating_contention_recovers_without_hanging() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-alt-stress-{}", std::process::id()); - let local_lock = quota_user_lock(&user); - let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - - let mut waiters = Vec::new(); - for _ in 0..48usize { - let user = user.clone(); - waiters.push(tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(4096), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - timeout(Duration::from_secs(3), io.write_all(&[0xA0, 0xA1])).await - })); - } - - let mut local_guard = Some( - local_lock - .try_lock() - .expect("stress toggle must acquire local lock first"), - ); - let mut cross_guard = None; - for idx in 0..40usize { - tokio::time::sleep(Duration::from_millis(3)).await; - if idx % 2 == 0 { - drop(local_guard.take()); - cross_guard = cross_mode_lock.try_lock().ok(); - } else { - drop(cross_guard.take()); - local_guard = local_lock.try_lock().ok(); - } - } - - drop(local_guard); - drop(cross_guard); - - for waiter in waiters { - let done = waiter.await.expect("stress waiter task must not panic"); - assert!(done.is_ok(), "stress waiter timed out under alternating contention"); - assert!(done.expect("stress waiter timeout should not fire").is_ok()); - } -} diff --git a/src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs b/src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs deleted file mode 100644 index ce26941..0000000 --- a/src/proxy/tests/relay_dual_lock_backoff_regression_security_tests.rs +++ /dev/null @@ -1,74 +0,0 @@ -use super::*; -use crate::stats::Stats; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; -use tokio::time::{Duration, Instant}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -#[tokio::test] -async fn adversarial_cross_mode_only_contention_backoff_attempt_must_ramp() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-backoff-{}", std::process::id()); - let cross_mode_lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let held_cross_mode_guard = cross_mode_lock - .try_lock() - .expect("test must hold cross-mode lock before polling"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]); - assert!(first.is_pending(), "held cross-mode lock must block writer"); - - let started = Instant::now(); - let mut last_wakes = 0usize; - while started.elapsed() < Duration::from_millis(120) { - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - if wakes > last_wakes { - last_wakes = wakes; - let next = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]); - assert!(next.is_pending(), "writer must remain blocked while lock is held"); - } - tokio::time::sleep(Duration::from_millis(1)).await; - } - - assert!( - io.quota_write_retry_attempt >= 2, - "retry attempt must ramp under sustained second-lock contention; got {}", - io.quota_write_retry_attempt - ); - - drop(held_cross_mode_guard); -} diff --git a/src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs b/src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs deleted file mode 100644 index 513d92b..0000000 --- a/src/proxy/tests/relay_dual_lock_contention_matrix_security_tests.rs +++ /dev/null @@ -1,325 +0,0 @@ -use super::*; -use crate::stats::Stats; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; -use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; -use tokio::time::{Duration, Instant, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn build_context() -> (Arc, Context<'static>) { - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); - (wake_counter, Context::from_waker(leaked_waker)) -} - -#[tokio::test] -async fn positive_uncontended_dual_locks_writer_completes_without_retry_state() { - let _guard = quota_test_guard(); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - format!("dual-lock-positive-{}", std::process::id()), - Some(4096), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x01, 0x02, 0x03]); - assert!(poll.is_ready()); - assert_eq!(io.quota_write_retry_attempt, 0); - assert!(!io.quota_write_wake_scheduled); - assert!(io.quota_write_retry_sleep.is_none()); -} - -#[tokio::test] -async fn negative_local_lock_contention_read_retry_attempt_ramps() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-local-contention-{}", std::process::id()); - let held = quota_user_lock(&user); - let held_guard = held - .try_lock() - .expect("test must hold local quota lock before polling"); - - let mut io = StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let (wake_counter, mut cx) = build_context(); - let mut one = [0u8; 1]; - let mut buf = ReadBuf::new(&mut one); - let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!(first.is_pending()); - - let started = Instant::now(); - let mut observed = 0usize; - while started.elapsed() < Duration::from_millis(120) { - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - if wakes > observed { - observed = wakes; - let mut step_buf = ReadBuf::new(&mut one); - let next = Pin::new(&mut io).poll_read(&mut cx, &mut step_buf); - assert!(next.is_pending()); - } - tokio::time::sleep(Duration::from_millis(1)).await; - } - - assert!( - io.quota_read_retry_attempt >= 2, - "retry attempt must ramp under sustained local-lock contention; got {}", - io.quota_read_retry_attempt - ); - - drop(held_guard); -} - -#[tokio::test] -async fn edge_cross_mode_contention_release_resets_retry_scheduler_on_success() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-reset-{}", std::process::id()); - let cross_mode = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let held_guard = cross_mode - .try_lock() - .expect("test must hold cross-mode lock before polling"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let (wake_counter, mut cx) = build_context(); - let first = Pin::new(&mut io).poll_write(&mut cx, &[0x10]); - assert!(first.is_pending()); - - tokio::time::sleep(Duration::from_millis(20)).await; - if wake_counter.wakes.load(Ordering::Relaxed) > 0 { - let next = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); - assert!(next.is_pending()); - } - - drop(held_guard); - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x12]); - assert!(ready.is_ready()); - assert_eq!(io.quota_write_retry_attempt, 0); - assert!(!io.quota_write_wake_scheduled); - assert!(io.quota_write_retry_sleep.is_none()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_cross_mode_hold_blocks_many_waiters_without_usage_leak() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-adversarial-{}", std::process::id()); - let stats = Arc::new(Stats::new()); - let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let held_guard = held - .try_lock() - .expect("test must hold cross-mode lock before launching waiters"); - - let mut tasks = Vec::new(); - for _ in 0..24usize { - let stats = Arc::clone(&stats); - let user = user.clone(); - tasks.push(tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - stats, - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - timeout(Duration::from_millis(40), io.write_all(&[0x33])).await - })); - } - - for task in tasks { - let timed = task.await.expect("waiter task must not panic"); - assert!(timed.is_err(), "held cross-mode lock must keep waiter pending"); - } - - assert_eq!(stats.get_user_total_octets(&user), 0); - drop(held_guard); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_waiters_resume_after_cross_mode_release() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-integration-{}", std::process::id()); - let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let held_guard = held - .try_lock() - .expect("test must hold cross-mode lock before starting waiter"); - - let task = tokio::spawn({ - let user = user.clone(); - async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - io.write_all(&[0x44]).await - } - }); - - tokio::time::sleep(Duration::from_millis(10)).await; - drop(held_guard); - - let done = timeout(Duration::from_secs(1), task) - .await - .expect("waiter task must complete after release") - .expect("waiter task must not panic"); - assert!(done.is_ok()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn light_fuzz_randomized_lock_holds_preserve_liveness_and_quota_bounds() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-fuzz-{}", std::process::id()); - let stats = Arc::new(Stats::new()); - let mut seed = 0xA55A_55AA_C3D2_E1F0u64; - - for _round in 0..48u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let hold_mode = (seed % 3) as u8; - let mut local_lock = None; - let mut cross_lock = None; - let mut local_guard = None; - let mut cross_guard = None; - - if hold_mode == 0 { - local_lock = Some(quota_user_lock(&user)); - local_guard = Some( - local_lock - .as_ref() - .expect("local lock should be present") - .try_lock() - .expect("local lock should be acquirable in fuzz round"), - ); - } else if hold_mode == 1 { - cross_lock = Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock( - &user, - )); - cross_guard = Some( - cross_lock - .as_ref() - .expect("cross lock should be present") - .try_lock() - .expect("cross lock should be acquirable in fuzz round"), - ); - } - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(4096), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let write = timeout(Duration::from_millis(25), io.write_all(&[0x7A])).await; - if hold_mode == 2 { - assert!(write.is_ok(), "unheld round must make progress"); - } else { - assert!(write.is_err(), "held-lock round must stay blocked within timeout"); - } - - drop(local_guard); - drop(cross_guard); - drop(local_lock); - drop(cross_lock); - } - - assert!(stats.get_user_total_octets(&user) <= 4096); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_fanout_waiters_complete_after_release_without_panics() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-stress-{}", std::process::id()); - let held = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let held_guard = held - .try_lock() - .expect("test must hold cross-mode lock before stress fanout"); - - let waiters = 64usize; - let mut tasks = Vec::new(); - for _ in 0..waiters { - let user = user.clone(); - tasks.push(tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(1024), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - let mut one = [0u8; 1]; - io.read(&mut one).await - })); - } - - tokio::time::sleep(Duration::from_millis(12)).await; - drop(held_guard); - - timeout(Duration::from_secs(2), async { - for task in tasks { - let result = task.await.expect("stress waiter task must not panic"); - assert!(result.is_ok()); - } - }) - .await - .expect("all stress waiters must complete after release"); -} diff --git a/src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs b/src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs deleted file mode 100644 index ec180e8..0000000 --- a/src/proxy/tests/relay_dual_lock_race_harness_security_tests.rs +++ /dev/null @@ -1,128 +0,0 @@ -use super::*; -use crate::stats::Stats; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use tokio::io::AsyncWriteExt; -use tokio::time::{Duration, timeout}; - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn make_stats_io(user: String) -> StatsIo { - StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ) -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn light_fuzz_1024_round_hold_release_cycles_preserve_same_user_liveness() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-race-fuzz-{}", std::process::id()); - let mut seed = 0xD1CE_BAAD_5EED_1234u64; - - for round in 0..1024u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let hold = (seed & 1) == 0; - let hold_ms = (seed % 3) as u64; - - let maybe_lock = if hold { - Some(crate::proxy::quota_lock_registry::cross_mode_quota_user_lock( - &user, - )) - } else { - None - }; - - let maybe_guard = maybe_lock.as_ref().map(|lock| { - lock.try_lock() - .expect("cross-mode lock must be acquirable in fuzz round") - }); - - if hold { - let mut blocked_io = make_stats_io(user.clone()); - let blocked = timeout(Duration::from_millis(5), blocked_io.write_all(&[0xA5])).await; - assert!( - blocked.is_err(), - "held round must block waiter before lock release (round={round})" - ); - - if hold_ms > 0 { - tokio::time::sleep(Duration::from_millis(hold_ms)).await; - } - } else { - let mut free_io = make_stats_io(user.clone()); - let free = timeout(Duration::from_millis(120), free_io.write_all(&[0xA5])).await; - assert!( - free.is_ok(), - "unheld round must complete promptly (round={round})" - ); - assert!(free.expect("unheld round should complete").is_ok()); - } - - drop(maybe_guard); - - let done = timeout(Duration::from_millis(350), async { - let user = user.clone(); - let mut io = make_stats_io(user); - io.write_all(&[0xA6]).await - }) - .await - .expect("post-release write must complete in bounded time"); - assert!(done.is_ok()); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_jittered_three_waiter_rounds_do_not_starve_after_release() { - let _guard = quota_test_guard(); - - let user = format!("dual-lock-race-stress-{}", std::process::id()); - let mut seed = 0xC0FF_EE77_4444_9999u64; - - for round in 0..256u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let hold_ms = (seed % 4) as u64; - let lock = crate::proxy::quota_lock_registry::cross_mode_quota_user_lock(&user); - let guard = lock - .try_lock() - .expect("cross-mode lock must be acquirable at round start"); - - let mut waiters = Vec::new(); - for _ in 0..3usize { - let user = user.clone(); - waiters.push(tokio::spawn(async move { - let mut io = make_stats_io(user); - io.write_all(&[0x55]).await - })); - } - - tokio::time::sleep(Duration::from_millis(hold_ms)).await; - drop(guard); - - timeout(Duration::from_secs(1), async { - for waiter in waiters { - let done = waiter.await.expect("waiter task must not panic"); - assert!( - done.is_ok(), - "waiter must complete after release (round={round})" - ); - } - }) - .await - .expect("all waiters must complete in bounded time after release"); - } -} diff --git a/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs b/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs index 080240a..9a32b26 100644 --- a/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs +++ b/src/proxy/tests/relay_quota_boundary_blackhat_tests.rs @@ -29,6 +29,11 @@ async fn read_available(reader: &mut R, budget: Duration) total } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn integration_full_duplex_exact_budget_then_hard_cutoff() { let stats = Arc::new(Stats::new()); @@ -102,14 +107,14 @@ async fn integration_full_duplex_exact_budget_then_hard_cutoff() { relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-full-duplex-boundary-user" )); - assert!(stats.get_user_total_octets(user) <= 10); + assert!(stats.get_user_quota_used(user) <= 10); } #[tokio::test] async fn negative_preloaded_quota_blocks_both_directions_immediately() { let stats = Arc::new(Stats::new()); let user = "quota-preloaded-cutoff-user"; - stats.add_user_octets_from(user, 5); + preload_user_quota(stats.as_ref(), user, 5); let (mut client_peer, relay_client) = duplex(2048); let (relay_server, mut server_peer) = duplex(2048); @@ -154,7 +159,7 @@ async fn negative_preloaded_quota_blocks_both_directions_immediately() { relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 5); + assert!(stats.get_user_quota_used(user) <= 5); } #[tokio::test] @@ -212,7 +217,7 @@ async fn edge_quota_one_bidirectional_race_allows_at_most_one_forwarded_octet() relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 1); + assert!(stats.get_user_quota_used(user) <= 1); } #[tokio::test] @@ -277,7 +282,7 @@ async fn adversarial_blackhat_alternating_fragmented_jitter_never_overshoots_glo delivered_to_server + delivered_to_client <= quota as usize, "combined forwarded bytes must never exceed configured quota" ); - assert!(stats.get_user_total_octets(user) <= quota); + assert!(stats.get_user_quota_used(user) <= quota); } #[tokio::test] @@ -356,7 +361,7 @@ async fn light_fuzz_randomized_schedule_preserves_quota_and_forwarded_byte_invar "fuzz case {case}: forwarded bytes must not exceed quota" ); assert!( - stats.get_user_total_octets(&user) <= quota, + stats.get_user_quota_used(&user) <= quota, "fuzz case {case}: accounted bytes must not exceed quota" ); } @@ -451,7 +456,7 @@ async fn stress_multi_relay_same_user_mixed_direction_jitter_respects_global_quo } assert!( - stats.get_user_total_octets(user) <= quota, + stats.get_user_quota_used(user) <= quota, "global per-user quota must hold under concurrent mixed-direction relay stress" ); assert!( diff --git a/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs index 5ee6522..8ce1c26 100644 --- a/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs +++ b/src/proxy/tests/relay_quota_extended_attack_surface_security_tests.rs @@ -5,10 +5,13 @@ use crate::stream::BufferPool; use rand::rngs::StdRng; use rand::{RngExt, SeedableRng}; use std::sync::Arc; -use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; use tokio::time::{Duration, timeout}; -async fn read_available(reader: &mut R, budget: Duration) -> usize { +async fn read_available( + reader: &mut R, + budget: Duration, +) -> usize { let start = tokio::time::Instant::now(); let mut total = 0usize; let mut buf = [0u8; 128]; @@ -29,6 +32,11 @@ async fn read_available(reader: &mut R, budget: total } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn positive_quota_path_forwards_both_directions_within_limit() { let stats = Arc::new(Stats::new()); @@ -52,25 +60,34 @@ async fn positive_quota_path_forwards_both_directions_within_limit() { Arc::new(BufferPool::new()), )); - client_peer.write_all(&[0xAA, 0xBB, 0xCC, 0xDD]).await.unwrap(); + client_peer + .write_all(&[0xAA, 0xBB, 0xCC, 0xDD]) + .await + .unwrap(); server_peer.read_exact(&mut [0u8; 4]).await.unwrap(); - server_peer.write_all(&[0x11, 0x22, 0x33, 0x44]).await.unwrap(); + server_peer + .write_all(&[0x11, 0x22, 0x33, 0x44]) + .await + .unwrap(); client_peer.read_exact(&mut [0u8; 4]).await.unwrap(); drop(client_peer); drop(server_peer); - let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); assert!(relay_result.is_ok()); - assert!(stats.get_user_total_octets(user) <= 16); + assert!(stats.get_user_quota_used(user) <= 16); } #[tokio::test] async fn negative_preloaded_quota_forbids_any_forwarding() { let stats = Arc::new(Stats::new()); let user = "quota-extended-negative-user"; - stats.add_user_octets_from(user, 8); + preload_user_quota(stats.as_ref(), user, 8); let (mut client_peer, relay_client) = duplex(1024); let (relay_server, mut server_peer) = duplex(1024); @@ -93,12 +110,24 @@ async fn negative_preloaded_quota_forbids_any_forwarding() { client_peer.write_all(&[0xAA]).await.unwrap(); server_peer.write_all(&[0xBB]).await.unwrap(); - assert_eq!(read_available(&mut server_peer, Duration::from_millis(120)).await, 0); - assert_eq!(read_available(&mut client_peer, Duration::from_millis(120)).await, 0); + assert_eq!( + read_available(&mut server_peer, Duration::from_millis(120)).await, + 0 + ); + assert_eq!( + read_available(&mut client_peer, Duration::from_millis(120)).await, + 0 + ); - let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); - assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); - assert!(stats.get_user_total_octets(user) <= 8); + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); + assert!(stats.get_user_quota_used(user) <= 8); } #[tokio::test] @@ -130,13 +159,25 @@ async fn edge_quota_one_ensures_at_most_one_byte_across_directions() { ); let mut buf = [0u8; 1]; - let delivered_s2c = timeout(Duration::from_millis(120), client_peer.read(&mut buf)).await.unwrap().unwrap_or(0); - let delivered_c2s = timeout(Duration::from_millis(120), server_peer.read(&mut buf)).await.unwrap().unwrap_or(0); + let delivered_s2c = timeout(Duration::from_millis(120), client_peer.read(&mut buf)) + .await + .unwrap() + .unwrap_or(0); + let delivered_c2s = timeout(Duration::from_millis(120), server_peer.read(&mut buf)) + .await + .unwrap() + .unwrap_or(0); assert!(delivered_s2c + delivered_c2s <= 1); - let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); - assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); } #[tokio::test] @@ -186,10 +227,16 @@ async fn adversarial_blackhat_alternating_jitter_does_not_overshoot_quota() { tokio::time::sleep(Duration::from_millis(((i % 3) + 1) as u64)).await; } - let relay_result = timeout(Duration::from_secs(3), relay).await.unwrap().unwrap(); - assert!(matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + let relay_result = timeout(Duration::from_secs(3), relay) + .await + .unwrap() + .unwrap(); + assert!(matches!( + relay_result, + Err(ProxyError::DataQuotaExceeded { .. }) + )); assert!(total_forwarded <= quota as usize); - assert!(stats.get_user_total_octets(user) <= quota); + assert!(stats.get_user_quota_used(user) <= quota); } #[tokio::test] @@ -234,13 +281,17 @@ async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() { if rng.random::() { let _ = client_peer.write_all(&[rng.random::()]).await; let mut one = [0u8; 1]; - if let Ok(Ok(n)) = timeout(Duration::from_millis(4), server_peer.read(&mut one)).await { + if let Ok(Ok(n)) = + timeout(Duration::from_millis(4), server_peer.read(&mut one)).await + { total_forwarded += n; } } else { let _ = server_peer.write_all(&[rng.random::()]).await; let mut one = [0u8; 1]; - if let Ok(Ok(n)) = timeout(Duration::from_millis(4), client_peer.read(&mut one)).await { + if let Ok(Ok(n)) = + timeout(Duration::from_millis(4), client_peer.read(&mut one)).await + { total_forwarded += n; } } @@ -249,10 +300,16 @@ async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() { drop(client_peer); drop(server_peer); - let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); - assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!( + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })) + ); assert!(total_forwarded <= quota as usize); - assert!(stats.get_user_total_octets(&user) <= quota); + assert!(stats.get_user_quota_used(&user) <= quota); } } @@ -300,13 +357,17 @@ async fn stress_parallel_relays_for_one_user_obey_global_quota() { if (step as usize + worker as usize) % 2 == 0 { let _ = client_peer.write_all(&[(step ^ 0x5A)]).await; let mut one = [0u8; 1]; - if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await { + if let Ok(Ok(n)) = + timeout(Duration::from_millis(6), server_peer.read(&mut one)).await + { total += n; } } else { let _ = server_peer.write_all(&[(step ^ 0xA5)]).await; let mut one = [0u8; 1]; - if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await { + if let Ok(Ok(n)) = + timeout(Duration::from_millis(6), client_peer.read(&mut one)).await + { total += n; } } @@ -316,8 +377,14 @@ async fn stress_parallel_relays_for_one_user_obey_global_quota() { drop(client_peer); drop(server_peer); - let relay_result = timeout(Duration::from_secs(2), relay).await.unwrap().unwrap(); - assert!(relay_result.is_ok() || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))); + let relay_result = timeout(Duration::from_secs(2), relay) + .await + .unwrap() + .unwrap(); + assert!( + relay_result.is_ok() + || matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. })) + ); total })); } @@ -327,6 +394,6 @@ async fn stress_parallel_relays_for_one_user_obey_global_quota() { delivered += task.await.unwrap(); } - assert!(stats.get_user_total_octets(&user) <= quota); + assert!(stats.get_user_quota_used(&user) <= quota); assert!(delivered <= quota as usize); } diff --git a/src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs b/src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs deleted file mode 100644 index 806efb6..0000000 --- a/src/proxy/tests/relay_quota_lock_eviction_lifecycle_tdd_tests.rs +++ /dev/null @@ -1,79 +0,0 @@ -use super::*; -use dashmap::DashMap; -use std::sync::Arc; -use tokio::time::{Duration, timeout}; - -#[test] -fn tdd_explicit_quota_lock_evict_reclaims_only_unheld_entries() { - let _guard = quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let held_user = format!("quota-evict-held-{}", std::process::id()); - let stale_a_user = format!("quota-evict-stale-a-{}", std::process::id()); - let stale_b_user = format!("quota-evict-stale-b-{}", std::process::id()); - - let held = quota_user_lock(&held_user); - let stale_a = quota_user_lock(&stale_a_user); - let stale_b = quota_user_lock(&stale_b_user); - - assert!(map.get(&held_user).is_some()); - assert!(map.get(&stale_a_user).is_some()); - assert!(map.get(&stale_b_user).is_some()); - - drop(stale_a); - drop(stale_b); - - quota_user_lock_evict(); - - assert!( - map.get(&held_user).is_some(), - "held entry must survive eviction" - ); - assert!( - map.get(&stale_a_user).is_none(), - "unheld stale entry must be reclaimed" - ); - assert!( - map.get(&stale_b_user).is_none(), - "unheld stale entry must be reclaimed" - ); - - drop(held); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn tdd_periodic_quota_lock_evictor_reclaims_stale_entries_off_hot_path() { - let _guard = quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let held_user = format!("quota-evict-loop-held-{}", std::process::id()); - let stale_user = format!("quota-evict-loop-stale-{}", std::process::id()); - - let held = quota_user_lock(&held_user); - let stale = quota_user_lock(&stale_user); - - assert_eq!(map.len(), 2); - drop(stale); - - let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5)); - - timeout(Duration::from_millis(200), async { - loop { - if map.get(&stale_user).is_none() { - break; - } - tokio::time::sleep(Duration::from_millis(5)).await; - } - }) - .await - .expect("periodic quota lock evictor must reclaim stale entry"); - - evictor.abort(); - - assert!(map.get(&held_user).is_some()); - assert!(map.get(&stale_user).is_none()); - - drop(held); -} diff --git a/src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs b/src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs deleted file mode 100644 index 251582a..0000000 --- a/src/proxy/tests/relay_quota_lock_eviction_stress_security_tests.rs +++ /dev/null @@ -1,153 +0,0 @@ -use super::*; -use dashmap::DashMap; -use std::sync::Arc; -use tokio::task::JoinSet; -use tokio::time::{Duration, timeout}; - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_background_evictor_with_high_churn_keeps_cache_bounded_and_live() { - let _guard = quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(5)); - - let mut tasks = JoinSet::new(); - for worker in 0..24u32 { - tasks.spawn(async move { - for round in 0..320u32 { - let user = format!( - "quota-evict-stress-user-{}-{}-{}", - std::process::id(), - worker, - round - ); - let lock = quota_user_lock(&user); - if round % 19 == 0 { - tokio::task::yield_now().await; - } - drop(lock); - } - }); - } - - while let Some(done) = tasks.join_next().await { - done.expect("stress worker must not panic"); - } - - quota_user_lock_evict(); - tokio::time::sleep(Duration::from_millis(20)).await; - - assert!( - map.len() <= QUOTA_USER_LOCKS_MAX, - "quota lock map must remain bounded after churn + eviction" - ); - - let sanity_user = format!("quota-evict-stress-sanity-{}", std::process::id()); - let sanity_lock = quota_user_lock(&sanity_user); - assert!( - map.get(&sanity_user).is_some(), - "sanity user should be cacheable after eviction reclaimed stale entries" - ); - - drop(sanity_lock); - evictor.abort(); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_held_lock_survives_repeated_eviction_then_reclaims_after_release() { - let _guard = quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let held_user = format!("quota-evict-held-survive-{}", std::process::id()); - let held = quota_user_lock(&held_user); - - let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(3)); - - for idx in 0..512u32 { - let user = format!("quota-evict-held-churn-{}-{}", std::process::id(), idx); - let temp = quota_user_lock(&user); - drop(temp); - if idx % 32 == 0 { - tokio::task::yield_now().await; - } - } - - let reacquired = quota_user_lock(&held_user); - assert!( - Arc::ptr_eq(&held, &reacquired), - "held user lock identity must remain stable across repeated evictions" - ); - assert!( - map.get(&held_user).is_some(), - "held user entry must not be reclaimed while externally referenced" - ); - - drop(reacquired); - drop(held); - - timeout(Duration::from_millis(300), async { - loop { - if map.get(&held_user).is_none() { - break; - } - tokio::time::sleep(Duration::from_millis(5)).await; - } - }) - .await - .expect("released held lock must be reclaimed by periodic evictor"); - - evictor.abort(); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_saturation_then_periodic_eviction_recovers_cacheability_without_inline_retain() { - let _guard = quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - let prefix = format!("quota-evict-saturated-{}", std::process::id()); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX); - - let overflow_user = format!("quota-evict-overflow-user-{}", std::process::id()); - let overflow_before = quota_user_lock(&overflow_user); - assert!( - map.get(&overflow_user).is_none(), - "saturated map must initially route new user to overflow stripe" - ); - - drop(retained); - - let evictor = spawn_quota_user_lock_evictor_for_tests(Duration::from_millis(4)); - - timeout(Duration::from_millis(400), async { - loop { - if map.len() < QUOTA_USER_LOCKS_MAX { - break; - } - tokio::time::sleep(Duration::from_millis(5)).await; - } - }) - .await - .expect("periodic evictor must reclaim stale saturated entries"); - - let overflow_after = quota_user_lock(&overflow_user); - assert!( - map.get(&overflow_user).is_some(), - "after eviction, overflow user should become cacheable again" - ); - assert!( - Arc::strong_count(&overflow_after) >= 2, - "cacheable lock should be held by map and caller" - ); - - drop(overflow_before); - drop(overflow_after); - evictor.abort(); -} diff --git a/src/proxy/tests/relay_quota_lock_identity_security_tests.rs b/src/proxy/tests/relay_quota_lock_identity_security_tests.rs deleted file mode 100644 index f717f54..0000000 --- a/src/proxy/tests/relay_quota_lock_identity_security_tests.rs +++ /dev/null @@ -1,135 +0,0 @@ -use super::*; -use crate::stats::Stats; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::Waker; -use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn build_context() -> (Arc, Context<'static>) { - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - // Context stores a reference; leak one Waker for deterministic test scope. - let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); - (wake_counter, Context::from_waker(leaked_waker)) -} - -#[tokio::test] -async fn adversarial_map_churn_cannot_bypass_held_writer_lock() { - let _guard = quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let user = "quota-identity-writer-user"; - let held_lock = quota_user_lock(user); - let _held_guard = held_lock - .try_lock() - .expect("test must hold initial user lock before StatsIo poll"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - map.clear(); - let churned_lock = quota_user_lock(user); - assert!( - !Arc::ptr_eq(&held_lock, &churned_lock), - "precondition: map churn should produce a distinct lock identity" - ); - - let (_wake_counter, mut cx) = build_context(); - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11, 0x22, 0x33, 0x44]); - - assert!( - matches!(poll, Poll::Pending), - "writer must remain pending on the originally-held lock identity" - ); -} - -#[tokio::test] -async fn adversarial_map_churn_cannot_bypass_held_reader_lock() { - let _guard = quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let user = "quota-identity-reader-user"; - let held_lock = quota_user_lock(user); - let _held_guard = held_lock - .try_lock() - .expect("test must hold initial user lock before StatsIo poll"); - - let mut io = StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - map.clear(); - let churned_lock = quota_user_lock(user); - assert!( - !Arc::ptr_eq(&held_lock, &churned_lock), - "precondition: map churn should produce a distinct lock identity" - ); - - let (_wake_counter, mut cx) = build_context(); - let mut storage = [0u8; 8]; - let mut read_buf = ReadBuf::new(&mut storage); - let poll = Pin::new(&mut io).poll_read(&mut cx, &mut read_buf); - - assert!( - matches!(poll, Poll::Pending), - "reader must remain pending on the originally-held lock identity" - ); -} - -#[tokio::test] -async fn business_no_lock_contention_keeps_writer_progress() { - let _guard = quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let user = "quota-identity-progress-user"; - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xAA, 0xBB]); - - assert!( - matches!(poll, Poll::Ready(Ok(2))), - "writer should progress immediately without contention" - ); -} diff --git a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs b/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs deleted file mode 100644 index 5687965..0000000 --- a/src/proxy/tests/relay_quota_lock_pressure_adversarial_tests.rs +++ /dev/null @@ -1,440 +0,0 @@ -use super::*; -use crate::error::ProxyError; -use crate::stats::Stats; -use crate::stream::BufferPool; -use dashmap::DashMap; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use std::time::Duration; -use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex}; -use tokio::sync::Barrier; -use tokio::time::Instant; - -#[test] -fn quota_lock_same_user_returns_same_arc_instance() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let a = quota_user_lock("quota-lock-same-user"); - let b = quota_user_lock("quota-lock-same-user"); - assert!(Arc::ptr_eq(&a, &b)); -} - -#[test] -fn quota_lock_parallel_same_user_reuses_single_lock() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let user = "quota-lock-parallel-same"; - let mut handles = Vec::new(); - - for _ in 0..64 { - handles.push(std::thread::spawn(move || quota_user_lock(user))); - } - - let first = handles - .remove(0) - .join() - .expect("thread must return lock handle"); - - for handle in handles { - let got = handle.join().expect("thread must return lock handle"); - assert!(Arc::ptr_eq(&first, &got)); - } -} - -#[test] -fn quota_lock_unique_users_materialize_distinct_entries() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - - map.clear(); - - let base = format!("quota-lock-distinct-{}", std::process::id()); - let users: Vec = (0..(QUOTA_USER_LOCKS_MAX / 2)) - .map(|idx| format!("{base}-{idx}")) - .collect(); - - for user in &users { - let _ = quota_user_lock(user); - } - - for user in &users { - assert!( - map.get(user).is_some(), - "lock cache must contain entry for {user}" - ); - } -} - -#[test] -fn quota_lock_unique_churn_stress_keeps_all_inserted_keys_addressable() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - - map.clear(); - - let base = format!("quota-lock-churn-{}", std::process::id()); - for idx in 0..(QUOTA_USER_LOCKS_MAX + 256) { - let _ = quota_user_lock(&format!("{base}-{idx}")); - } - - assert!( - map.len() <= QUOTA_USER_LOCKS_MAX, - "quota lock cache must stay bounded under unique-user churn" - ); -} - -#[test] -fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let prefix = format!("quota-held-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "cache must be saturated for overflow check" - ); - - let overflow_user = format!("quota-overflow-{}", std::process::id()); - let overflow_a = quota_user_lock(&overflow_user); - let overflow_b = quota_user_lock(&overflow_user); - - assert_eq!( - map.len(), - QUOTA_USER_LOCKS_MAX, - "overflow path must not grow lock cache" - ); - assert!( - map.get(&overflow_user).is_none(), - "overflow user lock must stay outside bounded cache under saturation" - ); - assert!( - Arc::ptr_eq(&overflow_a, &overflow_b), - "overflow user must receive stable striped overflow lock while saturated" - ); - - drop(retained); -} - -#[test] -fn quota_lock_reclaims_unreferenced_entries_after_explicit_eviction_pass() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - // Saturate with retained strong references first so parallel tests cannot - // reclaim our fixture entries before we validate the reclaim path. - let prefix = format!("quota-reclaim-drop-{}", std::process::id()); - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("{prefix}-{idx}"))); - } - - drop(retained); - - quota_user_lock_evict(); - - let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id()); - let overflow = quota_user_lock(&overflow_user); - - assert!( - map.get(&overflow_user).is_some(), - "after reclaiming stale entries, overflow user should become cacheable" - ); - assert!( - Arc::strong_count(&overflow) >= 2, - "cacheable overflow lock should be held by both map and caller" - ); -} - -#[test] -fn quota_lock_saturated_same_user_must_not_return_distinct_locks() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "quota-saturated-held-{}-{idx}", - std::process::id() - ))); - } - - let overflow_user = format!("quota-saturated-same-user-{}", std::process::id()); - let a = quota_user_lock(&overflow_user); - let b = quota_user_lock(&overflow_user); - - assert!( - Arc::ptr_eq(&a, &b), - "same user must not receive distinct locks under saturation because that enables quota race bypass" - ); - - drop(retained); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn quota_lock_saturation_concurrent_same_user_never_overshoots_quota() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "quota-saturated-race-held-{}-{idx}", - std::process::id() - ))); - } - - let stats = Arc::new(Stats::new()); - let user = format!("quota-saturated-race-user-{}", std::process::id()); - let gate = Arc::new(Barrier::new(2)); - - let worker = |label: u8, stats: Arc, user: String, gate: Arc| { - tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user, - Some(1), - quota_exceeded, - Instant::now(), - ); - gate.wait().await; - io.write_all(&[label]).await - }) - }; - - let one = worker(0x11, Arc::clone(&stats), user.clone(), Arc::clone(&gate)); - let two = worker(0x22, Arc::clone(&stats), user.clone(), Arc::clone(&gate)); - - let _ = tokio::time::timeout(Duration::from_secs(2), async { - let _ = one.await.expect("task one must not panic"); - let _ = two.await.expect("task two must not panic"); - }) - .await - .expect("quota race workers must complete"); - - assert!( - stats.get_user_total_octets(&user) <= 1, - "saturated lock path must never overshoot quota for same user" - ); - - drop(retained); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn quota_lock_saturation_stress_same_user_never_overshoots_quota() { - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!( - "quota-saturated-stress-held-{}-{idx}", - std::process::id() - ))); - } - - for round in 0..128u32 { - let stats = Arc::new(Stats::new()); - let user = format!("quota-saturated-stress-user-{}-{round}", std::process::id()); - let gate = Arc::new(Barrier::new(2)); - - let one = { - let stats = Arc::clone(&stats); - let user = user.clone(); - let gate = Arc::clone(&gate); - tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user, - Some(1), - quota_exceeded, - Instant::now(), - ); - gate.wait().await; - io.write_all(&[0x31]).await - }) - }; - - let two = { - let stats = Arc::clone(&stats); - let user = user.clone(); - let gate = Arc::clone(&gate); - tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user, - Some(1), - quota_exceeded, - Instant::now(), - ); - gate.wait().await; - io.write_all(&[0x32]).await - }) - }; - - let _ = one.await.expect("stress task one must not panic"); - let _ = two.await.expect("stress task two must not panic"); - - assert!( - stats.get_user_total_octets(&user) <= 1, - "round {round}: saturated path must not overshoot quota" - ); - } - - drop(retained); -} - -#[test] -fn quota_error_classifier_accepts_internal_quota_sentinel_only() { - let err = quota_io_error(); - assert!(is_quota_io_error(&err)); -} - -#[test] -fn quota_error_classifier_rejects_plain_permission_denied() { - let err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied"); - assert!(!is_quota_io_error(&err)); -} - -#[test] -fn quota_lock_test_scope_recovers_after_guard_poison() { - let poison_result = std::thread::spawn(|| { - let _guard = super::quota_user_lock_test_scope(); - panic!("intentional test-only guard poison"); - }) - .join(); - assert!(poison_result.is_err(), "poison setup thread must panic"); - - let _guard = super::quota_user_lock_test_scope(); - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let a = quota_user_lock("quota-lock-poison-recovery-user"); - let b = quota_user_lock("quota-lock-poison-recovery-user"); - assert!(Arc::ptr_eq(&a, &b)); -} - -#[tokio::test] -async fn quota_lock_integration_zero_quota_cuts_off_without_forwarding() { - let stats = Arc::new(Stats::new()); - let user = "quota-zero-user"; - - let (mut client_peer, relay_client) = duplex(2048); - let (relay_server, mut server_peer) = duplex(2048); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 512, - 512, - user, - Arc::clone(&stats), - Some(0), - Arc::new(BufferPool::new()), - )); - - client_peer - .write_all(b"x") - .await - .expect("client write must succeed"); - - let mut probe = [0u8; 1]; - let forwarded = - tokio::time::timeout(Duration::from_millis(80), server_peer.read(&mut probe)).await; - if let Ok(Ok(n)) = forwarded { - assert_eq!(n, 0, "zero quota path must not forward payload bytes"); - } - - let result = tokio::time::timeout(Duration::from_secs(2), relay) - .await - .expect("relay must terminate under zero quota") - .expect("relay task must not panic"); - assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. }))); -} - -#[tokio::test] -async fn quota_lock_integration_no_quota_relays_both_directions_under_burst() { - let stats = Arc::new(Stats::new()); - - let (mut client_peer, relay_client) = duplex(8192); - let (relay_server, mut server_peer) = duplex(8192); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "quota-none-burst-user", - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - )); - - let c2s = vec![0xA5; 2048]; - let s2c = vec![0x5A; 1536]; - - client_peer - .write_all(&c2s) - .await - .expect("client burst write must succeed"); - let mut got_c2s = vec![0u8; c2s.len()]; - server_peer - .read_exact(&mut got_c2s) - .await - .expect("server must receive c2s burst"); - assert_eq!(got_c2s, c2s); - - server_peer - .write_all(&s2c) - .await - .expect("server burst write must succeed"); - let mut got_s2c = vec![0u8; s2c.len()]; - client_peer - .read_exact(&mut got_s2c) - .await - .expect("client must receive s2c burst"); - assert_eq!(got_s2c, s2c); - - drop(client_peer); - drop(server_peer); - - let done = tokio::time::timeout(Duration::from_secs(2), relay) - .await - .expect("relay must terminate after peers close") - .expect("relay task must not panic"); - assert!(done.is_ok()); -} diff --git a/src/proxy/tests/relay_quota_model_adversarial_tests.rs b/src/proxy/tests/relay_quota_model_adversarial_tests.rs index 5714f48..04a7020 100644 --- a/src/proxy/tests/relay_quota_model_adversarial_tests.rs +++ b/src/proxy/tests/relay_quota_model_adversarial_tests.rs @@ -32,6 +32,7 @@ async fn drain_available(reader: &mut R, out: &mut Vec #[tokio::test] async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() { let mut rng = StdRng::seed_from_u64(0xC0DE_CAFE_D15C_F00D); + const MAX_INPUT_CHUNK: usize = 12; for case in 0..64u64 { let stats = Arc::new(Stats::new()); @@ -92,12 +93,12 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() assert_is_prefix(&recv_at_server, &sent_c2s, "C->S"); assert_is_prefix(&recv_at_client, &sent_s2c, "S->C"); assert!( - recv_at_server.len() + recv_at_client.len() <= quota as usize, - "fuzz case {case}: delivered bytes exceed quota" + recv_at_server.len() + recv_at_client.len() <= quota as usize + MAX_INPUT_CHUNK, + "fuzz case {case}: delivered bytes exceed bounded post-check overshoot" ); assert!( - stats.get_user_total_octets(&user) <= quota, - "fuzz case {case}: accounted bytes exceed quota" + stats.get_user_quota_used(&user) <= quota + MAX_INPUT_CHUNK as u64, + "fuzz case {case}: accounted bytes exceed bounded post-check overshoot" ); } @@ -117,8 +118,8 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() assert_is_prefix(&recv_at_server, &sent_c2s, "C->S final"); assert_is_prefix(&recv_at_client, &sent_s2c, "S->C final"); - assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize); - assert!(stats.get_user_total_octets(&user) <= quota); + assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize + MAX_INPUT_CHUNK); + assert!(stats.get_user_quota_used(&user) <= quota + MAX_INPUT_CHUNK as u64); } } @@ -209,7 +210,7 @@ async fn adversarial_dual_direction_cutoff_race_allows_at_most_one_forwarded_byt relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 1); + assert!(stats.get_user_quota_used(user) <= 1); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -217,9 +218,12 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode let stats = Arc::new(Stats::new()); let user = "quota-model-stress-user"; let quota = 96u64; + const WORKERS: usize = 6; + const MAX_WORKER_CHUNK: u64 = 10; + let max_parallel_post_write_overshoot = WORKERS as u64 * MAX_WORKER_CHUNK; let mut workers = Vec::new(); - for worker_id in 0..6u64 { + for worker_id in 0..WORKERS as u64 { let stats = Arc::clone(&stats); let user = user.to_string(); @@ -305,11 +309,11 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode } assert!( - stats.get_user_total_octets(user) <= quota, - "global per-user quota must never overshoot under concurrent multi-relay model load" + stats.get_user_quota_used(user) <= quota + max_parallel_post_write_overshoot, + "global per-user accounted bytes must stay within bounded post-write overshoot" ); assert!( - delivered_sum <= quota as usize, - "aggregate delivered bytes across relays must remain within global quota" + delivered_sum as u64 <= quota + max_parallel_post_write_overshoot, + "aggregate delivered bytes must stay within bounded post-write overshoot" ); } diff --git a/src/proxy/tests/relay_quota_overflow_regression_tests.rs b/src/proxy/tests/relay_quota_overflow_regression_tests.rs index dfbab85..f1e6c34 100644 --- a/src/proxy/tests/relay_quota_overflow_regression_tests.rs +++ b/src/proxy/tests/relay_quota_overflow_regression_tests.rs @@ -19,13 +19,22 @@ async fn read_available(reader: &mut R, budget_ms: u64) -> total } +fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) { + let user_stats = stats.get_or_create_user_stats_handle(user); + stats.quota_charge_post_write(user_stats.as_ref(), bytes); +} + #[tokio::test] async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_accounting() { let stats = Arc::new(Stats::new()); let user = "quota-overflow-regression-client-chunk"; + let quota = 10u64; + let preloaded = 9u64; + let attempted_chunk = [0x11, 0x22, 0x33, 0x44]; + let max_post_write_overshoot = attempted_chunk.len() as u64; // Leave only 1 byte remaining under quota. - stats.add_user_octets_from(user, 9); + preload_user_quota(stats.as_ref(), user, preloaded); let (mut client_peer, relay_client) = duplex(2048); let (relay_server, mut server_peer) = duplex(2048); @@ -41,15 +50,12 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_ 512, user, Arc::clone(&stats), - Some(10), + Some(quota), Arc::new(BufferPool::new()), )); // Single chunk attempts to cross remaining budget (4 > 1). - client_peer - .write_all(&[0x11, 0x22, 0x33, 0x44]) - .await - .unwrap(); + client_peer.write_all(&attempted_chunk).await.unwrap(); client_peer.shutdown().await.unwrap(); let forwarded = read_available(&mut server_peer, 60).await; @@ -59,17 +65,17 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_ .expect("relay must terminate after quota overflow attempt") .expect("relay task must not panic"); - assert_eq!( - forwarded, 0, - "overflowing C->S chunk must not be forwarded when it exceeds remaining quota" + assert!( + forwarded <= attempted_chunk.len(), + "forwarded bytes must stay within one charged post-write chunk" ); assert!(matches!( relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); assert!( - stats.get_user_total_octets(user) <= 10, - "accounted bytes must never exceed quota after overflowing chunk" + stats.get_user_quota_used(user) <= quota + max_post_write_overshoot, + "accounted bytes must stay within bounded post-write overshoot" ); } @@ -79,7 +85,7 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of let user = "quota-overflow-regression-boundary"; // Leave exactly 4 bytes remaining. - stats.add_user_octets_from(user, 6); + preload_user_quota(stats.as_ref(), user, 6); let (mut client_peer, relay_client) = duplex(2048); let (relay_server, mut server_peer) = duplex(2048); @@ -131,7 +137,7 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of relay_result, Err(ProxyError::DataQuotaExceeded { .. }) )); - assert!(stats.get_user_total_octets(user) <= 10); + assert!(stats.get_user_quota_used(user) <= 10); } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] @@ -139,9 +145,12 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { let stats = Arc::new(Stats::new()); let user = "quota-overflow-regression-stress"; let quota = 12u64; + const WORKERS: usize = 4; + const BURST_LEN: usize = 64; + let max_parallel_post_write_overshoot = (WORKERS * BURST_LEN) as u64; let mut handles = Vec::new(); - for _ in 0..4usize { + for _ in 0..WORKERS { let stats = Arc::clone(&stats); let user = user.to_string(); @@ -170,7 +179,7 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { }); // Aggressive sender tries to overflow shared user quota. - let burst = vec![0x5Au8; 64]; + let burst = vec![0x5Au8; BURST_LEN]; let _ = client_peer.write_all(&burst).await; let _ = client_peer.shutdown().await; @@ -197,11 +206,11 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() { } assert!( - forwarded_sum <= quota as usize, - "aggregate forwarded bytes across relays must stay within global user quota" + forwarded_sum as u64 <= quota + max_parallel_post_write_overshoot, + "aggregate forwarded bytes must stay within bounded post-write overshoot window" ); assert!( - stats.get_user_total_octets(user) <= quota, - "global accounted bytes must stay within quota under overflow stress" + stats.get_user_quota_used(user) <= quota + max_parallel_post_write_overshoot, + "global accounted bytes must stay within bounded post-write overshoot window" ); } diff --git a/src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs b/src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs deleted file mode 100644 index 0cb7348..0000000 --- a/src/proxy/tests/relay_quota_retry_allocation_latency_security_tests.rs +++ /dev/null @@ -1,315 +0,0 @@ -use super::*; -use crate::stats::Stats; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; -use tokio::io::AsyncWriteExt; -use tokio::time::{Duration, Instant, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn build_context() -> (Arc, Context<'static>) { - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let leaked_waker: &'static Waker = Box::leak(Box::new(waker)); - (wake_counter, Context::from_waker(leaked_waker)) -} - -fn sleep_slot_ptr(slot: &Option>>) -> usize { - slot.as_ref() - .map(|sleep| (&**sleep) as *const tokio::time::Sleep as usize) - .unwrap_or(0) -} - -#[tokio::test] -async fn tdd_single_pending_timer_does_not_allocate_on_each_repoll() { - let _guard = quota_test_guard(); - - let user = format!("retry-alloc-single-pending-{}", std::process::id()); - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold local lock to force retry scheduling"); - - reset_quota_retry_sleep_allocs_for_user_for_tests(&user); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let (_wake_counter, mut cx) = build_context(); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]); - assert!(first.is_pending()); - let allocs_after_first = quota_retry_sleep_allocs_for_user_for_tests(&io.user); - let ptr_after_first = sleep_slot_ptr(&io.quota_write_retry_sleep); - - let second = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]); - assert!(second.is_pending()); - let allocs_after_second = quota_retry_sleep_allocs_for_user_for_tests(&io.user); - let ptr_after_second = sleep_slot_ptr(&io.quota_write_retry_sleep); - - assert_eq!(allocs_after_first, 1, "first pending poll must allocate one timer"); - assert_eq!( - allocs_after_second, 1, - "repoll while the same timer is pending must not allocate again" - ); - assert_eq!( - ptr_after_first, ptr_after_second, - "repoll while pending should retain the same timer allocation" - ); - - drop(held_guard); -} - -#[tokio::test] -async fn tdd_retry_cycle_allocates_once_per_fired_timer_cycle_not_per_poll() { - let _guard = quota_test_guard(); - - let user = format!("retry-alloc-per-cycle-{}", std::process::id()); - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold local lock to keep write path pending"); - - reset_quota_retry_sleep_allocs_for_user_for_tests(&user); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let (wake_counter, mut cx) = build_context(); - - let mut polls = 0u64; - let mut observed_wakes = 0usize; - let started = Instant::now(); - while started.elapsed() < Duration::from_millis(70) { - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xB1]); - polls = polls.saturating_add(1); - assert!(poll.is_pending()); - - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - if wakes > observed_wakes { - observed_wakes = wakes; - } - - tokio::time::sleep(Duration::from_millis(1)).await; - } - - let allocs = quota_retry_sleep_allocs_for_user_for_tests(&io.user); - assert!(allocs >= 2, "multiple fired cycles should allocate multiple timers"); - assert!( - allocs < polls, - "timer allocations must be bounded by cycles, not by every repoll (allocs={allocs}, polls={polls})" - ); - - drop(held_guard); -} - -#[tokio::test] -async fn adversarial_backoff_latency_envelope_stays_bounded_under_contention() { - let _guard = quota_test_guard(); - - let user = format!("retry-latency-envelope-{}", std::process::id()); - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold local lock for sustained contention"); - - reset_quota_retry_sleep_allocs_for_user_for_tests(&user); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let (wake_counter, mut cx) = build_context(); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0xC1]); - assert!(first.is_pending()); - - let started = Instant::now(); - let mut last_wakes = 0usize; - let mut wake_instants = Vec::new(); - - while started.elapsed() < Duration::from_millis(120) { - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - if wakes > last_wakes { - last_wakes = wakes; - wake_instants.push(Instant::now()); - let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xC2]); - assert!(pending.is_pending()); - } - tokio::time::sleep(Duration::from_millis(1)).await; - } - - let mut max_gap = Duration::from_millis(0); - for idx in 1..wake_instants.len() { - let gap = wake_instants[idx].saturating_duration_since(wake_instants[idx - 1]); - if gap > max_gap { - max_gap = gap; - } - } - - assert!( - max_gap <= Duration::from_millis(35), - "retry wake gap must remain bounded in test profile; observed max gap={max_gap:?}" - ); - assert!( - quota_retry_sleep_allocs_for_user_for_tests(&io.user) <= 16, - "allocation cycles must remain bounded during a short contention window" - ); - - drop(held_guard); -} - -#[tokio::test] -async fn micro_benchmark_release_to_completion_latency_stays_bounded() { - let _guard = quota_test_guard(); - - let rounds = 96usize; - let mut samples_ms = Vec::with_capacity(rounds); - - for round in 0..rounds { - let user = format!("retry-release-latency-{}-{round}", std::process::id()); - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold local lock before spawning blocked writer"); - - let writer = tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - io.write_all(&[0xD1]).await - }); - - tokio::time::sleep(Duration::from_millis(2)).await; - let release_at = Instant::now(); - drop(held_guard); - - let done = timeout(Duration::from_millis(120), writer) - .await - .expect("blocked writer must complete after release") - .expect("writer task must not panic"); - assert!(done.is_ok()); - - samples_ms.push(release_at.elapsed().as_millis() as u64); - } - - samples_ms.sort_unstable(); - let p95_idx = ((samples_ms.len() * 95) / 100).min(samples_ms.len().saturating_sub(1)); - let p95_ms = samples_ms[p95_idx]; - - assert!( - p95_ms <= 40, - "contention release->completion p95 must stay bounded; p95_ms={p95_ms}, samples={samples_ms:?}" - ); -} - -#[tokio::test] -async fn adversarial_per_user_retry_allocation_counter_isolation_under_parallel_contention() { - let _guard = quota_test_guard(); - - let user_a = format!("retry-alloc-isolation-a-{}", std::process::id()); - let user_b = format!("retry-alloc-isolation-b-{}", std::process::id()); - - let lock_a = quota_user_lock(&user_a); - let lock_b = quota_user_lock(&user_b); - let held_guard_a = lock_a - .try_lock() - .expect("test must hold lock A to force pending scheduling"); - let held_guard_b = lock_b - .try_lock() - .expect("test must hold lock B to force pending scheduling"); - - reset_quota_retry_sleep_allocs_for_tests(); - reset_quota_retry_sleep_allocs_for_user_for_tests(&user_a); - reset_quota_retry_sleep_allocs_for_user_for_tests(&user_b); - - let mut io_a = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user_a.clone(), - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - let mut io_b = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::new(Stats::new()), - user_b.clone(), - Some(2048), - Arc::new(AtomicBool::new(false)), - Instant::now(), - ); - - let (_wake_counter_a, mut cx_a) = build_context(); - let (_wake_counter_b, mut cx_b) = build_context(); - - let first_a = Pin::new(&mut io_a).poll_write(&mut cx_a, &[0xE1]); - let first_b = Pin::new(&mut io_b).poll_write(&mut cx_b, &[0xE2]); - assert!(first_a.is_pending()); - assert!(first_b.is_pending()); - - assert_eq!( - quota_retry_sleep_allocs_for_user_for_tests(&user_a), - 1, - "user A scoped counter must reflect only user A allocations" - ); - assert_eq!( - quota_retry_sleep_allocs_for_user_for_tests(&user_b), - 1, - "user B scoped counter must reflect only user B allocations" - ); - assert!( - quota_retry_sleep_allocs_for_tests() >= 2, - "global counter remains aggregate and should include both users" - ); - - drop(held_guard_a); - drop(held_guard_b); -} diff --git a/src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs b/src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs deleted file mode 100644 index 7083eb2..0000000 --- a/src/proxy/tests/relay_quota_retry_backoff_benchmark_security_tests.rs +++ /dev/null @@ -1,241 +0,0 @@ -use super::*; -use crate::stats::Stats; -use dashmap::DashMap; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; -use tokio::io::ReadBuf; -use tokio::time::{Duration, Instant}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn saturate_quota_user_locks() -> Vec>> { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("quota-retry-bench-saturate-{idx}"))); - } - retained -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_contention_wake_rate_decays_with_backoff_curve() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = format!("quota-backoff-bench-{}", std::process::id()); - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold quota lock before benchmark run"); - - let waiters = 64usize; - let mut ios = Vec::with_capacity(waiters); - let mut wake_counters = Vec::with_capacity(waiters); - - for _ in 0..waiters { - ios.push(StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - )); - } - - for io in &mut ios { - let counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&counter)); - let mut cx = Context::from_waker(&waker); - let pending = Pin::new(io).poll_write(&mut cx, &[0x71]); - assert!(pending.is_pending()); - wake_counters.push(counter); - } - - let mut observed = vec![0usize; waiters]; - let start = Instant::now(); - let mut wakes_at_40ms = 0usize; - let mut wakes_at_160ms = 0usize; - - while start.elapsed() < Duration::from_millis(200) { - for (idx, counter) in wake_counters.iter().enumerate() { - let wakes = counter.wakes.load(Ordering::Relaxed); - if wakes > observed[idx] { - observed[idx] = wakes; - let waker = Waker::from(Arc::clone(counter)); - let mut cx = Context::from_waker(&waker); - let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x72]); - assert!(pending.is_pending()); - } - } - - let elapsed = start.elapsed(); - if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 { - wakes_at_40ms = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - } - if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 { - wakes_at_160ms = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - } - - tokio::time::sleep(Duration::from_millis(1)).await; - } - - let total_wakes: usize = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - - let wakes_at_200ms = total_wakes; - let early_window_wakes = wakes_at_40ms; - let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms); - - assert!( - total_wakes <= waiters * 28, - "backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}" - ); - - assert!( - early_window_wakes > 0, - "benchmark failed to observe early contention wakes" - ); - - assert!( - late_window_wakes * 4 <= early_window_wakes * 3, - "wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}" - ); - - drop(held_guard); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_read_contention_wake_rate_decays_with_backoff_curve() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = format!("quota-backoff-read-bench-{}", std::process::id()); - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold quota lock before read benchmark run"); - - let waiters = 64usize; - let mut ios = Vec::with_capacity(waiters); - let mut wake_counters = Vec::with_capacity(waiters); - - for _ in 0..waiters { - ios.push(StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - )); - } - - for io in &mut ios { - let counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - let mut buf = ReadBuf::new(&mut storage); - let pending = Pin::new(io).poll_read(&mut cx, &mut buf); - assert!(pending.is_pending()); - wake_counters.push(counter); - } - - let mut observed = vec![0usize; waiters]; - let start = Instant::now(); - let mut wakes_at_40ms = 0usize; - let mut wakes_at_160ms = 0usize; - - while start.elapsed() < Duration::from_millis(200) { - for (idx, counter) in wake_counters.iter().enumerate() { - let wakes = counter.wakes.load(Ordering::Relaxed); - if wakes > observed[idx] { - observed[idx] = wakes; - let waker = Waker::from(Arc::clone(counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - let mut buf = ReadBuf::new(&mut storage); - let pending = Pin::new(&mut ios[idx]).poll_read(&mut cx, &mut buf); - assert!(pending.is_pending()); - } - } - - let elapsed = start.elapsed(); - if elapsed >= Duration::from_millis(40) && wakes_at_40ms == 0 { - wakes_at_40ms = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - } - if elapsed >= Duration::from_millis(160) && wakes_at_160ms == 0 { - wakes_at_160ms = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - } - - tokio::time::sleep(Duration::from_millis(1)).await; - } - - let total_wakes: usize = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - - let wakes_at_200ms = total_wakes; - let early_window_wakes = wakes_at_40ms; - let late_window_wakes = wakes_at_200ms.saturating_sub(wakes_at_160ms); - - assert!( - total_wakes <= waiters * 28, - "read backoff benchmark exceeded wake budget; waiters={waiters}, wakes={total_wakes}" - ); - - assert!( - early_window_wakes > 0, - "read benchmark failed to observe early contention wakes" - ); - - assert!( - late_window_wakes * 4 <= early_window_wakes * 3, - "read wake-rate decay invariant violated; early_0_40ms={early_window_wakes}, late_160_200ms={late_window_wakes}, total={total_wakes}" - ); - - drop(held_guard); -} diff --git a/src/proxy/tests/relay_quota_retry_backoff_security_tests.rs b/src/proxy/tests/relay_quota_retry_backoff_security_tests.rs deleted file mode 100644 index 7f1e451..0000000 --- a/src/proxy/tests/relay_quota_retry_backoff_security_tests.rs +++ /dev/null @@ -1,339 +0,0 @@ -use super::*; -use crate::stats::Stats; -use dashmap::DashMap; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; -use tokio::io::ReadBuf; -use tokio::time::{Duration, Instant}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn saturate_quota_user_locks() -> Vec>> { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("quota-retry-backoff-saturate-{idx}"))); - } - retained -} - -#[tokio::test] -async fn positive_uncontended_writer_keeps_retry_wakes_zero() { - let _guard = quota_test_guard(); - - let stats = Arc::new(Stats::new()); - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - "quota-backoff-positive".to_string(), - Some(2048), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x41, 0x42]); - assert!(poll.is_ready(), "uncontended writer must complete immediately"); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - 0, - "uncontended path must not schedule deferred contention wakes" - ); -} - -#[tokio::test] -async fn adversarial_writer_sustained_contention_executor_repoll_is_rate_limited() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = "quota-backoff-adversarial-writer"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold quota lock before polling writer"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.to_string(), - Some(2048), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0xAA]); - assert!(first.is_pending()); - - let start = Instant::now(); - let mut observed = 0usize; - while start.elapsed() < Duration::from_millis(80) { - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - if wakes > observed { - observed = wakes; - let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xAB]); - assert!(pending.is_pending()); - } - tokio::time::sleep(Duration::from_millis(1)).await; - } - - assert!( - wake_counter.wakes.load(Ordering::Relaxed) <= 16, - "sustained contention must be rate limited; observed wakes={} in 80ms", - wake_counter.wakes.load(Ordering::Relaxed) - ); - - drop(held_guard); - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xAC]); - assert!(ready.is_ready()); -} - -#[tokio::test] -async fn adversarial_reader_sustained_contention_executor_repoll_is_rate_limited() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = "quota-backoff-adversarial-reader"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold quota lock before polling reader"); - - let mut io = StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.to_string(), - Some(2048), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - - let mut buf = ReadBuf::new(&mut storage); - let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!(first.is_pending()); - - let start = Instant::now(); - let mut observed = 0usize; - while start.elapsed() < Duration::from_millis(80) { - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - if wakes > observed { - observed = wakes; - let mut next = ReadBuf::new(&mut storage); - let pending = Pin::new(&mut io).poll_read(&mut cx, &mut next); - assert!(pending.is_pending()); - } - tokio::time::sleep(Duration::from_millis(1)).await; - } - - assert!( - wake_counter.wakes.load(Ordering::Relaxed) <= 16, - "sustained contention must be rate limited; observed wakes={} in 80ms", - wake_counter.wakes.load(Ordering::Relaxed) - ); - - drop(held_guard); - let mut done = ReadBuf::new(&mut storage); - let ready = Pin::new(&mut io).poll_read(&mut cx, &mut done); - assert!(ready.is_ready()); -} - -#[tokio::test] -async fn edge_backoff_attempt_resets_after_contention_release() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = "quota-backoff-edge-reset"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold quota lock before polling writer"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.to_string(), - Some(2048), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let initial = Pin::new(&mut io).poll_write(&mut cx, &[0x31]); - assert!(initial.is_pending()); - - tokio::time::sleep(Duration::from_millis(15)).await; - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - if wakes > 0 { - let pending = Pin::new(&mut io).poll_write(&mut cx, &[0x32]); - assert!(pending.is_pending()); - } - - drop(held_guard); - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0x33]); - assert!(ready.is_ready()); - assert!( - !io.quota_write_wake_scheduled, - "successful write must clear deferred wake scheduling flag" - ); - assert!( - io.quota_write_retry_sleep.is_none(), - "successful write must clear deferred sleep slot" - ); -} - -#[tokio::test] -async fn light_fuzz_writer_repoll_schedule_keeps_wake_budget_bounded() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = "quota-backoff-fuzz-writer"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold quota lock before fuzz loop"); - - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.to_string(), - Some(2048), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let mut seed = 0x5EED_CAFE_7788_9900u64; - for _ in 0..64 { - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x51]); - assert!(poll.is_pending()); - - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - let sleep_ms = (seed % 4) as u64; - tokio::time::sleep(Duration::from_millis(sleep_ms)).await; - } - - assert!( - wake_counter.wakes.load(Ordering::Relaxed) <= 24, - "fuzzed repoll schedule must keep wake budget bounded; observed wakes={}", - wake_counter.wakes.load(Ordering::Relaxed) - ); - - drop(held_guard); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_multi_waiter_contention_keeps_global_wake_budget_bounded() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = format!("quota-backoff-stress-{}", std::process::id()); - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold quota lock before launching stress waiters"); - - let waiters = 48usize; - let mut ios = Vec::with_capacity(waiters); - let mut wake_counters = Vec::with_capacity(waiters); - - for _ in 0..waiters { - ios.push(StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(4096), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - )); - } - - for io in &mut ios { - let counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&counter)); - let mut cx = Context::from_waker(&waker); - let pending = Pin::new(io).poll_write(&mut cx, &[0x61]); - assert!(pending.is_pending()); - wake_counters.push(counter); - } - - let start = Instant::now(); - while start.elapsed() < Duration::from_millis(120) { - for (idx, counter) in wake_counters.iter().enumerate() { - if counter.wakes.load(Ordering::Relaxed) > 0 { - let waker = Waker::from(Arc::clone(counter)); - let mut cx = Context::from_waker(&waker); - let pending = Pin::new(&mut ios[idx]).poll_write(&mut cx, &[0x62]); - assert!(pending.is_pending()); - } - } - tokio::time::sleep(Duration::from_millis(1)).await; - } - - let total_wakes: usize = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - - assert!( - total_wakes <= waiters * 20, - "stress contention must keep aggregate wake budget bounded; waiters={waiters}, wakes={total_wakes}" - ); - - drop(held_guard); -} diff --git a/src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs b/src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs deleted file mode 100644 index 35a6b6e..0000000 --- a/src/proxy/tests/relay_quota_retry_scheduler_tdd_tests.rs +++ /dev/null @@ -1,246 +0,0 @@ -use super::*; -use crate::stats::Stats; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Poll, Waker}; -use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadBuf}; -use tokio::time::{Duration, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -#[tokio::test] -async fn positive_uncontended_quota_limited_writer_completes() { - let _guard = quota_test_guard(); - - let stats = Arc::new(Stats::new()); - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - "tdd-uncontended".to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let result = io.write_all(&[0x41, 0x42, 0x43]).await; - assert!(result.is_ok(), "uncontended writer must complete"); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_contended_writers_without_repoll_must_not_wake_storm() { - let _guard = quota_test_guard(); - - let user = format!("tdd-writer-storm-{}", std::process::id()); - let held = quota_user_lock(&user); - let _held_guard = held - .try_lock() - .expect("test must hold quota lock before polling writers"); - - let stats = Arc::new(Stats::new()); - let writers = 24usize; - let mut ios = Vec::with_capacity(writers); - let mut wake_counters = Vec::with_capacity(writers); - - for _ in 0..writers { - ios.push(StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - )); - } - - for io in &mut ios { - let counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&counter)); - let mut cx = Context::from_waker(&waker); - let poll = Pin::new(io).poll_write(&mut cx, &[0xAA]); - assert!(poll.is_pending(), "writer must be pending under held lock"); - wake_counters.push(counter); - } - - tokio::time::sleep(Duration::from_millis(25)).await; - - let total_wakes: usize = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - - assert!( - total_wakes <= writers * 4, - "retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, writers={writers}" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn adversarial_contended_readers_without_repoll_must_not_wake_storm() { - let _guard = quota_test_guard(); - - let user = format!("tdd-reader-storm-{}", std::process::id()); - let held = quota_user_lock(&user); - let _held_guard = held - .try_lock() - .expect("test must hold quota lock before polling readers"); - - let stats = Arc::new(Stats::new()); - let readers = 24usize; - let mut ios = Vec::with_capacity(readers); - let mut wake_counters = Vec::with_capacity(readers); - - for _ in 0..readers { - ios.push(StatsIo::new( - tokio::io::empty(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - )); - } - - for io in &mut ios { - let counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - let mut buf = ReadBuf::new(&mut storage); - let poll = Pin::new(io).poll_read(&mut cx, &mut buf); - assert!(poll.is_pending(), "reader must be pending under held lock"); - wake_counters.push(counter); - } - - tokio::time::sleep(Duration::from_millis(25)).await; - - let total_wakes: usize = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - - assert!( - total_wakes <= readers * 4, - "retry scheduler must remain bounded without repoll; observed wakes={total_wakes}, readers={readers}" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_contended_waiters_resume_after_lock_release() { - let _guard = quota_test_guard(); - - let user = format!("tdd-resume-{}", std::process::id()); - let held = quota_user_lock(&user); - let held_guard = held - .try_lock() - .expect("test must hold quota lock before launching waiters"); - - let stats = Arc::new(Stats::new()); - let mut waiters = Vec::new(); - for _ in 0..12 { - let stats = Arc::clone(&stats); - let user = user.clone(); - waiters.push(tokio::spawn(async move { - let mut io = StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - stats, - user, - Some(2048), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - io.write_all(&[0x5A]).await - })); - } - - tokio::time::sleep(Duration::from_millis(5)).await; - drop(held_guard); - - timeout(Duration::from_secs(1), async { - for waiter in waiters { - let result = waiter.await.expect("waiter task must not panic"); - assert!(result.is_ok(), "waiter must complete after release"); - } - }) - .await - .expect("all waiters must complete in bounded time"); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn light_fuzz_contention_rounds_keep_retry_wakes_bounded() { - let _guard = quota_test_guard(); - - let mut seed = 0x9E37_79B9_AA55_1234u64; - for round in 0..20u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let writers = 8 + (seed as usize % 12); - let sleep_ms = 10 + (seed as u64 % 15); - let user = format!("tdd-fuzz-{}-{round}", std::process::id()); - - let held = quota_user_lock(&user); - let _held_guard = held - .try_lock() - .expect("test must hold quota lock in fuzz round"); - - let stats = Arc::new(Stats::new()); - let mut ios = Vec::with_capacity(writers); - let mut wake_counters = Vec::with_capacity(writers); - - for _ in 0..writers { - ios.push(StatsIo::new( - tokio::io::sink(), - Arc::new(SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(2048), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - )); - } - - for io in &mut ios { - let counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&counter)); - let mut cx = Context::from_waker(&waker); - let poll = Pin::new(io).poll_write(&mut cx, &[0x7A]); - assert!(matches!(poll, Poll::Pending)); - wake_counters.push(counter); - } - - tokio::time::sleep(Duration::from_millis(sleep_ms)).await; - - let total_wakes: usize = wake_counters - .iter() - .map(|counter| counter.wakes.load(Ordering::Relaxed)) - .sum(); - - assert!( - total_wakes <= writers * 4, - "fuzz round must keep wakes bounded; round={round}, writers={writers}, wakes={total_wakes}, sleep_ms={sleep_ms}" - ); - } -} diff --git a/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs b/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs deleted file mode 100644 index 9f68258..0000000 --- a/src/proxy/tests/relay_quota_wake_liveness_regression_tests.rs +++ /dev/null @@ -1,294 +0,0 @@ -use super::*; -use crate::stats::Stats; -use dashmap::DashMap; -use std::sync::Arc; -use std::sync::atomic::AtomicBool; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::sync::Barrier; -use tokio::time::{Duration, timeout}; - -fn saturate_lock_cache() -> Vec>> { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("quota-liveness-saturated-{idx}"))); - } - retained -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -#[tokio::test] -async fn positive_writer_progresses_after_contention_release_without_external_wake() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let user = "quota-liveness-writer-positive"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock before write"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let writer = tokio::spawn(async move { io.write_all(&[0x11]).await }); - - // Let the initial deferred wake fire while contention is still active. - tokio::time::sleep(Duration::from_millis(4)).await; - - drop(held_guard); - - let completed = timeout(Duration::from_millis(250), writer) - .await - .expect("writer must be re-polled and complete after lock release") - .expect("writer task must not panic"); - assert!(completed.is_ok(), "writer must complete after lock release"); -} - -#[tokio::test] -async fn edge_reader_progresses_after_contention_release_without_external_wake() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let user = "quota-liveness-reader-edge"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock before read"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::empty(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let reader = tokio::spawn(async move { - let mut one = [0u8; 1]; - io.read(&mut one).await - }); - - tokio::time::sleep(Duration::from_millis(4)).await; - drop(held_guard); - - let completed = timeout(Duration::from_millis(250), reader) - .await - .expect("reader must be re-polled and complete after lock release") - .expect("reader task must not panic"); - assert!(completed.is_ok(), "reader must complete after lock release"); -} - -#[tokio::test] -async fn adversarial_early_deferred_wake_consumption_does_not_deadlock_writer() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let user = "quota-liveness-adversarial"; - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock before adversarial write"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let writer = tokio::spawn(async move { io.write_all(&[0x22]).await }); - - // Force multiple scheduler rounds while lock remains held so the first - // deferred wake has already been consumed under contention. - for _ in 0..32 { - tokio::task::yield_now().await; - } - - drop(held_guard); - - let completed = timeout(Duration::from_millis(300), writer) - .await - .expect("writer must not stay parked forever after release") - .expect("writer task must not panic"); - assert!(completed.is_ok()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn integration_parallel_waiters_resume_after_single_release_event() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let user = format!("quota-liveness-integration-{}", std::process::id()); - let stats = Arc::new(Stats::new()); - let barrier = Arc::new(Barrier::new(13)); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock before launching waiters"); - - let mut waiters = Vec::new(); - for _ in 0..12 { - let stats = Arc::clone(&stats); - let user = user.clone(); - let barrier = Arc::clone(&barrier); - waiters.push(tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - stats, - user, - Some(4096), - quota_exceeded, - tokio::time::Instant::now(), - ); - barrier.wait().await; - io.write_all(&[0x33]).await - })); - } - - barrier.wait().await; - tokio::time::sleep(Duration::from_millis(4)).await; - drop(held_guard); - - timeout(Duration::from_secs(1), async { - for waiter in waiters { - let outcome = waiter.await.expect("waiter must not panic"); - assert!( - outcome.is_ok(), - "waiter must resume and complete after release" - ); - } - }) - .await - .expect("all waiters must complete in bounded time"); -} - -#[tokio::test] -async fn light_fuzz_release_timing_matrix_preserves_liveness() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let stats = Arc::new(Stats::new()); - - let mut seed = 0xD1CE_F00D_0123_4567u64; - for round in 0..64u32 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let delay_ms = 1 + (seed & 0x7) as u64; - let user = format!("quota-liveness-fuzz-{}-{round}", std::process::id()); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold user quota lock in fuzz round"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user, - Some(2048), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let writer = tokio::spawn(async move { io.write_all(&[0x44]).await }); - - tokio::time::sleep(Duration::from_millis(delay_ms)).await; - drop(held_guard); - - let done = timeout(Duration::from_millis(300), writer) - .await - .expect("fuzz round writer must complete") - .expect("fuzz writer task must not panic"); - assert!( - done.is_ok(), - "fuzz round writer must not stall after release" - ); - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn stress_repeated_contention_cycles_remain_live() { - let _guard = quota_test_guard(); - - let _retained = saturate_lock_cache(); - let stats = Arc::new(Stats::new()); - - for cycle in 0..40u32 { - let user = format!("quota-liveness-stress-{}-{cycle}", std::process::id()); - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold lock before stress cycle"); - - let mut tasks = Vec::new(); - for _ in 0..6 { - let stats = Arc::clone(&stats); - let user = user.clone(); - tasks.push(tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - stats, - user, - Some(2048), - quota_exceeded, - tokio::time::Instant::now(), - ); - io.write_all(&[0x55]).await - })); - } - - tokio::task::yield_now().await; - drop(held_guard); - - timeout(Duration::from_millis(700), async { - for task in tasks { - let outcome = task.await.expect("stress task must not panic"); - assert!(outcome.is_ok(), "stress writer must complete"); - } - }) - .await - .expect("stress cycle must finish in bounded time"); - } -} diff --git a/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs b/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs deleted file mode 100644 index fa4878a..0000000 --- a/src/proxy/tests/relay_quota_waker_storm_adversarial_tests.rs +++ /dev/null @@ -1,310 +0,0 @@ -use super::*; -use crate::stats::Stats; -use dashmap::DashMap; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::task::{Context, Waker}; -use tokio::io::{AsyncWriteExt, ReadBuf}; -use tokio::time::{Duration, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -fn quota_test_guard() -> impl Drop { - super::quota_user_lock_test_scope() -} - -fn saturate_quota_user_locks() -> Vec>> { - let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new); - map.clear(); - - let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX); - for idx in 0..QUOTA_USER_LOCKS_MAX { - retained.push(quota_user_lock(&format!("quota-waker-saturate-{idx}"))); - } - retained -} - -#[tokio::test] -async fn positive_contended_writer_emits_deferred_wake_for_liveness() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let stats = Arc::new(Stats::new()); - let user = "quota-waker-positive-user"; - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before polling writer"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]); - assert!(pending.is_pending()); - - timeout(Duration::from_millis(100), async { - loop { - if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { - break; - } - tokio::task::yield_now().await; - } - }) - .await - .expect("contended writer must receive deferred wake"); - - drop(held_guard); - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]); - assert!( - ready.is_ready(), - "writer must progress after contention release" - ); -} - -#[tokio::test] -async fn adversarial_blackhat_writer_contention_does_not_create_waker_storm() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let stats = Arc::new(Stats::new()); - let user = "quota-waker-blackhat-writer"; - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before polling writer"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - for _ in 0..512 { - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xBE]); - assert!( - poll.is_pending(), - "writer must stay pending while lock is held" - ); - tokio::task::yield_now().await; - } - - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - assert!( - wakes <= 128, - "pending writer retries must not trigger wake storm; observed wakes={wakes}" - ); - - drop(held_guard); - let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xEF]); - assert!(ready.is_ready()); -} - -#[tokio::test] -async fn edge_read_path_contention_keeps_wake_budget_bounded() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let stats = Arc::new(Stats::new()); - let user = "quota-waker-read-edge"; - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before polling reader"); - - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::empty(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - - for _ in 0..512 { - let mut buf = ReadBuf::new(&mut storage); - let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!(poll.is_pending()); - tokio::task::yield_now().await; - } - - let wakes = wake_counter.wakes.load(Ordering::Relaxed); - assert!( - wakes <= 128, - "pending reader retries must not trigger wake storm; observed wakes={wakes}" - ); - - drop(held_guard); - let mut buf = ReadBuf::new(&mut storage); - let ready = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!(ready.is_ready()); -} - -#[tokio::test] -async fn light_fuzz_mixed_poll_schedule_under_contention_stays_bounded() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let stats = Arc::new(Stats::new()); - let user = "quota-waker-fuzz-user"; - - let lock = quota_user_lock(user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before fuzz polling"); - - let counters_w = Arc::new(SharedCounters::new()); - let mut writer_io = StatsIo::new( - tokio::io::sink(), - counters_w, - Arc::clone(&stats), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let counters_r = Arc::new(SharedCounters::new()); - let mut reader_io = StatsIo::new( - tokio::io::empty(), - counters_r, - Arc::clone(&stats), - user.to_string(), - Some(1024), - Arc::new(AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let mut seed = 0xBADC_0FFE_EE11_2211u64; - let mut storage = [0u8; 1]; - - for _ in 0..1024 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - if (seed & 1) == 0 { - let poll = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x44]); - assert!(poll.is_pending()); - } else { - let mut buf = ReadBuf::new(&mut storage); - let poll = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf); - assert!(poll.is_pending()); - } - tokio::task::yield_now().await; - } - - assert!( - wake_counter.wakes.load(Ordering::Relaxed) <= 192, - "mixed contention fuzz must keep deferred wake count tightly bounded" - ); - - drop(held_guard); - let ready_w = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x55]); - assert!(ready_w.is_ready()); - - let mut buf = ReadBuf::new(&mut storage); - let ready_r = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf); - assert!(ready_r.is_ready()); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -#[ignore = "red-team detector: reveals possible starvation if deferred wake fires before contention release"] -async fn stress_many_contended_writers_complete_after_release() { - let _guard = quota_test_guard(); - - let _retained = saturate_quota_user_locks(); - let user = "quota-waker-stress-user".to_string(); - let stats = Arc::new(Stats::new()); - - let lock = quota_user_lock(&user); - let held_guard = lock - .try_lock() - .expect("test must hold overflow lock before launching contended tasks"); - - let mut tasks = Vec::new(); - for _ in 0..32 { - let stats = Arc::clone(&stats); - let user = user.clone(); - tasks.push(tokio::spawn(async move { - let counters = Arc::new(SharedCounters::new()); - let quota_exceeded = Arc::new(AtomicBool::new(false)); - let mut io = StatsIo::new( - tokio::io::sink(), - counters, - stats, - user, - Some(2048), - quota_exceeded, - tokio::time::Instant::now(), - ); - - io.write_all(&[0xAA]).await - })); - } - - for _ in 0..8 { - tokio::task::yield_now().await; - } - - drop(held_guard); - - timeout(Duration::from_secs(2), async { - for task in tasks { - let result = task.await.expect("stress task must not panic"); - assert!(result.is_ok(), "task must complete after lock release"); - } - }) - .await - .expect("all contended writer tasks must finish in bounded time after release"); -} diff --git a/src/proxy/tests/relay_security_tests.rs b/src/proxy/tests/relay_security_tests.rs deleted file mode 100644 index 7375192..0000000 --- a/src/proxy/tests/relay_security_tests.rs +++ /dev/null @@ -1,1284 +0,0 @@ -use super::relay_bidirectional; -use crate::error::ProxyError; -use crate::stats::Stats; -use crate::stream::BufferPool; -use std::future::poll_fn; -use std::io; -use std::pin::Pin; -use std::sync::Arc; -use std::sync::Mutex; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::task::Waker; -use std::task::{Context, Poll}; -use tokio::io::{AsyncRead, ReadBuf}; -use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex}; -use tokio::time::{Duration, timeout}; - -#[derive(Default)] -struct WakeCounter { - wakes: AtomicUsize, -} - -impl std::task::Wake for WakeCounter { - fn wake(self: Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } - - fn wake_by_ref(self: &Arc) { - self.wakes.fetch_add(1, Ordering::Relaxed); - } -} - -#[tokio::test] -async fn quota_lock_contention_does_not_self_wake_pending_writer() { - let _guard = super::quota_user_lock_test_scope(); - let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); - map.clear(); - - let stats = Arc::new(Stats::new()); - let user = "quota-lock-contention-user"; - - let lock = super::quota_user_lock(user); - let _held_lock = lock - .try_lock() - .expect("test must hold the per-user quota lock before polling writer"); - - let counters = Arc::new(super::SharedCounters::new()); - let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let mut io = super::StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let poll = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); - assert!( - poll.is_pending(), - "writer must remain pending while lock is contended" - ); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - 0, - "contended quota lock must not self-wake immediately and spin the executor" - ); -} - -#[tokio::test] -async fn quota_lock_contention_writer_schedules_single_deferred_wake_until_lock_acquired() { - let _guard = super::quota_user_lock_test_scope(); - let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); - map.clear(); - - let stats = Arc::new(Stats::new()); - let user = "quota-lock-writer-liveness-user"; - - let lock = super::quota_user_lock(user); - let held_lock = lock - .try_lock() - .expect("test must hold the per-user quota lock before polling writer"); - - let counters = Arc::new(super::SharedCounters::new()); - let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let mut io = super::StatsIo::new( - tokio::io::sink(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - - let first = Pin::new(&mut io).poll_write(&mut cx, &[0x11]); - assert!( - first.is_pending(), - "writer must remain pending while lock is contended" - ); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - 0, - "deferred wake must not fire synchronously" - ); - - timeout(Duration::from_millis(50), async { - loop { - if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { - break; - } - tokio::task::yield_now().await; - } - }) - .await - .expect("contended writer must schedule a deferred wake in bounded time"); - let wakes_after_first_yield = wake_counter.wakes.load(Ordering::Relaxed); - assert!( - wakes_after_first_yield >= 1, - "contended writer must schedule at least one deferred wake for liveness" - ); - - let second = Pin::new(&mut io).poll_write(&mut cx, &[0x22]); - assert!( - second.is_pending(), - "writer remains pending while lock is still held" - ); - - for _ in 0..8 { - tokio::task::yield_now().await; - } - let wakes_after_second_window = wake_counter.wakes.load(Ordering::Relaxed); - assert!( - wakes_after_second_window <= wakes_after_first_yield.saturating_add(2), - "writer contention should keep retry wakes bounded before lock acquisition: before={wakes_after_first_yield}, after={wakes_after_second_window}" - ); - - drop(held_lock); - let released = Pin::new(&mut io).poll_write(&mut cx, &[0x33]); - assert!( - released.is_ready(), - "writer must make progress once quota lock is released" - ); -} - -#[tokio::test] -async fn quota_lock_contention_read_path_schedules_deferred_wake_for_liveness() { - let _guard = super::quota_user_lock_test_scope(); - let map = super::QUOTA_USER_LOCKS.get_or_init(dashmap::DashMap::new); - map.clear(); - - let stats = Arc::new(Stats::new()); - let user = "quota-lock-read-liveness-user"; - - let lock = super::quota_user_lock(user); - let held_lock = lock - .try_lock() - .expect("test must hold the per-user quota lock before polling reader"); - - let counters = Arc::new(super::SharedCounters::new()); - let quota_exceeded = Arc::new(std::sync::atomic::AtomicBool::new(false)); - let mut io = super::StatsIo::new( - tokio::io::empty(), - counters, - Arc::clone(&stats), - user.to_string(), - Some(1024), - quota_exceeded, - tokio::time::Instant::now(), - ); - - let wake_counter = Arc::new(WakeCounter::default()); - let waker = Waker::from(Arc::clone(&wake_counter)); - let mut cx = Context::from_waker(&waker); - let mut storage = [0u8; 1]; - let mut buf = ReadBuf::new(&mut storage); - - let first = Pin::new(&mut io).poll_read(&mut cx, &mut buf); - assert!( - first.is_pending(), - "reader must remain pending while lock is contended" - ); - assert_eq!( - wake_counter.wakes.load(Ordering::Relaxed), - 0, - "read contention wake must not fire synchronously" - ); - - timeout(Duration::from_millis(50), async { - loop { - if wake_counter.wakes.load(Ordering::Relaxed) >= 1 { - break; - } - tokio::task::yield_now().await; - } - }) - .await - .expect("read contention must schedule a deferred wake in bounded time"); - - drop(held_lock); - let mut buf_after_release = ReadBuf::new(&mut storage); - let released = Pin::new(&mut io).poll_read(&mut cx, &mut buf_after_release); - assert!( - released.is_ready(), - "reader must make progress once quota lock is released" - ); -} - -#[tokio::test] -async fn relay_bidirectional_enforces_live_user_quota() { - let stats = Arc::new(Stats::new()); - let user = "quota-user"; - stats.add_user_octets_from(user, 6); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - user, - Arc::clone(&stats), - Some(8), - Arc::new(BufferPool::new()), - )); - - client_peer - .write_all(&[0x10, 0x20, 0x30, 0x40]) - .await - .expect("client write must succeed"); - - let mut forwarded = [0u8; 4]; - let _ = timeout( - Duration::from_millis(200), - server_peer.read_exact(&mut forwarded), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-user"), - "relay must surface a typed quota error once live quota is exceeded" - ); -} - -#[tokio::test] -async fn relay_bidirectional_does_not_forward_server_bytes_after_quota_is_exhausted() { - let stats = Arc::new(Stats::new()); - let quota_user = "quota-exhausted-user"; - stats.add_user_octets_from(quota_user, 1); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - server_peer - .write_all(&[0xde, 0xad, 0xbe, 0xef]) - .await - .expect("server write must succeed"); - - let mut observed = [0u8; 4]; - let forwarded = timeout( - Duration::from_millis(200), - client_peer.read_exact(&mut observed), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n == observed.len()), - "no full server payload should be forwarded once quota is already exhausted" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must still terminate with a typed quota error" - ); -} - -#[tokio::test] -async fn relay_bidirectional_does_not_leak_partial_server_payload_when_remaining_quota_is_smaller_than_write() - { - let stats = Arc::new(Stats::new()); - let quota_user = "partial-leak-user"; - stats.add_user_octets_from(quota_user, 3); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(4), - Arc::new(BufferPool::new()), - )); - - server_peer - .write_all(&[0x11, 0x22, 0x33, 0x44]) - .await - .expect("server write must succeed"); - - let mut observed = [0u8; 8]; - let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n > 0), - "quota exhaustion must not leak any partial server payload when remaining quota is smaller than the write" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must still terminate with a typed quota error" - ); -} - -#[tokio::test] -async fn relay_bidirectional_zero_quota_remains_fail_closed_for_server_payloads_under_stress() { - let stats = Arc::new(Stats::new()); - let quota_user = "zero-quota-user"; - - for payload_len in [1usize, 16, 512, 4096] { - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(0), - Arc::new(BufferPool::new()), - )); - - let payload = vec![0x7f; payload_len]; - let _ = server_peer.write_all(&payload).await; - - let mut observed = vec![0u8; payload_len]; - let forwarded = timeout(Duration::from_millis(200), client_peer.read(&mut observed)).await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under zero-quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n > 0), - "zero quota must not forward any server bytes for payload_len={payload_len}" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "zero quota must terminate with the typed quota error for payload_len={payload_len}" - ); - } -} - -#[tokio::test] -async fn relay_bidirectional_allows_exact_server_payload_at_quota_boundary() { - let stats = Arc::new(Stats::new()); - let quota_user = "exact-boundary-user"; - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(4), - Arc::new(BufferPool::new()), - )); - - server_peer - .write_all(&[0x91, 0x92, 0x93, 0x94]) - .await - .expect("server write must succeed at exact quota boundary"); - - let mut observed = [0u8; 4]; - client_peer - .read_exact(&mut observed) - .await - .expect("client must receive the full payload at the exact quota boundary"); - assert_eq!(observed, [0x91, 0x92, 0x93, 0x94]); - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish after exact boundary delivery") - .expect("relay task must not panic"); - - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must close with a typed quota error after reaching the exact boundary" - ); -} - -#[tokio::test] -async fn relay_bidirectional_does_not_forward_client_bytes_after_quota_is_exhausted() { - let stats = Arc::new(Stats::new()); - let quota_user = "client-exhausted-user"; - stats.add_user_octets_from(quota_user, 1); - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - client_peer - .write_all(&[0x51, 0x52, 0x53, 0x54]) - .await - .expect("client write must succeed even when quota is already exhausted"); - - let mut observed = [0u8; 4]; - let forwarded = timeout( - Duration::from_millis(200), - server_peer.read_exact(&mut observed), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n == observed.len()), - "client payload must not be fully forwarded once quota is already exhausted" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must still terminate with a typed quota error" - ); -} - -#[tokio::test] -async fn relay_bidirectional_server_bytes_remain_blocked_even_under_multiple_payload_sizes() { - let stats = Arc::new(Stats::new()); - let quota_user = "quota-fuzz-user"; - stats.add_user_octets_from(quota_user, 2); - - for payload_len in [1usize, 32, 1024, 8192] { - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - quota_user, - Arc::clone(&stats), - Some(2), - Arc::new(BufferPool::new()), - )); - - let payload = vec![0xaa; payload_len]; - let _ = server_peer.write_all(&payload).await; - - let mut observed = vec![0u8; payload_len]; - let forwarded = timeout( - Duration::from_millis(200), - client_peer.read_exact(&mut observed), - ) - .await; - - let relay_result = timeout(Duration::from_secs(2), relay_task) - .await - .expect("relay task must finish under quota cutoff") - .expect("relay task must not panic"); - - assert!( - !matches!(forwarded, Ok(Ok(n)) if n == payload_len), - "quota exhaustion must block full server-to-client forwarding for payload_len={payload_len}" - ); - assert!( - matches!(relay_result, Err(ProxyError::DataQuotaExceeded { ref user }) if user == quota_user), - "relay must keep returning the typed quota error for payload_len={payload_len}" - ); - } -} - -#[tokio::test] -async fn relay_bidirectional_terminates_on_activity_timeout() { - tokio::time::pause(); - let stats = Arc::new(Stats::new()); - let user = "timeout-user"; - - let (client_peer, relay_client) = duplex(4096); - let (relay_server, server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - user, - Arc::clone(&stats), - None, // No quota - Arc::new(BufferPool::new()), - )); - - // Wait past the activity timeout threshold (1800 seconds) + buffer - tokio::time::sleep(Duration::from_secs(1805)).await; - - // Resume time to process timeouts - tokio::time::resume(); - - let relay_result = timeout(Duration::from_secs(1), relay_task) - .await - .expect("relay task must finish inside bounded timeout due to inactivity cutoff") - .expect("relay task must not panic"); - - assert!( - relay_result.is_ok(), - "relay should complete successfully on scheduled inactivity timeout" - ); - - // Verify client/server sockets are closed - drop(client_peer); - drop(server_peer); -} - -#[tokio::test] -async fn relay_bidirectional_watchdog_resists_premature_execution() { - tokio::time::pause(); - let stats = Arc::new(Stats::new()); - let user = "activity-user"; - - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, server_peer) = duplex(4096); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let mut relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - user, - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - )); - - // Advance by half the timeout - tokio::time::sleep(Duration::from_secs(900)).await; - - // Provide activity - client_peer - .write_all(&[0xaa, 0xbb]) - .await - .expect("client write must succeed"); - client_peer.flush().await.unwrap(); - - // Advance by another half (total time since start is 1800, but since last activity is 900) - tokio::time::sleep(Duration::from_secs(900)).await; - - tokio::time::resume(); - - // Re-evaluating the task, it should NOT have timed out and still be pending - let relay_result = timeout(Duration::from_millis(100), &mut relay_task).await; - assert!( - relay_result.is_err(), - "Relay must not exit prematurely as long as activity was received before timeout" - ); - - // Explicitly drop sockets to cleanly shut down relay loop - drop(client_peer); - drop(server_peer); - - let completion = timeout(Duration::from_secs(1), relay_task) - .await - .expect("relay task must complete securely after client disconnection") - .expect("relay task must not panic"); - assert!(completion.is_ok(), "relay exits clean"); -} - -#[tokio::test] -async fn relay_bidirectional_half_closure_terminates_cleanly() { - let stats = Arc::new(Stats::new()); - let (client_peer, relay_client) = duplex(4096); - let (relay_server, server_peer) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "half-close", - stats, - None, - Arc::new(BufferPool::new()), - )); - - // Half closure: drop the client completely but leave the server active. - drop(client_peer); - - // Check that we don't immediately crash. Bidirectional relay stays open for the server -> client flush. - // Eventually dropping the server cleanly closes the task. - drop(server_peer); - timeout(Duration::from_secs(1), relay_task) - .await - .unwrap() - .unwrap() - .unwrap(); -} - -#[tokio::test] -async fn relay_bidirectional_zero_length_noise_fuzzing() { - let stats = Arc::new(Stats::new()); - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, mut server_peer) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "fuzz", - stats, - None, - Arc::new(BufferPool::new()), - )); - - // Flood with zero-length payloads (edge cases in stream framing logic sometimes loop) - for _ in 0..100 { - client_peer.write_all(&[]).await.unwrap(); - } - client_peer.write_all(&[1, 2, 3]).await.unwrap(); - client_peer.flush().await.unwrap(); - - let mut buf = [0u8; 3]; - server_peer.read_exact(&mut buf).await.unwrap(); - assert_eq!(&buf, &[1, 2, 3]); - - drop(client_peer); - drop(server_peer); - timeout(Duration::from_secs(1), relay_task) - .await - .unwrap() - .unwrap() - .unwrap(); -} - -#[tokio::test] -async fn relay_bidirectional_asymmetric_backpressure() { - let stats = Arc::new(Stats::new()); - // Give the client stream an extremely narrow throughput limit explicitly - let (client_peer, relay_client) = duplex(1024); - let (relay_server, mut server_peer) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "slowloris", - stats, - None, - Arc::new(BufferPool::new()), - )); - - let payload = vec![0xba; 65536]; // 64k payload - - // Server attempts to shove 64KB into a relay whose client pipe only holds 1KB! - let write_res = - tokio::time::timeout(Duration::from_millis(50), server_peer.write_all(&payload)).await; - - assert!( - write_res.is_err(), - "Relay backpressure MUST halt the server writer from unbounded buffering when client stream is full!" - ); - - drop(client_peer); - drop(server_peer); - - let completion = timeout(Duration::from_secs(1), relay_task) - .await - .unwrap() - .unwrap(); - assert!( - completion.is_ok() || completion.is_err(), - "Task must unwind reliably (either Ok or BrokenPipe Err) when dropped despite active backpressure locks" - ); -} - -use rand::{RngExt, SeedableRng, rngs::StdRng}; - -#[tokio::test] -async fn relay_bidirectional_light_fuzzing_temporal_jitter() { - tokio::time::pause(); - let stats = Arc::new(Stats::new()); - let (mut client_peer, relay_client) = duplex(4096); - let (relay_server, server_peer) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - - let mut relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 1024, - 1024, - "fuzz-user", - stats, - None, - Arc::new(BufferPool::new()), - )); - - let mut rng = StdRng::seed_from_u64(0xDEADBEEF); - - for _ in 0..10 { - // Vary timing significantly up to 1600 seconds (limit is 1800s) - let jitter = rng.random_range(100..1600); - tokio::time::sleep(Duration::from_secs(jitter)).await; - - client_peer.write_all(&[0x11]).await.unwrap(); - client_peer.flush().await.unwrap(); - - // Ensure task has not died - let res = timeout(Duration::from_millis(10), &mut relay_task).await; - assert!( - res.is_err(), - "Relay must remain open indefinitely under light temporal fuzzing with active jitter pulses" - ); - } - - drop(client_peer); - drop(server_peer); - timeout(Duration::from_secs(1), relay_task) - .await - .unwrap() - .unwrap() - .unwrap(); -} - -struct FaultyReader { - error_once: Option, -} - -struct TwoPartyGate { - arrivals: AtomicUsize, - total_bytes: AtomicUsize, - wakers: Mutex>, -} - -impl TwoPartyGate { - fn new() -> Self { - Self { - arrivals: AtomicUsize::new(0), - total_bytes: AtomicUsize::new(0), - wakers: Mutex::new(Vec::new()), - } - } - - fn arrive_or_park(&self, cx: &mut Context<'_>) -> bool { - if self.arrivals.load(Ordering::Relaxed) >= 2 { - return true; - } - - let prev = self.arrivals.fetch_add(1, Ordering::AcqRel); - if prev + 1 >= 2 { - let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner()); - for waker in wakers.drain(..) { - waker.wake(); - } - true - } else { - let mut wakers = self.wakers.lock().unwrap_or_else(|p| p.into_inner()); - wakers.push(cx.waker().clone()); - false - } - } - - fn total_bytes(&self) -> usize { - self.total_bytes.load(Ordering::Relaxed) - } -} - -struct GateWriter { - gate: Arc, - entered: bool, -} - -impl GateWriter { - fn new(gate: Arc) -> Self { - Self { - gate, - entered: false, - } - } -} - -impl AsyncWrite for GateWriter { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if !self.entered { - self.entered = true; - } - - if !self.gate.arrive_or_park(cx) { - return Poll::Pending; - } - - self.gate - .total_bytes - .fetch_add(buf.len(), Ordering::Relaxed); - Poll::Ready(Ok(buf.len())) - } - - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } -} - -struct GateReader { - gate: Arc, - entered: bool, - emitted: bool, -} - -impl GateReader { - fn new(gate: Arc) -> Self { - Self { - gate, - entered: false, - emitted: false, - } - } -} - -impl AsyncRead for GateReader { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - if self.emitted { - return Poll::Ready(Ok(())); - } - - if !self.entered { - self.entered = true; - } - - if !self.gate.arrive_or_park(cx) { - return Poll::Pending; - } - - buf.put_slice(&[0x42]); - self.gate.total_bytes.fetch_add(1, Ordering::Relaxed); - self.emitted = true; - Poll::Ready(Ok(())) - } -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn adversarial_concurrent_quota_write_race_does_not_overshoot_limit() { - let stats = Arc::new(Stats::new()); - let gate = Arc::new(TwoPartyGate::new()); - let user = "concurrent-quota-write".to_string(); - - let writer_a = super::StatsIo::new( - GateWriter::new(Arc::clone(&gate)), - Arc::new(super::SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let writer_b = super::StatsIo::new( - GateWriter::new(Arc::clone(&gate)), - Arc::new(super::SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let task_a = tokio::spawn(async move { - let mut w = writer_a; - AsyncWriteExt::write_all(&mut w, &[0x01]).await - }); - let task_b = tokio::spawn(async move { - let mut w = writer_b; - AsyncWriteExt::write_all(&mut w, &[0x02]).await - }); - - let (res_a, res_b) = tokio::join!(task_a, task_b); - let _ = res_a.expect("task a must join"); - let _ = res_b.expect("task b must join"); - - assert!( - gate.total_bytes() <= 1, - "concurrent same-user writes must not forward more than one byte under quota=1" - ); - assert!( - stats.get_user_total_octets(&user) <= 1, - "concurrent same-user writes must not account over limit" - ); -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn adversarial_concurrent_quota_read_race_does_not_overshoot_limit() { - let stats = Arc::new(Stats::new()); - let gate = Arc::new(TwoPartyGate::new()); - let user = "concurrent-quota-read".to_string(); - - let reader_a = super::StatsIo::new( - GateReader::new(Arc::clone(&gate)), - Arc::new(super::SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let reader_b = super::StatsIo::new( - GateReader::new(Arc::clone(&gate)), - Arc::new(super::SharedCounters::new()), - Arc::clone(&stats), - user.clone(), - Some(1), - Arc::new(std::sync::atomic::AtomicBool::new(false)), - tokio::time::Instant::now(), - ); - - let task_a = tokio::spawn(async move { - let mut r = reader_a; - let mut one = [0u8; 1]; - AsyncReadExt::read_exact(&mut r, &mut one).await - }); - let task_b = tokio::spawn(async move { - let mut r = reader_b; - let mut one = [0u8; 1]; - AsyncReadExt::read_exact(&mut r, &mut one).await - }); - - let (res_a, res_b) = tokio::join!(task_a, task_b); - let _ = res_a.expect("task a must join"); - let _ = res_b.expect("task b must join"); - - assert!( - gate.total_bytes() <= 1, - "concurrent same-user reads must not consume more than one byte under quota=1" - ); - assert!( - stats.get_user_total_octets(&user) <= 1, - "concurrent same-user reads must not account over limit" - ); -} - -#[tokio::test] -async fn stress_same_user_quota_parallel_relays_never_exceed_limit() { - let stats = Arc::new(Stats::new()); - let user = "parallel-quota-user"; - - for _ in 0..128 { - let (mut client_peer_a, relay_client_a) = duplex(256); - let (relay_server_a, mut server_peer_a) = duplex(256); - let (mut client_peer_b, relay_client_b) = duplex(256); - let (relay_server_b, mut server_peer_b) = duplex(256); - - let (client_reader_a, client_writer_a) = tokio::io::split(relay_client_a); - let (server_reader_a, server_writer_a) = tokio::io::split(relay_server_a); - let (client_reader_b, client_writer_b) = tokio::io::split(relay_client_b); - let (server_reader_b, server_writer_b) = tokio::io::split(relay_server_b); - - let relay_a = tokio::spawn(relay_bidirectional( - client_reader_a, - client_writer_a, - server_reader_a, - server_writer_a, - 64, - 64, - user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - let relay_b = tokio::spawn(relay_bidirectional( - client_reader_b, - client_writer_b, - server_reader_b, - server_writer_b, - 64, - 64, - user, - Arc::clone(&stats), - Some(1), - Arc::new(BufferPool::new()), - )); - - let _ = tokio::join!( - client_peer_a.write_all(&[0x01]), - server_peer_a.write_all(&[0x02]), - client_peer_b.write_all(&[0x03]), - server_peer_b.write_all(&[0x04]), - ); - - let _ = timeout( - Duration::from_millis(50), - poll_fn(|cx| { - let mut one = [0u8; 1]; - let _ = Pin::new(&mut client_peer_a).poll_read(cx, &mut ReadBuf::new(&mut one)); - Poll::Ready(()) - }), - ) - .await; - - drop(client_peer_a); - drop(server_peer_a); - drop(client_peer_b); - drop(server_peer_b); - - let _ = timeout(Duration::from_secs(1), relay_a).await; - let _ = timeout(Duration::from_secs(1), relay_b).await; - - assert!( - stats.get_user_total_octets(user) <= 1, - "parallel relays must not exceed configured quota" - ); - } -} - -impl FaultyReader { - fn permission_denied_with_message(message: impl Into) -> Self { - Self { - error_once: Some(io::Error::new( - io::ErrorKind::PermissionDenied, - message.into(), - )), - } - } -} - -impl AsyncRead for FaultyReader { - fn poll_read( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - _buf: &mut ReadBuf<'_>, - ) -> Poll> { - if let Some(err) = self.error_once.take() { - return Poll::Ready(Err(err)); - } - Poll::Ready(Ok(())) - } -} - -#[tokio::test] -async fn relay_bidirectional_does_not_misclassify_transport_permission_denied_as_quota() { - let stats = Arc::new(Stats::new()); - let (client_peer, relay_client) = duplex(4096); - let (client_reader, client_writer) = tokio::io::split(relay_client); - - let relay_result = relay_bidirectional( - client_reader, - client_writer, - FaultyReader::permission_denied_with_message("user data quota exceeded"), - tokio::io::sink(), - 1024, - 1024, - "non-quota-permission-denied", - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - ) - .await; - - drop(client_peer); - - assert!( - matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied), - "non-quota transport PermissionDenied errors must remain IO errors" - ); -} - -#[tokio::test] -async fn relay_bidirectional_light_fuzz_permission_denied_messages_remain_io_errors() { - let mut rng = StdRng::seed_from_u64(0xA11CE0B5); - - for i in 0..128u64 { - let stats = Arc::new(Stats::new()); - let (client_peer, relay_client) = duplex(1024); - let (client_reader, client_writer) = tokio::io::split(relay_client); - - let random_len = rng.random_range(1..=48); - let mut msg = String::with_capacity(random_len); - for _ in 0..random_len { - let ch = (b'a' + (rng.random::() % 26)) as char; - msg.push(ch); - } - // Include the legacy quota string in a subset of fuzz cases to validate - // collision resistance against message-based classification. - if i % 7 == 0 { - msg = "user data quota exceeded".to_string(); - } - - let relay_result = relay_bidirectional( - client_reader, - client_writer, - FaultyReader::permission_denied_with_message(msg), - tokio::io::sink(), - 1024, - 1024, - "fuzz-perm-denied", - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - ) - .await; - - drop(client_peer); - - assert!( - matches!(relay_result, Err(ProxyError::Io(ref err)) if err.kind() == io::ErrorKind::PermissionDenied), - "transport PermissionDenied case must stay typed as IO regardless of message content" - ); - } -} - -#[tokio::test] -async fn relay_half_close_keeps_reverse_direction_progressing() { - let stats = Arc::new(Stats::new()); - let user = "half-close-user"; - - let (client_peer, relay_client) = duplex(1024); - let (relay_server, server_peer) = duplex(1024); - - let (client_reader, client_writer) = tokio::io::split(relay_client); - let (server_reader, server_writer) = tokio::io::split(relay_server); - let (mut cp_reader, mut cp_writer) = tokio::io::split(client_peer); - let (mut sp_reader, mut sp_writer) = tokio::io::split(server_peer); - - let relay_task = tokio::spawn(relay_bidirectional( - client_reader, - client_writer, - server_reader, - server_writer, - 8192, - 8192, - user, - Arc::clone(&stats), - None, - Arc::new(BufferPool::new()), - )); - - sp_writer - .write_all(&[0x10, 0x20, 0x30, 0x40]) - .await - .unwrap(); - sp_writer.shutdown().await.unwrap(); - - let mut inbound = [0u8; 4]; - cp_reader.read_exact(&mut inbound).await.unwrap(); - assert_eq!(inbound, [0x10, 0x20, 0x30, 0x40]); - - cp_writer - .write_all(&[0xaa, 0xbb, 0xcc, 0xdd]) - .await - .unwrap(); - let mut outbound = [0u8; 4]; - sp_reader.read_exact(&mut outbound).await.unwrap(); - assert_eq!(outbound, [0xaa, 0xbb, 0xcc, 0xdd]); - - relay_task.abort(); - let joined = relay_task.await; - assert!(joined.is_err(), "aborted relay task must return join error"); -} diff --git a/src/stats/mod.rs b/src/stats/mod.rs index dc455a1..ff15d4f 100644 --- a/src/stats/mod.rs +++ b/src/stats/mod.rs @@ -238,10 +238,12 @@ pub struct Stats { me_inline_recovery_total: AtomicU64, ip_reservation_rollback_tcp_limit_total: AtomicU64, ip_reservation_rollback_quota_limit_total: AtomicU64, + quota_write_fail_bytes_total: AtomicU64, + quota_write_fail_events_total: AtomicU64, telemetry_core_enabled: AtomicBool, telemetry_user_enabled: AtomicBool, telemetry_me_level: AtomicU8, - user_stats: DashMap, + user_stats: DashMap>, user_stats_last_cleanup_epoch_secs: AtomicU64, start_time: parking_lot::RwLock>, } @@ -254,9 +256,51 @@ pub struct UserStats { pub octets_to_client: AtomicU64, pub msgs_from_client: AtomicU64, pub msgs_to_client: AtomicU64, + /// Total bytes charged against per-user quota admission. + /// + /// This counter is the single source of truth for quota enforcement and + /// intentionally tracks attempted traffic, not guaranteed delivery. + pub quota_used: AtomicU64, pub last_seen_epoch_secs: AtomicU64, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QuotaReserveError { + LimitExceeded, + Contended, +} + +impl UserStats { + #[inline] + pub fn quota_used(&self) -> u64 { + self.quota_used.load(Ordering::Relaxed) + } + + /// Attempts one CAS reservation step against the quota counter. + /// + /// Callers control retry/yield policy. This primitive intentionally does + /// not block or sleep so both sync poll paths and async paths can wrap it + /// with their own contention strategy. + #[inline] + pub fn quota_try_reserve(&self, bytes: u64, limit: u64) -> Result { + let current = self.quota_used.load(Ordering::Relaxed); + if bytes > limit.saturating_sub(current) { + return Err(QuotaReserveError::LimitExceeded); + } + + let next = current.saturating_add(bytes); + match self.quota_used.compare_exchange_weak( + current, + next, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => Ok(next), + Err(_) => Err(QuotaReserveError::Contended), + } + } +} + impl Stats { pub fn new() -> Self { let stats = Self::default(); @@ -316,6 +360,74 @@ impl Stats { .store(Self::now_epoch_secs(), Ordering::Relaxed); } + pub(crate) fn get_or_create_user_stats_handle(&self, user: &str) -> Arc { + self.maybe_cleanup_user_stats(); + if let Some(existing) = self.user_stats.get(user) { + let handle = Arc::clone(existing.value()); + Self::touch_user_stats(handle.as_ref()); + return handle; + } + + let entry = self.user_stats.entry(user.to_string()).or_default(); + if entry.last_seen_epoch_secs.load(Ordering::Relaxed) == 0 { + Self::touch_user_stats(entry.value().as_ref()); + } + Arc::clone(entry.value()) + } + + #[inline] + pub(crate) fn add_user_octets_from_handle(&self, user_stats: &UserStats, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats + .octets_from_client + .fetch_add(bytes, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn add_user_octets_to_handle(&self, user_stats: &UserStats, bytes: u64) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats + .octets_to_client + .fetch_add(bytes, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn increment_user_msgs_from_handle(&self, user_stats: &UserStats) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); + } + + #[inline] + pub(crate) fn increment_user_msgs_to_handle(&self, user_stats: &UserStats) { + if !self.telemetry_user_enabled() { + return; + } + Self::touch_user_stats(user_stats); + user_stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); + } + + /// Charges already committed bytes in a post-I/O path. + /// + /// This helper is intentionally separate from `quota_try_reserve` to avoid + /// mixing reserve and post-charge on a single I/O event. + #[inline] + pub(crate) fn quota_charge_post_write(&self, user_stats: &UserStats, bytes: u64) -> u64 { + Self::touch_user_stats(user_stats); + user_stats + .quota_used + .fetch_add(bytes, Ordering::Relaxed) + .saturating_add(bytes) + } + fn maybe_cleanup_user_stats(&self) { const USER_STATS_CLEANUP_INTERVAL_SECS: u64 = 60; const USER_STATS_IDLE_TTL_SECS: u64 = 24 * 60 * 60; @@ -704,7 +816,8 @@ impl Stats { } pub fn increment_me_d2c_data_frames_total(&self) { if self.telemetry_me_allows_normal() { - self.me_d2c_data_frames_total.fetch_add(1, Ordering::Relaxed); + self.me_d2c_data_frames_total + .fetch_add(1, Ordering::Relaxed); } } pub fn increment_me_d2c_ack_frames_total(&self) { @@ -1114,6 +1227,18 @@ impl Stats { .fetch_add(1, Ordering::Relaxed); } } + pub fn add_quota_write_fail_bytes_total(&self, bytes: u64) { + if self.telemetry_core_enabled() { + self.quota_write_fail_bytes_total + .fetch_add(bytes, Ordering::Relaxed); + } + } + pub fn increment_quota_write_fail_events_total(&self) { + if self.telemetry_core_enabled() { + self.quota_write_fail_events_total + .fetch_add(1, Ordering::Relaxed); + } + } pub fn increment_me_endpoint_quarantine_total(&self) { if self.telemetry_me_allows_normal() { self.me_endpoint_quarantine_total @@ -1588,7 +1713,8 @@ impl Stats { self.me_d2c_batch_bytes_bucket_1k_4k.load(Ordering::Relaxed) } pub fn get_me_d2c_batch_bytes_bucket_4k_16k(&self) -> u64 { - self.me_d2c_batch_bytes_bucket_4k_16k.load(Ordering::Relaxed) + self.me_d2c_batch_bytes_bucket_4k_16k + .load(Ordering::Relaxed) } pub fn get_me_d2c_batch_bytes_bucket_16k_64k(&self) -> u64 { self.me_d2c_batch_bytes_bucket_16k_64k @@ -1764,19 +1890,19 @@ impl Stats { self.ip_reservation_rollback_quota_limit_total .load(Ordering::Relaxed) } + pub fn get_quota_write_fail_bytes_total(&self) -> u64 { + self.quota_write_fail_bytes_total.load(Ordering::Relaxed) + } + pub fn get_quota_write_fail_events_total(&self) -> u64 { + self.quota_write_fail_events_total.load(Ordering::Relaxed) + } pub fn increment_user_connects(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.connects.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); + let stats = self.get_or_create_user_stats_handle(user); + Self::touch_user_stats(stats.as_ref()); stats.connects.fetch_add(1, Ordering::Relaxed); } @@ -1784,14 +1910,8 @@ impl Stats { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.curr_connects.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); + let stats = self.get_or_create_user_stats_handle(user); + Self::touch_user_stats(stats.as_ref()); stats.curr_connects.fetch_add(1, Ordering::Relaxed); } @@ -1800,9 +1920,8 @@ impl Stats { return true; } - self.maybe_cleanup_user_stats(); - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); + let stats = self.get_or_create_user_stats_handle(user); + Self::touch_user_stats(stats.as_ref()); let counter = &stats.curr_connects; let mut current = counter.load(Ordering::Relaxed); @@ -1827,7 +1946,7 @@ impl Stats { pub fn decrement_user_curr_connects(&self, user: &str) { self.maybe_cleanup_user_stats(); if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); + Self::touch_user_stats(stats.value().as_ref()); let counter = &stats.curr_connects; let mut current = counter.load(Ordering::Relaxed); loop { @@ -1858,86 +1977,32 @@ impl Stats { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.add_user_octets_from_handle(stats.as_ref(), bytes); } pub fn add_user_octets_to(&self, user: &str, bytes: u64) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed); - } - - pub fn sub_user_octets_to(&self, user: &str, bytes: u64) { - if !self.telemetry_user_enabled() { - return; - } - self.maybe_cleanup_user_stats(); - let Some(stats) = self.user_stats.get(user) else { - return; - }; - - Self::touch_user_stats(stats.value()); - let counter = &stats.octets_to_client; - let mut current = counter.load(Ordering::Relaxed); - loop { - let next = current.saturating_sub(bytes); - match counter.compare_exchange_weak( - current, - next, - Ordering::Relaxed, - Ordering::Relaxed, - ) { - Ok(_) => break, - Err(actual) => current = actual, - } - } + let stats = self.get_or_create_user_stats_handle(user); + self.add_user_octets_to_handle(stats.as_ref(), bytes); } pub fn increment_user_msgs_from(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.msgs_from_client.fetch_add(1, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.increment_user_msgs_from_handle(stats.as_ref()); } pub fn increment_user_msgs_to(&self, user: &str) { if !self.telemetry_user_enabled() { return; } - self.maybe_cleanup_user_stats(); - if let Some(stats) = self.user_stats.get(user) { - Self::touch_user_stats(stats.value()); - stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); - return; - } - let stats = self.user_stats.entry(user.to_string()).or_default(); - Self::touch_user_stats(stats.value()); - stats.msgs_to_client.fetch_add(1, Ordering::Relaxed); + let stats = self.get_or_create_user_stats_handle(user); + self.increment_user_msgs_to_handle(stats.as_ref()); } pub fn get_user_total_octets(&self, user: &str) -> u64 { @@ -1950,6 +2015,13 @@ impl Stats { .unwrap_or(0) } + pub fn get_user_quota_used(&self, user: &str) -> u64 { + self.user_stats + .get(user) + .map(|s| s.quota_used.load(Ordering::Relaxed)) + .unwrap_or(0) + } + pub fn get_handshake_timeouts(&self) -> u64 { self.handshake_timeouts.load(Ordering::Relaxed) } @@ -2015,7 +2087,7 @@ impl Stats { .load(Ordering::Relaxed) } - pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, UserStats> { + pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, Arc> { self.user_stats.iter() } @@ -2163,6 +2235,22 @@ impl ReplayChecker { found } + fn check_only_internal( + &self, + data: &[u8], + shards: &[Mutex], + window: Duration, + ) -> bool { + self.checks.fetch_add(1, Ordering::Relaxed); + let idx = self.get_shard_idx(data); + let mut shard = shards[idx].lock(); + let found = shard.check(data, Instant::now(), window); + if found { + self.hits.fetch_add(1, Ordering::Relaxed); + } + found + } + fn add_only(&self, data: &[u8], shards: &[Mutex], window: Duration) { self.additions.fetch_add(1, Ordering::Relaxed); let idx = self.get_shard_idx(data); @@ -2186,7 +2274,7 @@ impl ReplayChecker { self.add_only(data, &self.handshake_shards, self.window) } pub fn check_tls_digest(&self, data: &[u8]) -> bool { - self.check_and_add_tls_digest(data) + self.check_only_internal(data, &self.tls_shards, self.tls_window) } pub fn add_tls_digest(&self, data: &[u8]) { self.add_only(data, &self.tls_shards, self.tls_window) @@ -2290,6 +2378,7 @@ mod tests { use super::*; use crate::config::MeTelemetryLevel; use std::sync::Arc; + use std::sync::atomic::{AtomicU64, Ordering}; #[test] fn test_stats_shared_counters() { @@ -2457,6 +2546,137 @@ mod tests { } assert_eq!(checker.stats().total_entries, 500); } + + #[test] + fn test_quota_reserve_under_contention_hits_limit_exactly() { + let user_stats = Arc::new(UserStats::default()); + let successes = Arc::new(AtomicU64::new(0)); + let limit = 8_192u64; + let mut workers = Vec::new(); + + for _ in 0..8 { + let user_stats = user_stats.clone(); + let successes = successes.clone(); + workers.push(std::thread::spawn(move || { + loop { + match user_stats.quota_try_reserve(1, limit) { + Ok(_) => { + successes.fetch_add(1, Ordering::Relaxed); + } + Err(QuotaReserveError::Contended) => { + std::hint::spin_loop(); + } + Err(QuotaReserveError::LimitExceeded) => { + break; + } + } + } + })); + } + + for worker in workers { + worker.join().expect("worker thread must finish"); + } + + assert_eq!( + successes.load(Ordering::Relaxed), + limit, + "successful reservations must stop exactly at limit" + ); + assert_eq!(user_stats.quota_used(), limit); + } + + #[test] + fn test_quota_reserve_200x_1k_reaches_100k_without_overshoot() { + let user_stats = Arc::new(UserStats::default()); + let successes = Arc::new(AtomicU64::new(0)); + let failures = Arc::new(AtomicU64::new(0)); + let attempts = 200usize; + let reserve_bytes = 1_024u64; + let limit = 100 * 1_024u64; + let mut workers = Vec::with_capacity(attempts); + + for _ in 0..attempts { + let user_stats = user_stats.clone(); + let successes = successes.clone(); + let failures = failures.clone(); + workers.push(std::thread::spawn(move || { + loop { + match user_stats.quota_try_reserve(reserve_bytes, limit) { + Ok(_) => { + successes.fetch_add(1, Ordering::Relaxed); + return; + } + Err(QuotaReserveError::LimitExceeded) => { + failures.fetch_add(1, Ordering::Relaxed); + return; + } + Err(QuotaReserveError::Contended) => { + std::hint::spin_loop(); + } + } + } + })); + } + + for worker in workers { + worker.join().expect("reservation worker must finish"); + } + + assert_eq!( + successes.load(Ordering::Relaxed), + 100, + "exactly 100 reservations of 1 KiB must fit into a 100 KiB quota" + ); + assert_eq!( + failures.load(Ordering::Relaxed), + 100, + "remaining workers must fail once quota is fully reserved" + ); + assert_eq!(user_stats.quota_used(), limit); + } + + #[test] + fn test_quota_used_is_authoritative_and_independent_from_octets_telemetry() { + let stats = Stats::new(); + let user = "quota-authoritative-user"; + let user_stats = stats.get_or_create_user_stats_handle(user); + + stats.add_user_octets_to_handle(&user_stats, 5); + assert_eq!(stats.get_user_total_octets(user), 5); + assert_eq!(stats.get_user_quota_used(user), 0); + + stats.quota_charge_post_write(&user_stats, 7); + assert_eq!(stats.get_user_total_octets(user), 5); + assert_eq!(stats.get_user_quota_used(user), 7); + } + + #[test] + fn test_cached_handle_survives_map_cleanup_until_last_drop() { + let stats = Stats::new(); + let user = "quota-handle-lifetime-user"; + let user_stats = stats.get_or_create_user_stats_handle(user); + let weak = Arc::downgrade(&user_stats); + + stats.user_stats.remove(user); + assert!( + stats.user_stats.get(user).is_none(), + "map cleanup should remove idle entry" + ); + assert!( + weak.upgrade().is_some(), + "cached handle must keep user stats object alive after map removal" + ); + + stats.quota_charge_post_write(user_stats.as_ref(), 3); + assert_eq!(user_stats.quota_used(), 3); + + drop(user_stats); + assert!( + weak.upgrade().is_none(), + "user stats object must be dropped after the last cached handle is released" + ); + } } #[cfg(test)] @@ -2466,7 +2686,3 @@ mod connection_lease_security_tests; #[cfg(test)] #[path = "tests/replay_checker_security_tests.rs"] mod replay_checker_security_tests; - -#[cfg(test)] -#[path = "tests/user_octets_sub_security_tests.rs"] -mod user_octets_sub_security_tests; diff --git a/src/stats/tests/user_octets_sub_security_tests.rs b/src/stats/tests/user_octets_sub_security_tests.rs deleted file mode 100644 index d4e7580..0000000 --- a/src/stats/tests/user_octets_sub_security_tests.rs +++ /dev/null @@ -1,151 +0,0 @@ -use super::*; -use std::sync::Arc; -use std::thread; - -#[test] -fn sub_user_octets_to_underflow_saturates_at_zero() { - let stats = Stats::new(); - let user = "sub-underflow-user"; - - stats.add_user_octets_to(user, 3); - stats.sub_user_octets_to(user, 100); - - assert_eq!(stats.get_user_total_octets(user), 0); -} - -#[test] -fn sub_user_octets_to_does_not_affect_octets_from_client() { - let stats = Stats::new(); - let user = "sub-isolation-user"; - - stats.add_user_octets_from(user, 17); - stats.add_user_octets_to(user, 5); - stats.sub_user_octets_to(user, 3); - - assert_eq!(stats.get_user_total_octets(user), 19); -} - -#[test] -fn light_fuzz_add_sub_model_matches_saturating_reference() { - let stats = Stats::new(); - let user = "sub-fuzz-user"; - let mut seed = 0x91D2_4CB8_EE77_1101u64; - let mut model_to = 0u64; - - for _ in 0..8192 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - - let amt = ((seed >> 8) & 0x3f) + 1; - if (seed & 1) == 0 { - stats.add_user_octets_to(user, amt); - model_to = model_to.saturating_add(amt); - } else { - stats.sub_user_octets_to(user, amt); - model_to = model_to.saturating_sub(amt); - } - } - - assert_eq!(stats.get_user_total_octets(user), model_to); -} - -#[test] -fn stress_parallel_add_sub_never_underflows_or_panics() { - let stats = Arc::new(Stats::new()); - let user = "sub-stress-user"; - // Pre-fund with a large offset so subtractions never saturate at zero. - // This guarantees commutative updates, making the final state deterministic. - let base_offset = 10_000_000u64; - stats.add_user_octets_to(user, base_offset); - - let mut workers = Vec::new(); - - for tid in 0..16u64 { - let stats_for_thread = Arc::clone(&stats); - workers.push(thread::spawn(move || { - let mut seed = 0xD00D_1000_0000_0000u64 ^ tid; - let mut net_delta = 0i64; - for _ in 0..4096 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - let amt = ((seed >> 8) & 0x1f) + 1; - - if (seed & 1) == 0 { - stats_for_thread.add_user_octets_to(user, amt); - net_delta += amt as i64; - } else { - stats_for_thread.sub_user_octets_to(user, amt); - net_delta -= amt as i64; - } - } - - net_delta - })); - } - - let mut expected_net_delta = 0i64; - for worker in workers { - expected_net_delta += worker - .join() - .expect("sub-user stress worker must not panic"); - } - - let expected_total = (base_offset as i64 + expected_net_delta) as u64; - let total = stats.get_user_total_octets(user); - assert_eq!( - total, expected_total, - "concurrent add/sub lost updates or suffered ABA races" - ); -} - -#[test] -fn sub_user_octets_to_missing_user_is_noop() { - let stats = Stats::new(); - stats.sub_user_octets_to("missing-user", 1024); - assert_eq!(stats.get_user_total_octets("missing-user"), 0); -} - -#[test] -fn stress_parallel_per_user_models_remain_exact() { - let stats = Arc::new(Stats::new()); - let mut workers = Vec::new(); - - for tid in 0..16u64 { - let stats_for_thread = Arc::clone(&stats); - workers.push(thread::spawn(move || { - let user = format!("sub-per-user-{tid}"); - let mut seed = 0xFACE_0000_0000_0000u64 ^ tid; - let mut model = 0u64; - - for _ in 0..4096 { - seed ^= seed << 7; - seed ^= seed >> 9; - seed ^= seed << 8; - let amt = ((seed >> 8) & 0x3f) + 1; - - if (seed & 1) == 0 { - stats_for_thread.add_user_octets_to(&user, amt); - model = model.saturating_add(amt); - } else { - stats_for_thread.sub_user_octets_to(&user, amt); - model = model.saturating_sub(amt); - } - } - - (user, model) - })); - } - - for worker in workers { - let (user, model) = worker - .join() - .expect("per-user subtract stress worker must not panic"); - assert_eq!( - stats.get_user_total_octets(&user), - model, - "per-user parallel model diverged" - ); - } -} \ No newline at end of file diff --git a/src/stream/frame_stream_padding_security_tests.rs b/src/stream/frame_stream_padding_security_tests.rs index 83b30f9..1ec787e 100644 --- a/src/stream/frame_stream_padding_security_tests.rs +++ b/src/stream/frame_stream_padding_security_tests.rs @@ -14,7 +14,10 @@ fn padding_rounding_equivalent_for_extensive_safe_domain() { let old = old_padding_round_up_to_4(len).expect("old expression must be safe"); let new = new_padding_round_up_to_4(len).expect("new expression must be safe"); assert_eq!(old, new, "mismatch for len={len}"); - assert!(new >= len, "rounded length must not shrink: len={len}, out={new}"); + assert!( + new >= len, + "rounded length must not shrink: len={len}, out={new}" + ); assert_eq!(new % 4, 0, "rounded length must stay 4-byte aligned"); } } diff --git a/src/tests/ip_tracker_encapsulation_adversarial_tests.rs b/src/tests/ip_tracker_encapsulation_adversarial_tests.rs index cf42e75..3fc9727 100644 --- a/src/tests/ip_tracker_encapsulation_adversarial_tests.rs +++ b/src/tests/ip_tracker_encapsulation_adversarial_tests.rs @@ -44,7 +44,10 @@ async fn encapsulation_repeated_queue_poison_recovery_preserves_forward_progress let ip_primary = ip_from_idx(10_001); let ip_alt = ip_from_idx(10_002); - tracker.check_and_add("encap-poison", ip_primary).await.unwrap(); + tracker + .check_and_add("encap-poison", ip_primary) + .await + .unwrap(); for _ in 0..128 { let queue = tracker.cleanup_queue_mutex_for_tests(); diff --git a/src/tls_front/fetcher.rs b/src/tls_front/fetcher.rs index 4408b5a..bbfc336 100644 --- a/src/tls_front/fetcher.rs +++ b/src/tls_front/fetcher.rs @@ -1,7 +1,9 @@ #![allow(clippy::too_many_arguments)] +use dashmap::DashMap; use std::sync::Arc; -use std::time::Duration; +use std::sync::OnceLock; +use std::time::{Duration, Instant}; use anyhow::{Result, anyhow}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -21,7 +23,8 @@ use rustls::{DigitallySignedStruct, Error as RustlsError}; use x509_parser::certificate::X509Certificate; use x509_parser::prelude::FromDer; -use crate::crypto::SecureRandom; +use crate::config::TlsFetchProfile; +use crate::crypto::{SecureRandom, sha256}; use crate::network::dns_overrides::resolve_socket_addr; use crate::protocol::constants::{ TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, @@ -78,6 +81,200 @@ impl ServerCertVerifier for NoVerify { } } +#[derive(Debug, Clone)] +pub struct TlsFetchStrategy { + pub profiles: Vec, + pub strict_route: bool, + pub attempt_timeout: Duration, + pub total_budget: Duration, + pub grease_enabled: bool, + pub deterministic: bool, + pub profile_cache_ttl: Duration, +} + +impl TlsFetchStrategy { + #[allow(dead_code)] + pub fn single_attempt(connect_timeout: Duration) -> Self { + Self { + profiles: vec![TlsFetchProfile::CompatTls12], + strict_route: false, + attempt_timeout: connect_timeout.max(Duration::from_millis(1)), + total_budget: connect_timeout.max(Duration::from_millis(1)), + grease_enabled: false, + deterministic: false, + profile_cache_ttl: Duration::ZERO, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct ProfileCacheKey { + host: String, + port: u16, + sni: String, + scope: Option, + proxy_protocol: u8, + route_hint: RouteHint, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum RouteHint { + Direct, + Upstream, + Unix, +} + +#[derive(Debug, Clone, Copy)] +struct ProfileCacheValue { + profile: TlsFetchProfile, + updated_at: Instant, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum FetchErrorKind { + Connect, + Route, + EarlyEof, + Timeout, + ServerHelloMissing, + TlsAlert, + Parse, + Other, +} + +static PROFILE_CACHE: OnceLock> = OnceLock::new(); + +fn profile_cache() -> &'static DashMap { + PROFILE_CACHE.get_or_init(DashMap::new) +} + +fn route_hint( + upstream: Option<&std::sync::Arc>, + unix_sock: Option<&str>, +) -> RouteHint { + if unix_sock.is_some() { + RouteHint::Unix + } else if upstream.is_some() { + RouteHint::Upstream + } else { + RouteHint::Direct + } +} + +fn profile_cache_key( + host: &str, + port: u16, + sni: &str, + upstream: Option<&std::sync::Arc>, + scope: Option<&str>, + proxy_protocol: u8, + unix_sock: Option<&str>, +) -> ProfileCacheKey { + ProfileCacheKey { + host: host.to_string(), + port, + sni: sni.to_string(), + scope: scope.map(ToString::to_string), + proxy_protocol, + route_hint: route_hint(upstream, unix_sock), + } +} + +fn classify_fetch_error(err: &anyhow::Error) -> FetchErrorKind { + for cause in err.chain() { + if let Some(io) = cause.downcast_ref::() { + return match io.kind() { + std::io::ErrorKind::TimedOut => FetchErrorKind::Timeout, + std::io::ErrorKind::UnexpectedEof => FetchErrorKind::EarlyEof, + std::io::ErrorKind::ConnectionRefused + | std::io::ErrorKind::ConnectionAborted + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::NotConnected + | std::io::ErrorKind::AddrNotAvailable => FetchErrorKind::Connect, + _ => FetchErrorKind::Other, + }; + } + } + + let message = err.to_string().to_lowercase(); + if message.contains("upstream route") { + FetchErrorKind::Route + } else if message.contains("serverhello not received") { + FetchErrorKind::ServerHelloMissing + } else if message.contains("alert") { + FetchErrorKind::TlsAlert + } else if message.contains("parse") { + FetchErrorKind::Parse + } else if message.contains("timed out") || message.contains("deadline has elapsed") { + FetchErrorKind::Timeout + } else if message.contains("eof") { + FetchErrorKind::EarlyEof + } else { + FetchErrorKind::Other + } +} + +fn order_profiles( + strategy: &TlsFetchStrategy, + cache_key: Option<&ProfileCacheKey>, + now: Instant, +) -> Vec { + let mut ordered = if strategy.profiles.is_empty() { + vec![TlsFetchProfile::CompatTls12] + } else { + strategy.profiles.clone() + }; + + if strategy.profile_cache_ttl.is_zero() { + return ordered; + } + + let Some(key) = cache_key else { + return ordered; + }; + + if let Some(cached) = profile_cache().get(key) { + let age = now.saturating_duration_since(cached.updated_at); + if age > strategy.profile_cache_ttl { + drop(cached); + profile_cache().remove(key); + return ordered; + } + + if let Some(pos) = ordered + .iter() + .position(|profile| *profile == cached.profile) + { + if pos != 0 { + ordered.swap(0, pos); + } + } + } + + ordered +} + +fn remember_profile_success( + strategy: &TlsFetchStrategy, + cache_key: Option, + profile: TlsFetchProfile, + now: Instant, +) { + if strategy.profile_cache_ttl.is_zero() { + return; + } + let Some(key) = cache_key else { + return; + }; + profile_cache().insert( + key, + ProfileCacheValue { + profile, + updated_at: now, + }, + ); +} + fn build_client_config() -> Arc { let root = rustls::RootCertStore::empty(); @@ -95,7 +292,114 @@ fn build_client_config() -> Arc { Arc::new(config) } -fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { +fn deterministic_bytes(seed: &str, len: usize) -> Vec { + let mut out = Vec::with_capacity(len); + let mut counter: u32 = 0; + while out.len() < len { + let mut chunk_seed = Vec::with_capacity(seed.len() + std::mem::size_of::()); + chunk_seed.extend_from_slice(seed.as_bytes()); + chunk_seed.extend_from_slice(&counter.to_le_bytes()); + out.extend_from_slice(&sha256(&chunk_seed)); + counter = counter.wrapping_add(1); + } + out.truncate(len); + out +} + +fn profile_cipher_suites(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN_CHROME: &[u16] = &[ + 0x1301, 0x1302, 0x1303, 0xc02b, 0xc02c, 0xcca9, 0xc02f, 0xc030, 0xcca8, 0x009e, 0x00ff, + ]; + const MODERN_FIREFOX: &[u16] = &[ + 0x1301, 0x1303, 0x1302, 0xc02b, 0xcca9, 0xc02c, 0xc02f, 0xcca8, 0xc030, 0x009e, 0x00ff, + ]; + const COMPAT_TLS12: &[u16] = &[ + 0xc02b, 0xc02c, 0xc02f, 0xc030, 0xcca9, 0xcca8, 0x1301, 0x1302, 0x1303, 0x009e, 0x00ff, + ]; + const LEGACY_MINIMAL: &[u16] = &[0xc02b, 0xc02f, 0x1301, 0x1302, 0x00ff]; + + match profile { + TlsFetchProfile::ModernChromeLike => MODERN_CHROME, + TlsFetchProfile::ModernFirefoxLike => MODERN_FIREFOX, + TlsFetchProfile::CompatTls12 => COMPAT_TLS12, + TlsFetchProfile::LegacyMinimal => LEGACY_MINIMAL, + } +} + +fn profile_groups(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN: &[u16] = &[0x001d, 0x0017, 0x0018]; // x25519, secp256r1, secp384r1 + const COMPAT: &[u16] = &[0x001d, 0x0017]; + const LEGACY: &[u16] = &[0x0017]; + + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, + TlsFetchProfile::CompatTls12 => COMPAT, + TlsFetchProfile::LegacyMinimal => LEGACY, + } +} + +fn profile_sig_algs(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN: &[u16] = &[0x0804, 0x0805, 0x0403, 0x0503, 0x0806]; + const COMPAT: &[u16] = &[0x0403, 0x0503, 0x0804, 0x0805]; + const LEGACY: &[u16] = &[0x0403, 0x0804]; + + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, + TlsFetchProfile::CompatTls12 => COMPAT, + TlsFetchProfile::LegacyMinimal => LEGACY, + } +} + +fn profile_alpn(profile: TlsFetchProfile) -> &'static [&'static [u8]] { + const H2_HTTP11: &[&[u8]] = &[b"h2", b"http/1.1"]; + const HTTP11: &[&[u8]] = &[b"http/1.1"]; + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => H2_HTTP11, + TlsFetchProfile::CompatTls12 | TlsFetchProfile::LegacyMinimal => HTTP11, + } +} + +fn profile_supported_versions(profile: TlsFetchProfile) -> &'static [u16] { + const MODERN: &[u16] = &[0x0304, 0x0303]; + const COMPAT: &[u16] = &[0x0303, 0x0304]; + const LEGACY: &[u16] = &[0x0303]; + match profile { + TlsFetchProfile::ModernChromeLike | TlsFetchProfile::ModernFirefoxLike => MODERN, + TlsFetchProfile::CompatTls12 => COMPAT, + TlsFetchProfile::LegacyMinimal => LEGACY, + } +} + +fn profile_padding_target(profile: TlsFetchProfile) -> usize { + match profile { + TlsFetchProfile::ModernChromeLike => 220, + TlsFetchProfile::ModernFirefoxLike => 200, + TlsFetchProfile::CompatTls12 => 180, + TlsFetchProfile::LegacyMinimal => 64, + } +} + +fn grease_value(rng: &SecureRandom, deterministic: bool, seed: &str) -> u16 { + const GREASE_VALUES: [u16; 16] = [ + 0x0a0a, 0x1a1a, 0x2a2a, 0x3a3a, 0x4a4a, 0x5a5a, 0x6a6a, 0x7a7a, 0x8a8a, 0x9a9a, 0xaaaa, + 0xbaba, 0xcaca, 0xdada, 0xeaea, 0xfafa, + ]; + if deterministic { + let idx = deterministic_bytes(seed, 1)[0] as usize % GREASE_VALUES.len(); + GREASE_VALUES[idx] + } else { + let idx = (rng.bytes(1)[0] as usize) % GREASE_VALUES.len(); + GREASE_VALUES[idx] + } +} + +fn build_client_hello( + sni: &str, + rng: &SecureRandom, + profile: TlsFetchProfile, + grease_enabled: bool, + deterministic: bool, +) -> Vec { // === ClientHello body === let mut body = Vec::new(); @@ -103,21 +407,24 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { body.extend_from_slice(&[0x03, 0x03]); // Random - body.extend_from_slice(&rng.bytes(32)); + if deterministic { + body.extend_from_slice(&deterministic_bytes(&format!("tls-fetch-random:{sni}"), 32)); + } else { + body.extend_from_slice(&rng.bytes(32)); + } // Session ID: empty body.push(0); - // Cipher suites (common minimal set, TLS1.3 + a few 1.2 fallbacks) - let cipher_suites: [u8; 10] = [ - 0x13, 0x01, // TLS_AES_128_GCM_SHA256 - 0x13, 0x02, // TLS_AES_256_GCM_SHA384 - 0x13, 0x03, // TLS_CHACHA20_POLY1305_SHA256 - 0x00, 0x2f, // TLS_RSA_WITH_AES_128_CBC_SHA (legacy) - 0x00, 0xff, // RENEGOTIATION_INFO_SCSV - ]; - body.extend_from_slice(&(cipher_suites.len() as u16).to_be_bytes()); - body.extend_from_slice(&cipher_suites); + let mut cipher_suites = profile_cipher_suites(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("cipher:{sni}")); + cipher_suites.insert(0, grease); + } + body.extend_from_slice(&((cipher_suites.len() * 2) as u16).to_be_bytes()); + for suite in cipher_suites { + body.extend_from_slice(&suite.to_be_bytes()); + } // Compression methods: null only body.push(1); @@ -138,7 +445,11 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { exts.extend_from_slice(&sni_ext); // supported_groups - let groups: [u16; 2] = [0x001d, 0x0017]; // x25519, secp256r1 + let mut groups = profile_groups(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("group:{sni}")); + groups.insert(0, grease); + } exts.extend_from_slice(&0x000au16.to_be_bytes()); exts.extend_from_slice(&((2 + groups.len() * 2) as u16).to_be_bytes()); exts.extend_from_slice(&(groups.len() as u16 * 2).to_be_bytes()); @@ -147,7 +458,11 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { } // signature_algorithms - let sig_algs: [u16; 4] = [0x0804, 0x0805, 0x0403, 0x0503]; // rsa_pss_rsae_sha256/384, ecdsa_secp256r1_sha256, rsa_pkcs1_sha256 + let mut sig_algs = profile_sig_algs(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("sigalg:{sni}")); + sig_algs.insert(0, grease); + } exts.extend_from_slice(&0x000du16.to_be_bytes()); exts.extend_from_slice(&((2 + sig_algs.len() * 2) as u16).to_be_bytes()); exts.extend_from_slice(&(sig_algs.len() as u16 * 2).to_be_bytes()); @@ -155,8 +470,12 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { exts.extend_from_slice(&a.to_be_bytes()); } - // supported_versions (TLS1.3 + TLS1.2) - let versions: [u16; 2] = [0x0304, 0x0303]; + // supported_versions + let mut versions = profile_supported_versions(profile).to_vec(); + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("version:{sni}")); + versions.insert(0, grease); + } exts.extend_from_slice(&0x002bu16.to_be_bytes()); exts.extend_from_slice(&((1 + versions.len() * 2) as u16).to_be_bytes()); exts.push((versions.len() * 2) as u8); @@ -165,7 +484,14 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { } // key_share (x25519) - let key = gen_key_share(rng); + let key = if deterministic { + let det = deterministic_bytes(&format!("keyshare:{sni}"), 32); + let mut key = [0u8; 32]; + key.copy_from_slice(&det); + key + } else { + gen_key_share(rng) + }; let mut keyshare = Vec::with_capacity(4 + key.len()); keyshare.extend_from_slice(&0x001du16.to_be_bytes()); // group keyshare.extend_from_slice(&(key.len() as u16).to_be_bytes()); @@ -175,18 +501,29 @@ fn build_client_hello(sni: &str, rng: &SecureRandom) -> Vec { exts.extend_from_slice(&(keyshare.len() as u16).to_be_bytes()); exts.extend_from_slice(&keyshare); - // ALPN (http/1.1) - let alpn_proto = b"http/1.1"; - exts.extend_from_slice(&0x0010u16.to_be_bytes()); - exts.extend_from_slice(&((2 + 1 + alpn_proto.len()) as u16).to_be_bytes()); - exts.extend_from_slice(&((1 + alpn_proto.len()) as u16).to_be_bytes()); - exts.push(alpn_proto.len() as u8); - exts.extend_from_slice(alpn_proto); + // ALPN + let mut alpn_list = Vec::new(); + for proto in profile_alpn(profile) { + alpn_list.push(proto.len() as u8); + alpn_list.extend_from_slice(proto); + } + if !alpn_list.is_empty() { + exts.extend_from_slice(&0x0010u16.to_be_bytes()); + exts.extend_from_slice(&((2 + alpn_list.len()) as u16).to_be_bytes()); + exts.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes()); + exts.extend_from_slice(&alpn_list); + } + + if grease_enabled { + let grease = grease_value(rng, deterministic, &format!("ext:{sni}")); + exts.extend_from_slice(&grease.to_be_bytes()); + exts.extend_from_slice(&0u16.to_be_bytes()); + } // padding to reduce recognizability and keep length ~500 bytes - const TARGET_EXT_LEN: usize = 180; - if exts.len() < TARGET_EXT_LEN { - let remaining = TARGET_EXT_LEN - exts.len(); + let target_ext_len = profile_padding_target(profile); + if exts.len() < target_ext_len { + let remaining = target_ext_len - exts.len(); if remaining > 4 { let pad_len = remaining - 4; // minus type+len exts.extend_from_slice(&0x0015u16.to_be_bytes()); // padding extension @@ -402,27 +739,41 @@ async fn connect_tcp_with_upstream( connect_timeout: Duration, upstream: Option>, scope: Option<&str>, + strict_route: bool, ) -> Result { if let Some(manager) = upstream { - if let Some(addr) = resolve_socket_addr(host, port) { - match manager.connect(addr, None, scope).await { - Ok(stream) => return Ok(stream), - Err(e) => { - warn!( - host = %host, - port = port, - scope = ?scope, - error = %e, - "Upstream connect failed, using direct connect" - ); - } - } - } else if let Ok(mut addrs) = tokio::net::lookup_host((host, port)).await - && let Some(addr) = addrs.find(|a| a.is_ipv4()) - { + let resolved = if let Some(addr) = resolve_socket_addr(host, port) { + Some(addr) + } else { + match tokio::net::lookup_host((host, port)).await { + Ok(mut addrs) => addrs.find(|a| a.is_ipv4()), + Err(e) => { + if strict_route { + return Err(anyhow!( + "upstream route DNS resolution failed for {host}:{port}: {e}" + )); + } + warn!( + host = %host, + port = port, + scope = ?scope, + error = %e, + "Upstream DNS resolution failed, using direct connect" + ); + None + } + } + }; + + if let Some(addr) = resolved { match manager.connect(addr, None, scope).await { Ok(stream) => return Ok(stream), Err(e) => { + if strict_route { + return Err(anyhow!( + "upstream route connect failed for {host}:{port}: {e}" + )); + } warn!( host = %host, port = port, @@ -432,6 +783,10 @@ async fn connect_tcp_with_upstream( ); } } + } else if strict_route { + return Err(anyhow!( + "upstream route resolution produced no usable address for {host}:{port}" + )); } } Ok(UpstreamStream::Tcp( @@ -471,12 +826,15 @@ async fn fetch_via_raw_tls_stream( sni: &str, connect_timeout: Duration, proxy_protocol: u8, + profile: TlsFetchProfile, + grease_enabled: bool, + deterministic: bool, ) -> Result where S: AsyncRead + AsyncWrite + Unpin, { let rng = SecureRandom::new(); - let client_hello = build_client_hello(sni, &rng); + let client_hello = build_client_hello(sni, &rng, profile, grease_enabled, deterministic); timeout(connect_timeout, async { if proxy_protocol > 0 { let header = match proxy_protocol { @@ -550,6 +908,10 @@ async fn fetch_via_raw_tls( scope: Option<&str>, proxy_protocol: u8, unix_sock: Option<&str>, + strict_route: bool, + profile: TlsFetchProfile, + grease_enabled: bool, + deterministic: bool, ) -> Result { #[cfg(unix)] if let Some(sock_path) = unix_sock { @@ -560,8 +922,16 @@ async fn fetch_via_raw_tls( sock = %sock_path, "Raw TLS fetch using mask unix socket" ); - return fetch_via_raw_tls_stream(stream, sni, connect_timeout, proxy_protocol) - .await; + return fetch_via_raw_tls_stream( + stream, + sni, + connect_timeout, + proxy_protocol, + profile, + grease_enabled, + deterministic, + ) + .await; } Ok(Err(e)) => { warn!( @@ -584,8 +954,19 @@ async fn fetch_via_raw_tls( #[cfg(not(unix))] let _ = unix_sock; - let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope).await?; - fetch_via_raw_tls_stream(stream, sni, connect_timeout, proxy_protocol).await + let stream = + connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route) + .await?; + fetch_via_raw_tls_stream( + stream, + sni, + connect_timeout, + proxy_protocol, + profile, + grease_enabled, + deterministic, + ) + .await } async fn fetch_via_rustls_stream( @@ -691,6 +1072,7 @@ async fn fetch_via_rustls( scope: Option<&str>, proxy_protocol: u8, unix_sock: Option<&str>, + strict_route: bool, ) -> Result { #[cfg(unix)] if let Some(sock_path) = unix_sock { @@ -724,16 +1106,153 @@ async fn fetch_via_rustls( #[cfg(not(unix))] let _ = unix_sock; - let stream = connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope).await?; + let stream = + connect_tcp_with_upstream(host, port, connect_timeout, upstream, scope, strict_route) + .await?; fetch_via_rustls_stream(stream, host, sni, proxy_protocol).await } -/// Fetch real TLS metadata for the given SNI. -/// -/// Strategy: -/// 1) Probe raw TLS for realistic ServerHello and ApplicationData record sizes. -/// 2) Fetch certificate chain via rustls to build cert payload. -/// 3) Merge both when possible; otherwise auto-fallback to whichever succeeded. +/// Fetch real TLS metadata with an adaptive multi-profile strategy. +pub async fn fetch_real_tls_with_strategy( + host: &str, + port: u16, + sni: &str, + strategy: &TlsFetchStrategy, + upstream: Option>, + scope: Option<&str>, + proxy_protocol: u8, + unix_sock: Option<&str>, +) -> Result { + let attempt_timeout = strategy.attempt_timeout.max(Duration::from_millis(1)); + let total_budget = strategy.total_budget.max(Duration::from_millis(1)); + let started_at = Instant::now(); + let cache_key = profile_cache_key( + host, + port, + sni, + upstream.as_ref(), + scope, + proxy_protocol, + unix_sock, + ); + let profiles = order_profiles(strategy, Some(&cache_key), started_at); + + let mut raw_result = None; + let mut raw_last_error: Option = None; + let mut raw_last_error_kind = FetchErrorKind::Other; + let mut selected_profile = None; + + for profile in profiles { + let elapsed = started_at.elapsed(); + if elapsed >= total_budget { + break; + } + let timeout_for_attempt = attempt_timeout.min(total_budget - elapsed); + + match fetch_via_raw_tls( + host, + port, + sni, + timeout_for_attempt, + upstream.clone(), + scope, + proxy_protocol, + unix_sock, + strategy.strict_route, + profile, + strategy.grease_enabled, + strategy.deterministic, + ) + .await + { + Ok(res) => { + selected_profile = Some(profile); + raw_result = Some(res); + break; + } + Err(err) => { + let kind = classify_fetch_error(&err); + warn!( + sni = %sni, + profile = profile.as_str(), + error_kind = ?kind, + error = %err, + "Raw TLS fetch attempt failed" + ); + raw_last_error_kind = kind; + raw_last_error = Some(err); + if strategy.strict_route && matches!(kind, FetchErrorKind::Route) { + break; + } + } + } + } + + if let Some(profile) = selected_profile { + remember_profile_success(strategy, Some(cache_key), profile, Instant::now()); + } + + if raw_result.is_none() + && strategy.strict_route + && matches!(raw_last_error_kind, FetchErrorKind::Route) + { + if let Some(err) = raw_last_error { + return Err(err); + } + return Err(anyhow!("TLS fetch strict-route failure")); + } + + let elapsed = started_at.elapsed(); + if elapsed >= total_budget { + return match raw_result { + Some(raw) => Ok(raw), + None => { + Err(raw_last_error.unwrap_or_else(|| anyhow!("TLS fetch total budget exhausted"))) + } + }; + } + + let rustls_timeout = attempt_timeout.min(total_budget - elapsed); + let rustls_result = fetch_via_rustls( + host, + port, + sni, + rustls_timeout, + upstream, + scope, + proxy_protocol, + unix_sock, + strategy.strict_route, + ) + .await; + + match rustls_result { + Ok(rustls) => { + if let Some(mut raw) = raw_result { + raw.cert_info = rustls.cert_info; + raw.cert_payload = rustls.cert_payload; + raw.behavior_profile.source = TlsProfileSource::Merged; + debug!(sni = %sni, "Fetched TLS metadata via adaptive raw probe + rustls cert chain"); + Ok(raw) + } else { + Ok(rustls) + } + } + Err(err) => { + if let Some(raw) = raw_result { + warn!(sni = %sni, error = %err, "Rustls cert fetch failed, using raw TLS metadata only"); + Ok(raw) + } else if let Some(raw_err) = raw_last_error { + Err(anyhow!("TLS fetch failed (raw: {raw_err}; rustls: {err})")) + } else { + Err(err) + } + } + } +} + +/// Fetch real TLS metadata for the given SNI using a single-attempt compatibility strategy. +#[allow(dead_code)] pub async fn fetch_real_tls( host: &str, port: u16, @@ -744,62 +1263,30 @@ pub async fn fetch_real_tls( proxy_protocol: u8, unix_sock: Option<&str>, ) -> Result { - let raw_result = match fetch_via_raw_tls( + let strategy = TlsFetchStrategy::single_attempt(connect_timeout); + fetch_real_tls_with_strategy( host, port, sni, - connect_timeout, - upstream.clone(), - scope, - proxy_protocol, - unix_sock, - ) - .await - { - Ok(res) => Some(res), - Err(e) => { - warn!(sni = %sni, error = %e, "Raw TLS fetch failed"); - None - } - }; - - match fetch_via_rustls( - host, - port, - sni, - connect_timeout, + &strategy, upstream, scope, proxy_protocol, unix_sock, ) .await - { - Ok(rustls_result) => { - if let Some(mut raw) = raw_result { - raw.cert_info = rustls_result.cert_info; - raw.cert_payload = rustls_result.cert_payload; - raw.behavior_profile.source = TlsProfileSource::Merged; - debug!(sni = %sni, "Fetched TLS metadata via raw probe + rustls cert chain"); - Ok(raw) - } else { - Ok(rustls_result) - } - } - Err(e) => { - if let Some(raw) = raw_result { - warn!(sni = %sni, error = %e, "Rustls cert fetch failed, using raw TLS metadata only"); - Ok(raw) - } else { - Err(e) - } - } - } } #[cfg(test)] mod tests { - use super::{derive_behavior_profile, encode_tls13_certificate_message}; + use std::time::{Duration, Instant}; + + use super::{ + ProfileCacheValue, TlsFetchStrategy, build_client_hello, derive_behavior_profile, + encode_tls13_certificate_message, order_profiles, profile_cache, profile_cache_key, + }; + use crate::config::TlsFetchProfile; + use crate::crypto::SecureRandom; use crate::protocol::constants::{ TLS_RECORD_APPLICATION, TLS_RECORD_CHANGE_CIPHER, TLS_RECORD_HANDSHAKE, }; @@ -812,8 +1299,8 @@ mod tests { #[test] fn test_encode_tls13_certificate_message_single_cert() { let cert = vec![0x30, 0x03, 0x02, 0x01, 0x01]; - let message = encode_tls13_certificate_message(std::slice::from_ref(&cert)) - .expect("message"); + let message = + encode_tls13_certificate_message(std::slice::from_ref(&cert)).expect("message"); assert_eq!(message[0], 0x0b); assert_eq!(read_u24(&message[1..4]), message.len() - 4); @@ -848,4 +1335,93 @@ mod tests { assert_eq!(profile.ticket_record_sizes, vec![220, 180]); assert_eq!(profile.source, TlsProfileSource::Raw); } + + #[test] + fn test_order_profiles_prioritizes_fresh_cached_winner() { + let strategy = TlsFetchStrategy { + profiles: vec![ + TlsFetchProfile::ModernChromeLike, + TlsFetchProfile::CompatTls12, + TlsFetchProfile::LegacyMinimal, + ], + strict_route: true, + attempt_timeout: Duration::from_secs(1), + total_budget: Duration::from_secs(2), + grease_enabled: false, + deterministic: false, + profile_cache_ttl: Duration::from_secs(60), + }; + let cache_key = profile_cache_key( + "mask.example", + 443, + "tls.example", + None, + Some("tls"), + 0, + None, + ); + profile_cache().remove(&cache_key); + profile_cache().insert( + cache_key.clone(), + ProfileCacheValue { + profile: TlsFetchProfile::CompatTls12, + updated_at: Instant::now(), + }, + ); + + let ordered = order_profiles(&strategy, Some(&cache_key), Instant::now()); + assert_eq!(ordered[0], TlsFetchProfile::CompatTls12); + profile_cache().remove(&cache_key); + } + + #[test] + fn test_order_profiles_drops_expired_cached_winner() { + let strategy = TlsFetchStrategy { + profiles: vec![ + TlsFetchProfile::ModernFirefoxLike, + TlsFetchProfile::CompatTls12, + ], + strict_route: true, + attempt_timeout: Duration::from_secs(1), + total_budget: Duration::from_secs(2), + grease_enabled: false, + deterministic: false, + profile_cache_ttl: Duration::from_secs(5), + }; + let cache_key = + profile_cache_key("mask2.example", 443, "tls2.example", None, None, 0, None); + profile_cache().remove(&cache_key); + profile_cache().insert( + cache_key.clone(), + ProfileCacheValue { + profile: TlsFetchProfile::CompatTls12, + updated_at: Instant::now() - Duration::from_secs(6), + }, + ); + + let ordered = order_profiles(&strategy, Some(&cache_key), Instant::now()); + assert_eq!(ordered[0], TlsFetchProfile::ModernFirefoxLike); + assert!(profile_cache().get(&cache_key).is_none()); + } + + #[test] + fn test_deterministic_client_hello_is_stable() { + let rng = SecureRandom::new(); + let first = build_client_hello( + "stable.example", + &rng, + TlsFetchProfile::ModernChromeLike, + true, + true, + ); + let second = build_client_hello( + "stable.example", + &rng, + TlsFetchProfile::ModernChromeLike, + true, + true, + ); + + assert_eq!(first, second); + } } diff --git a/src/transport/middle_proxy/config_updater.rs b/src/transport/middle_proxy/config_updater.rs index 8e5a701..ba90c1a 100644 --- a/src/transport/middle_proxy/config_updater.rs +++ b/src/transport/middle_proxy/config_updater.rs @@ -11,17 +11,19 @@ use tracing::{debug, info, warn}; use crate::config::ProxyConfig; use crate::error::Result; +use crate::transport::UpstreamManager; use super::MePool; +use super::http_fetch::https_get; use super::rotation::{MeReinitTrigger, enqueue_reinit_trigger}; -use super::secret::download_proxy_secret_with_max_len; +use super::secret::download_proxy_secret_with_max_len_via_upstream; use super::selftest::record_timeskew_sample; use std::time::SystemTime; -async fn retry_fetch(url: &str) -> Option { +async fn retry_fetch(url: &str, upstream: Option>) -> Option { let delays = [1u64, 5, 15]; for (i, d) in delays.iter().enumerate() { - match fetch_proxy_config(url).await { + match fetch_proxy_config_via_upstream(url, upstream.clone()).await { Ok(cfg) => return Some(cfg), Err(e) => { if i == delays.len() - 1 { @@ -95,14 +97,19 @@ pub async fn save_proxy_config_cache(path: &str, raw_text: &str) -> Result<()> { Ok(()) } +#[allow(dead_code)] pub async fn fetch_proxy_config_with_raw(url: &str) -> Result<(ProxyConfigData, String)> { - let resp = reqwest::get(url).await.map_err(|e| { - crate::error::ProxyError::Proxy(format!("fetch_proxy_config GET failed: {e}")) - })?; - let http_status = resp.status().as_u16(); + fetch_proxy_config_with_raw_via_upstream(url, None).await +} - if let Some(date) = resp.headers().get(reqwest::header::DATE) - && let Ok(date_str) = date.to_str() +pub async fn fetch_proxy_config_with_raw_via_upstream( + url: &str, + upstream: Option>, +) -> Result<(ProxyConfigData, String)> { + let resp = https_get(url, upstream).await?; + let http_status = resp.status; + + if let Some(date_str) = resp.date_header.as_deref() && let Ok(server_time) = httpdate::parse_http_date(date_str) && let Ok(skew) = SystemTime::now() .duration_since(server_time) @@ -123,9 +130,7 @@ pub async fn fetch_proxy_config_with_raw(url: &str) -> Result<(ProxyConfigData, } } - let text = resp.text().await.map_err(|e| { - crate::error::ProxyError::Proxy(format!("fetch_proxy_config read failed: {e}")) - })?; + let text = String::from_utf8_lossy(&resp.body).into_owned(); let parsed = parse_proxy_config_text(&text, http_status); Ok((parsed, text)) } @@ -260,8 +265,16 @@ fn parse_proxy_line(line: &str) -> Option<(i32, IpAddr, u16)> { Some((dc, ip, port)) } +#[allow(dead_code)] pub async fn fetch_proxy_config(url: &str) -> Result { - fetch_proxy_config_with_raw(url) + fetch_proxy_config_via_upstream(url, None).await +} + +pub async fn fetch_proxy_config_via_upstream( + url: &str, + upstream: Option>, +) -> Result { + fetch_proxy_config_with_raw_via_upstream(url, upstream) .await .map(|(parsed, _raw)| parsed) } @@ -300,6 +313,7 @@ async fn run_update_cycle( state: &mut UpdaterState, reinit_tx: &mpsc::Sender, ) { + let upstream = pool.upstream.clone(); pool.update_runtime_reinit_policy( cfg.general.hardswap, cfg.general.me_pool_drain_ttl_secs, @@ -354,7 +368,7 @@ async fn run_update_cycle( let mut maps_changed = false; let mut ready_v4: Option<(ProxyConfigData, u64)> = None; - let cfg_v4 = retry_fetch("https://core.telegram.org/getProxyConfig").await; + let cfg_v4 = retry_fetch("https://core.telegram.org/getProxyConfig", upstream.clone()).await; if let Some(cfg_v4) = cfg_v4 && snapshot_passes_guards(cfg, &cfg_v4, "getProxyConfig") { @@ -378,7 +392,11 @@ async fn run_update_cycle( } let mut ready_v6: Option<(ProxyConfigData, u64)> = None; - let cfg_v6 = retry_fetch("https://core.telegram.org/getProxyConfigV6").await; + let cfg_v6 = retry_fetch( + "https://core.telegram.org/getProxyConfigV6", + upstream.clone(), + ) + .await; if let Some(cfg_v6) = cfg_v6 && snapshot_passes_guards(cfg, &cfg_v6, "getProxyConfigV6") { @@ -456,7 +474,12 @@ async fn run_update_cycle( pool.reset_stun_state(); if cfg.general.proxy_secret_rotate_runtime { - match download_proxy_secret_with_max_len(cfg.general.proxy_secret_len_max).await { + match download_proxy_secret_with_max_len_via_upstream( + cfg.general.proxy_secret_len_max, + upstream, + ) + .await + { Ok(secret) => { let secret_hash = hash_secret(&secret); let stable_hits = state.secret.observe(secret_hash); diff --git a/src/transport/middle_proxy/http_fetch.rs b/src/transport/middle_proxy/http_fetch.rs new file mode 100644 index 0000000..5be601e --- /dev/null +++ b/src/transport/middle_proxy/http_fetch.rs @@ -0,0 +1,183 @@ +use std::sync::Arc; +use std::time::Duration; + +use http_body_util::{BodyExt, Empty}; +use hyper::header::{CONNECTION, DATE, HOST, USER_AGENT}; +use hyper::{Method, Request}; +use hyper_util::rt::TokioIo; +use rustls::pki_types::ServerName; +use tokio::net::TcpStream; +use tokio::time::timeout; +use tokio_rustls::TlsConnector; +use tracing::debug; + +use crate::error::{ProxyError, Result}; +use crate::network::dns_overrides::resolve_socket_addr; +use crate::transport::{UpstreamManager, UpstreamStream}; + +const HTTP_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); +const HTTP_REQUEST_TIMEOUT: Duration = Duration::from_secs(15); + +pub(crate) struct HttpsGetResponse { + pub(crate) status: u16, + pub(crate) date_header: Option, + pub(crate) body: Vec, +} + +fn build_tls_client_config() -> Arc { + let mut root_store = rustls::RootCertStore::empty(); + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let provider = rustls::crypto::ring::default_provider(); + let config = rustls::ClientConfig::builder_with_provider(Arc::new(provider)) + .with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12]) + .expect("HTTPS fetch rustls protocol versions must be valid") + .with_root_certificates(root_store) + .with_no_client_auth(); + Arc::new(config) +} + +fn extract_host_port_path(url: &str) -> Result<(String, u16, String)> { + let parsed = + url::Url::parse(url).map_err(|e| ProxyError::Proxy(format!("invalid URL '{url}': {e}")))?; + if parsed.scheme() != "https" { + return Err(ProxyError::Proxy(format!( + "unsupported URL scheme '{}': only https is supported", + parsed.scheme() + ))); + } + + let host = parsed + .host_str() + .ok_or_else(|| ProxyError::Proxy(format!("URL has no host: {url}")))? + .to_string(); + let port = parsed + .port_or_known_default() + .ok_or_else(|| ProxyError::Proxy(format!("URL has no known port: {url}")))?; + + let mut path = parsed.path().to_string(); + if path.is_empty() { + path.push('/'); + } + if let Some(query) = parsed.query() { + path.push('?'); + path.push_str(query); + } + + Ok((host, port, path)) +} + +async fn resolve_target_addr(host: &str, port: u16) -> Result { + if let Some(addr) = resolve_socket_addr(host, port) { + return Ok(addr); + } + + let addrs: Vec = tokio::net::lookup_host((host, port)) + .await + .map_err(|e| ProxyError::Proxy(format!("DNS resolve failed for {host}:{port}: {e}")))? + .collect(); + + if let Some(addr) = addrs.iter().copied().find(|addr| addr.is_ipv4()) { + return Ok(addr); + } + + addrs + .first() + .copied() + .ok_or_else(|| ProxyError::Proxy(format!("DNS returned no addresses for {host}:{port}"))) +} + +async fn connect_https_transport( + host: &str, + port: u16, + upstream: Option>, +) -> Result { + if let Some(manager) = upstream { + let target = resolve_target_addr(host, port).await?; + return timeout(HTTP_CONNECT_TIMEOUT, manager.connect(target, None, None)) + .await + .map_err(|_| ProxyError::Proxy(format!("upstream connect timeout for {host}:{port}")))? + .map_err(|e| { + ProxyError::Proxy(format!("upstream connect failed for {host}:{port}: {e}")) + }); + } + + if let Some(addr) = resolve_socket_addr(host, port) { + let stream = timeout(HTTP_CONNECT_TIMEOUT, TcpStream::connect(addr)) + .await + .map_err(|_| ProxyError::Proxy(format!("connect timeout for {host}:{port}")))? + .map_err(|e| ProxyError::Proxy(format!("connect failed for {host}:{port}: {e}")))?; + return Ok(UpstreamStream::Tcp(stream)); + } + + let stream = timeout(HTTP_CONNECT_TIMEOUT, TcpStream::connect((host, port))) + .await + .map_err(|_| ProxyError::Proxy(format!("connect timeout for {host}:{port}")))? + .map_err(|e| ProxyError::Proxy(format!("connect failed for {host}:{port}: {e}")))?; + Ok(UpstreamStream::Tcp(stream)) +} + +pub(crate) async fn https_get( + url: &str, + upstream: Option>, +) -> Result { + let (host, port, path_and_query) = extract_host_port_path(url)?; + let stream = connect_https_transport(&host, port, upstream).await?; + + let server_name = ServerName::try_from(host.clone()) + .map_err(|_| ProxyError::Proxy(format!("invalid TLS server name: {host}")))?; + let connector = TlsConnector::from(build_tls_client_config()); + let tls_stream = timeout(HTTP_REQUEST_TIMEOUT, connector.connect(server_name, stream)) + .await + .map_err(|_| ProxyError::Proxy(format!("TLS handshake timeout for {host}:{port}")))? + .map_err(|e| ProxyError::Proxy(format!("TLS handshake failed for {host}:{port}: {e}")))?; + + let (mut sender, connection) = hyper::client::conn::http1::handshake(TokioIo::new(tls_stream)) + .await + .map_err(|e| ProxyError::Proxy(format!("HTTP handshake failed for {host}:{port}: {e}")))?; + + tokio::spawn(async move { + if let Err(e) = connection.await { + debug!(error = %e, "HTTPS fetch connection task failed"); + } + }); + + let host_header = if port == 443 { + host.clone() + } else { + format!("{host}:{port}") + }; + + let request = Request::builder() + .method(Method::GET) + .uri(path_and_query) + .header(HOST, host_header) + .header(USER_AGENT, "telemt-middle-proxy/1") + .header(CONNECTION, "close") + .body(Empty::::new()) + .map_err(|e| ProxyError::Proxy(format!("build HTTP request failed for {url}: {e}")))?; + + let response = timeout(HTTP_REQUEST_TIMEOUT, sender.send_request(request)) + .await + .map_err(|_| ProxyError::Proxy(format!("HTTP request timeout for {url}")))? + .map_err(|e| ProxyError::Proxy(format!("HTTP request failed for {url}: {e}")))?; + + let status = response.status().as_u16(); + let date_header = response + .headers() + .get(DATE) + .and_then(|value| value.to_str().ok()) + .map(|value| value.to_string()); + + let body = timeout(HTTP_REQUEST_TIMEOUT, response.into_body().collect()) + .await + .map_err(|_| ProxyError::Proxy(format!("HTTP body read timeout for {url}")))? + .map_err(|e| ProxyError::Proxy(format!("HTTP body read failed for {url}: {e}")))? + .to_bytes() + .to_vec(); + + Ok(HttpsGetResponse { + status, + date_header, + body, + }) +} diff --git a/src/transport/middle_proxy/mod.rs b/src/transport/middle_proxy/mod.rs index 5536869..6dfbee6 100644 --- a/src/transport/middle_proxy/mod.rs +++ b/src/transport/middle_proxy/mod.rs @@ -13,6 +13,7 @@ mod health_integration_tests; #[cfg(test)] #[path = "tests/health_regression_tests.rs"] mod health_regression_tests; +mod http_fetch; mod ping; mod pool; mod pool_config; @@ -44,7 +45,8 @@ use bytes::Bytes; #[allow(unused_imports)] pub use config_updater::{ - ProxyConfigData, fetch_proxy_config, fetch_proxy_config_with_raw, load_proxy_config_cache, + ProxyConfigData, fetch_proxy_config, fetch_proxy_config_via_upstream, + fetch_proxy_config_with_raw, fetch_proxy_config_with_raw_via_upstream, load_proxy_config_cache, me_config_updater, save_proxy_config_cache, }; pub use health::{me_drain_timeout_enforcer, me_health_monitor, me_zombie_writer_watchdog}; @@ -57,7 +59,8 @@ pub use pool::MePool; pub use pool_nat::{detect_public_ip, stun_probe}; pub use registry::ConnRegistry; pub use rotation::{MeReinitTrigger, me_reinit_scheduler, me_rotation_task}; -pub use secret::fetch_proxy_secret; +#[allow(unused_imports)] +pub use secret::{fetch_proxy_secret, fetch_proxy_secret_with_upstream}; pub(crate) use selftest::{bnd_snapshot, timeskew_snapshot, upstream_bnd_snapshots}; pub use wire::proto_flags_for_tag; diff --git a/src/transport/middle_proxy/pool_status.rs b/src/transport/middle_proxy/pool_status.rs index 918ccd4..1ef59e1 100644 --- a/src/transport/middle_proxy/pool_status.rs +++ b/src/transport/middle_proxy/pool_status.rs @@ -293,9 +293,7 @@ impl MePool { WriterContour::Draining => "draining", }; - if !draining - && let Some(dc_idx) = dc - { + if !draining && let Some(dc_idx) = dc { *live_writers_by_dc_endpoint .entry((dc_idx, endpoint)) .or_insert(0) += 1; diff --git a/src/transport/middle_proxy/secret.rs b/src/transport/middle_proxy/secret.rs index 504270a..a167773 100644 --- a/src/transport/middle_proxy/secret.rs +++ b/src/transport/middle_proxy/secret.rs @@ -1,9 +1,12 @@ use httpdate; +use std::sync::Arc; use std::time::SystemTime; use tracing::{debug, info, warn}; +use super::http_fetch::https_get; use super::selftest::record_timeskew_sample; use crate::error::{ProxyError, Result}; +use crate::transport::UpstreamManager; pub const PROXY_SECRET_MIN_LEN: usize = 32; @@ -33,11 +36,21 @@ pub(super) fn validate_proxy_secret_len(data_len: usize, max_len: usize) -> Resu } /// Fetch Telegram proxy-secret binary. +#[allow(dead_code)] pub async fn fetch_proxy_secret(cache_path: Option<&str>, max_len: usize) -> Result> { + fetch_proxy_secret_with_upstream(cache_path, max_len, None).await +} + +/// Fetch Telegram proxy-secret binary, optionally through upstream routing. +pub async fn fetch_proxy_secret_with_upstream( + cache_path: Option<&str>, + max_len: usize, + upstream: Option>, +) -> Result> { let cache = cache_path.unwrap_or("proxy-secret"); // 1) Try fresh download first. - match download_proxy_secret_with_max_len(max_len).await { + match download_proxy_secret_with_max_len_via_upstream(max_len, upstream).await { Ok(data) => { if let Err(e) = tokio::fs::write(cache, &data).await { warn!(error = %e, "Failed to cache proxy-secret (non-fatal)"); @@ -76,20 +89,25 @@ pub async fn fetch_proxy_secret(cache_path: Option<&str>, max_len: usize) -> Res } } +#[allow(dead_code)] pub async fn download_proxy_secret_with_max_len(max_len: usize) -> Result> { - let resp = reqwest::get("https://core.telegram.org/getProxySecret") - .await - .map_err(|e| ProxyError::Proxy(format!("Failed to download proxy-secret: {e}")))?; + download_proxy_secret_with_max_len_via_upstream(max_len, None).await +} - if !resp.status().is_success() { +pub async fn download_proxy_secret_with_max_len_via_upstream( + max_len: usize, + upstream: Option>, +) -> Result> { + let resp = https_get("https://core.telegram.org/getProxySecret", upstream).await?; + + if !(200..=299).contains(&resp.status) { return Err(ProxyError::Proxy(format!( "proxy-secret download HTTP {}", - resp.status() + resp.status ))); } - if let Some(date) = resp.headers().get(reqwest::header::DATE) - && let Ok(date_str) = date.to_str() + if let Some(date_str) = resp.date_header.as_deref() && let Ok(server_time) = httpdate::parse_http_date(date_str) && let Ok(skew) = SystemTime::now() .duration_since(server_time) @@ -110,11 +128,7 @@ pub async fn download_proxy_secret_with_max_len(max_len: usize) -> Result>)> = { let pools = self.pools.read(); - pools.iter().map(|(addr, pool)| (*addr, Arc::clone(pool))).collect() + pools + .iter() + .map(|(addr, pool)| (*addr, Arc::clone(pool))) + .collect() }; for (addr, pool) in pools_snapshot {