Compare commits

...

47 Commits

Author SHA1 Message Date
Alexey 22097f8c7c Update Dockerfile 2026-03-24 11:46:49 +03:00
Alexey 1450af60a0 Update Dockerfile 2026-03-24 11:41:53 +03:00
Alexey f1cc8d65f2 Update release.yml 2026-03-24 11:12:03 +03:00
Alexey ec7e808daf Update release.yml 2026-03-24 11:05:50 +03:00
Alexey e4b7e23e76 New TLS-Fetcher + TLS SNI Validator + Upstream-driver getProxySecret/Config + Workflow Tunings + Redesign Quotas on Atomics + Tests Swap: merge pull request #569 from telemt/flow
New TLS-Fetcher + TLS SNI Validator + Upstream-driver getProxySecret/Config + Workflow Tunings + Redesign Quotas on Atomics + Tests Swap
2026-03-24 10:56:15 +03:00
Alexey 8b92b80b4a Rustks CryptoProvider fixes + Rustfmt 2026-03-24 10:33:06 +03:00
Alexey f7868aa00f Advanced TLS Fetcher
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-24 09:58:24 +03:00
Alexey 655a08fa5c TLS Fetcher fixes 2026-03-23 23:12:50 +03:00
Alexey 8bc432db49 Rustfmt 2026-03-23 23:00:46 +03:00
Alexey a40d6929e5 Upstream-driver getProxyConfig and getProxyConfig
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 22:41:17 +03:00
Alexey 8db566dbe9 TLS Validator
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 21:58:39 +03:00
Alexey bb71de0230 Missing proxy_protocol_trusted_cidrs as trust-
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 20:54:58 +03:00
Alexey 62a258f8e3 Update test.yml 2026-03-23 20:49:17 +03:00
Alexey c868eaae74 Update test.yml 2026-03-23 20:36:25 +03:00
Alexey 8e1860f912 Update test.yml 2026-03-23 20:34:59 +03:00
Alexey 814bef9d99 Rustfmt 2026-03-23 20:32:55 +03:00
Alexey 3ceda15073 Update relay_quota_model_adversarial_tests.rs
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 20:18:18 +03:00
Alexey a3a6ea2880 Update relay_quota_overflow_regression_tests.rs
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 20:06:11 +03:00
Alexey 24156b5067 Workflow for Docker and correct binary naming
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 17:42:18 +03:00
Alexey a1dfa5b11d Merge branch 'flow' of https://github.com/telemt/telemt into flow 2026-03-23 17:05:26 +03:00
Alexey 800356c751 Rewiring tests
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 17:04:47 +03:00
Alexey 1546b012a6 Merge pull request #568 from avbor/main
DOCS: Update VPS_DOUBLE_HOP.*.md - AmneziaWG 2.0
2026-03-23 16:49:57 +03:00
Alexey e6b77af931 Workflows Swap
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 16:49:23 +03:00
Alexey 8cfaab9320 Fixes in tests
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 16:39:49 +03:00
Alexey 2d69b9d0ae New wave of tests
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 16:39:23 +03:00
Alexander 41c2b4de65 Update VPS_DOUBLE_HOP.en.md
Added S3-S4 parameters for AWG and update AWG generator.
2026-03-23 16:30:37 +03:00
Alexander 0a5e8a09fd Update VPS_DOUBLE_HOP.ru.md
Added S3-S4 parameters for AWG and update AWG generator.
2026-03-23 16:29:08 +03:00
Alexey 2f9fddfa6f Old Test Deletion 2026-03-23 16:21:53 +03:00
Alexey 6f4356f72a Redesign Quotas on Atomics 2026-03-23 15:53:44 +03:00
Alexey 0c3c9009a9 Merge pull request #538 from DavidOsipov/flow
Cross-mode Quota Locks, Masking Prefetch & Tiny-Frame Debt Protection
2026-03-23 11:35:57 +03:00
Alexey 0475844701 Merge branch 'flow' into flow 2026-03-23 11:35:44 +03:00
David Osipov 1abf9bd05c Refactor CI workflows: rename build job and streamline stress testing setup 2026-03-23 12:27:57 +04:00
David Osipov 6f17d4d231 Add comprehensive security tests for quota management and relay functionality
- Introduced `relay_dual_lock_race_harness_security_tests.rs` to validate user liveness during lock hold and release cycles.
- Added `relay_quota_extended_attack_surface_security_tests.rs` to cover various quota scenarios including positive, negative, edge cases, and adversarial conditions.
- Implemented `relay_quota_lock_eviction_lifecycle_tdd_tests.rs` to ensure proper eviction of stale entries and lifecycle management of quota locks.
- Created `relay_quota_lock_eviction_stress_security_tests.rs` to stress test the eviction mechanism under high churn conditions.
- Enhanced `relay_quota_lock_pressure_adversarial_tests.rs` to verify reclaiming of unreferenced entries after explicit eviction.
- Developed `relay_quota_retry_allocation_latency_security_tests.rs` to benchmark and validate latency and allocation behavior under contention.
2026-03-23 12:04:41 +04:00
Alexey bf30e93284 Merge pull request #545 from Dimasssss/patch-1
Update CONFIG_PARAMS.en.md and FAQ
2026-03-23 11:00:08 +03:00
David Osipov 91be148b72 Security hardening, concurrency fixes, and expanded test coverage
This commit introduces a comprehensive set of improvements to enhance
the security, reliability, and configurability of the proxy server,
specifically targeting adversarial resilience and high-load concurrency.

Security & Cryptography:
- Zeroize MTProto cryptographic key material (`dec_key`, `enc_key`)
  immediately after use to prevent memory leakage on early returns.
- Move TLS handshake replay tracking after full policy/ALPN validation
  to prevent cache poisoning by unauthenticated probes.
- Add `proxy_protocol_trusted_cidrs` configuration to restrict PROXY
  protocol headers to trusted networks, rejecting spoofed IPs.

Adversarial Resilience & DoS Mitigation:
- Implement "Tiny Frame Debt" tracking in the middle-relay to prevent
  CPU exhaustion from malicious 0-byte or 1-byte frame floods.
- Add `mask_relay_max_bytes` to strictly bound unauthenticated fallback
  connections, preventing the proxy from being abused as an open relay.
- Add a 5ms prefetch window (`mask_classifier_prefetch_timeout_ms`) to
  correctly assemble and classify fragmented HTTP/1.1 and HTTP/2 probes
  (e.g., `PRI * HTTP/2.0`) before routing them to masking heuristics.
- Prevent recursive masking loops (FD exhaustion) by verifying the mask
  target is not the proxy's own listener via local interface enumeration.

Concurrency & Reliability:
- Eliminate executor waker storms during quota lock contention by replacing
  the spin-waker task with inline `Sleep` and exponential backoff.
- Roll back user quota reservations (`rollback_me2c_quota_reservation`)
  if a network write fails, preventing Head-of-Line (HoL) blocking from
  permanently burning data quotas.
- Recover gracefully from idle-registry `Mutex` poisoning instead of
  panicking, ensuring isolated thread failures do not break the proxy.
- Fix `auth_probe_scan_start_offset` modulo logic to ensure bounds safety.

Testing:
- Add extensive adversarial, timing, fuzzing, and invariant test suites
  for both the client and handshake modules.
2026-03-22 23:09:49 +04:00
Alexander e46d2cfc52 Update VPS_DOUBLE_HOP.ru.md
Fix typo
2026-03-22 21:59:20 +03:00
Dimasssss d4cda6d546 Update CONFIG_PARAMS.en.md 2026-03-22 21:56:21 +03:00
Alexey e35d69c61f Merge pull request #544 from avbor/main
DOCS: VPS doube hop manual Ru\En
2026-03-22 21:45:13 +03:00
Dimasssss a353a94175 Update FAQ.en.md 2026-03-22 21:35:39 +03:00
Dimasssss b856250b2c Update FAQ.ru.md 2026-03-22 21:30:17 +03:00
Alexander 97d1476ded Merge branch 'flow' into main 2026-03-22 20:52:58 +03:00
Alexander cde14fc1bf Create VPS_DOUBLE_HOP.en.md
Added VPS double hop with AmneziaWG manual
2026-03-22 20:35:09 +03:00
Alexander 5723d50d0b Create VPS_DOUBLE_HOP.ru.md
Added VPS double hop with AmneziaWG manual
2026-03-22 20:04:14 +03:00
Alexey 3eb384e02a Update middle_relay.rs 2026-03-22 17:53:32 +03:00
Dimasssss c960e0e245 Update CONFIG_PARAMS.en.md 2026-03-22 17:44:52 +03:00
David Osipov 6fc188f0c4 Update src/proxy/tests/handshake_more_clever_tests.rs
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-22 17:08:23 +04:00
David Osipov 5c9fea5850 Update src/proxy/tests/client_security_tests.rs
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-22 17:08:16 +04:00
108 changed files with 10659 additions and 7031 deletions
+39
View File
@@ -0,0 +1,39 @@
name: Build
on:
push:
branches: [ "*" ]
pull_request:
branches: [ "*" ]
env:
CARGO_TERM_COLOR: always
jobs:
build:
name: Build
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install latest stable Rust toolchain
uses: dtolnay/rust-toolchain@stable
- name: Cache cargo registry & build artifacts
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
- name: Build Release
run: cargo build --release --verbose
+98 -92
View File
@@ -26,6 +26,9 @@ jobs:
name: GNU ${{ matrix.target }}
runs-on: ubuntu-latest
container:
image: rust:slim-bookworm
strategy:
fail-fast: false
matrix:
@@ -47,8 +50,8 @@ jobs:
- name: Install deps
run: |
sudo apt-get update
sudo apt-get install -y \
apt-get update
apt-get install -y \
build-essential \
clang \
lld \
@@ -69,14 +72,10 @@ jobs:
if [ "${{ matrix.target }}" = "aarch64-unknown-linux-gnu" ]; then
export CC=aarch64-linux-gnu-gcc
export CXX=aarch64-linux-gnu-g++
export CC_aarch64_unknown_linux_gnu=aarch64-linux-gnu-gcc
export CXX_aarch64_unknown_linux_gnu=aarch64-linux-gnu-g++
export RUSTFLAGS="-C linker=aarch64-linux-gnu-gcc"
else
export CC=clang
export CXX=clang++
export CC_x86_64_unknown_linux_gnu=clang
export CXX_x86_64_unknown_linux_gnu=clang++
export RUSTFLAGS="-C linker=clang -C link-arg=-fuse-ld=lld"
fi
@@ -85,20 +84,19 @@ jobs:
- name: Package
run: |
mkdir -p dist
BIN=target/${{ matrix.target }}/release/${{ env.BINARY_NAME }}
cp "$BIN" dist/${{ env.BINARY_NAME }}-${{ matrix.target }}
cp target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} dist/telemt
cd dist
tar -czf ${{ matrix.asset }}.tar.gz ${{ env.BINARY_NAME }}-${{ matrix.target }}
tar -czf ${{ matrix.asset }}.tar.gz \
--owner=0 --group=0 --numeric-owner \
telemt
sha256sum ${{ matrix.asset }}.tar.gz > ${{ matrix.asset }}.sha256
- uses: actions/upload-artifact@v4
with:
name: ${{ matrix.asset }}
path: |
dist/${{ matrix.asset }}.tar.gz
dist/${{ matrix.asset }}.sha256
path: dist/*
# ==========================
# MUSL
@@ -106,10 +104,10 @@ jobs:
build-musl:
name: MUSL ${{ matrix.target }}
runs-on: ubuntu-latest
container:
image: rust:slim-bookworm
strategy:
fail-fast: false
matrix:
@@ -118,10 +116,10 @@ jobs:
asset: telemt-x86_64-linux-musl
- target: aarch64-unknown-linux-musl
asset: telemt-aarch64-linux-musl
steps:
- uses: actions/checkout@v4
- name: Install deps
run: |
apt-get update
@@ -129,43 +127,43 @@ jobs:
musl-tools \
pkg-config \
curl
- uses: actions/cache@v4
if: matrix.target == 'aarch64-unknown-linux-musl'
with:
path: ~/.musl-aarch64
key: musl-toolchain-aarch64-v1
- name: Install aarch64 musl toolchain
if: matrix.target == 'aarch64-unknown-linux-musl'
run: |
set -e
TOOLCHAIN_DIR="$HOME/.musl-aarch64"
ARCHIVE="aarch64-linux-musl-cross.tgz"
URL="https://github.com/telemt/telemt/releases/download/toolchains/$ARCHIVE"
if [ -x "$TOOLCHAIN_DIR/bin/aarch64-linux-musl-gcc" ]; then
echo "✅ MUSL toolchain already installed"
echo "✅ MUSL toolchain cached"
else
echo "⬇️ Downloading musl toolchain from Telemt GitHub Releases..."
echo "⬇️ Downloading MUSL toolchain..."
curl -fL \
--retry 5 \
--retry-delay 3 \
--connect-timeout 10 \
--max-time 120 \
-o "$ARCHIVE" "$URL"
mkdir -p "$TOOLCHAIN_DIR"
tar -xzf "$ARCHIVE" --strip-components=1 -C "$TOOLCHAIN_DIR"
fi
echo "$TOOLCHAIN_DIR/bin" >> $GITHUB_PATH
- name: Add rust target
run: rustup target add ${{ matrix.target }}
- uses: actions/cache@v4
with:
path: |
@@ -173,7 +171,7 @@ jobs:
/usr/local/cargo/git
target
key: musl-${{ matrix.target }}-${{ hashFiles('**/Cargo.lock') }}
- name: Build
run: |
if [ "${{ matrix.target }}" = "aarch64-unknown-linux-musl" ]; then
@@ -185,75 +183,25 @@ jobs:
export CC_x86_64_unknown_linux_musl=musl-gcc
export RUSTFLAGS="-C target-feature=+crt-static"
fi
cargo build --release --target ${{ matrix.target }}
- name: Package
run: |
mkdir -p dist
BIN=target/${{ matrix.target }}/release/${{ env.BINARY_NAME }}
cp "$BIN" dist/${{ env.BINARY_NAME }}-${{ matrix.target }}
cp target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} dist/telemt
cd dist
tar -czf ${{ matrix.asset }}.tar.gz ${{ env.BINARY_NAME }}-${{ matrix.target }}
tar -czf ${{ matrix.asset }}.tar.gz \
--owner=0 --group=0 --numeric-owner \
telemt
sha256sum ${{ matrix.asset }}.tar.gz > ${{ matrix.asset }}.sha256
- uses: actions/upload-artifact@v4
with:
name: ${{ matrix.asset }}
path: |
dist/${{ matrix.asset }}.tar.gz
dist/${{ matrix.asset }}.sha256
# ==========================
# Docker
# ==========================
docker:
name: Docker
runs-on: ubuntu-latest
needs: [build-gnu, build-musl]
continue-on-error: true
steps:
- uses: actions/checkout@v4
- uses: actions/download-artifact@v4
with:
path: artifacts
- name: Extract binaries
run: |
mkdir dist
find artifacts -name "*.tar.gz" -exec tar -xzf {} -C dist \;
cp dist/telemt-x86_64-unknown-linux-musl dist/telemt || true
- uses: docker/setup-qemu-action@v3
- uses: docker/setup-buildx-action@v3
- name: Login to GHCR
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract version
id: vars
run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
- name: Build & Push
uses: docker/build-push-action@v6
with:
context: .
push: true
platforms: linux/amd64,linux/arm64
tags: |
ghcr.io/${{ github.repository }}:${{ steps.vars.outputs.VERSION }}
ghcr.io/${{ github.repository }}:latest
build-args: |
BINARY=dist/telemt
path: dist/*
# ==========================
# Release
@@ -271,7 +219,7 @@ jobs:
with:
path: artifacts
- name: Flatten artifacts
- name: Flatten
run: |
mkdir dist
find artifacts -type f -exec cp {} dist/ \;
@@ -281,5 +229,63 @@ jobs:
with:
files: dist/*
generate_release_notes: true
draft: false
prerelease: ${{ contains(github.ref, '-rc') || contains(github.ref, '-beta') || contains(github.ref, '-alpha') }}
prerelease: ${{ contains(github.ref, '-') }}
# ==========================
# Docker
# ==========================
docker:
name: Docker (${{ matrix.platform }})
runs-on: ubuntu-latest
needs: [build-gnu, build-musl]
strategy:
fail-fast: false
matrix:
include:
- platform: linux/amd64
artifact: telemt-x86_64-linux-musl
- platform: linux/arm64
artifact: telemt-aarch64-linux-musl
permissions:
contents: read
packages: write
steps:
- uses: actions/checkout@v4
- uses: actions/download-artifact@v4
with:
name: ${{ matrix.artifact }}
path: dist
- name: Extract binary
run: |
tar -xzf dist/${{ matrix.artifact }}.tar.gz -C dist
chmod +x dist/telemt
- uses: docker/setup-qemu-action@v3
- uses: docker/setup-buildx-action@v3
- uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract version
id: vars
run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
- name: Build & Push (per arch)
uses: docker/build-push-action@v6
with:
context: .
push: true
platforms: ${{ matrix.platform }}
tags: |
ghcr.io/${{ github.repository }}:${{ steps.vars.outputs.VERSION }}
ghcr.io/${{ github.repository }}:latest
build-args: |
BINARY=dist/telemt
-66
View File
@@ -1,66 +0,0 @@
name: Rust
on:
push:
branches: [ "*" ]
pull_request:
branches: [ "*" ]
env:
CARGO_TERM_COLOR: always
jobs:
build:
name: Build
runs-on: ubuntu-latest
permissions:
contents: read
actions: write
checks: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install latest stable Rust toolchain
uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt, clippy
- name: Cache cargo registry & build artifacts
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
- name: Build Release
run: cargo build --release --verbose
- name: Run tests
run: cargo test --verbose
- name: Stress quota-lock suites (PR only)
if: github.event_name == 'pull_request'
env:
RUST_TEST_THREADS: 16
run: |
set -euo pipefail
for i in $(seq 1 12); do
echo "[quota-lock-stress] iteration ${i}/12"
cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16
cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16
done
# clippy dont fail on warnings because of active development of telemt
# and many warnings
- name: Run clippy
run: cargo clippy -- --cap-lints warn
- name: Check for unused dependencies
run: cargo udeps || true
+127
View File
@@ -0,0 +1,127 @@
name: Check
on:
push:
branches: [ "*" ]
pull_request:
branches: [ "*" ]
env:
CARGO_TERM_COLOR: always
concurrency:
group: test-${{ github.ref }}
cancel-in-progress: true
jobs:
# ==========================
# Formatting
# ==========================
fmt:
name: Fmt
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt
- run: cargo fmt -- --check
# ==========================
# Tests
# ==========================
test:
name: Test
runs-on: ubuntu-latest
permissions:
contents: read
actions: write
checks: write
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- name: Cache cargo
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
- run: cargo test --verbose
# ==========================
# Clippy
# ==========================
clippy:
name: Clippy
runs-on: ubuntu-latest
permissions:
contents: read
checks: write
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
components: clippy
- name: Cache cargo
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
- run: cargo clippy -- --cap-lints warn
# ==========================
# Udeps
# ==========================
udeps:
name: Udeps
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- name: Cache cargo
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
- name: Install cargo-udeps
run: cargo install cargo-udeps || true
# тоже не валит билд
- run: cargo udeps || true
Generated
+30 -8
View File
@@ -1454,9 +1454,9 @@ dependencies = [
[[package]]
name = "iri-string"
version = "0.7.10"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a"
checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb"
dependencies = [
"memchr",
"serde",
@@ -1486,7 +1486,7 @@ dependencies = [
"cesu8",
"cfg-if",
"combine",
"jni-sys",
"jni-sys 0.3.1",
"log",
"thiserror 1.0.69",
"walkdir",
@@ -1495,9 +1495,31 @@ dependencies = [
[[package]]
name = "jni-sys"
version = "0.3.0"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258"
dependencies = [
"jni-sys 0.4.1",
]
[[package]]
name = "jni-sys"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2"
dependencies = [
"jni-sys-macros",
]
[[package]]
name = "jni-sys-macros"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264"
dependencies = [
"quote",
"syn",
]
[[package]]
name = "jobserver"
@@ -1659,9 +1681,9 @@ dependencies = [
[[package]]
name = "moka"
version = "0.12.14"
version = "0.12.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85f8024e1c8e71c778968af91d43700ce1d11b219d127d79fb2934153b82b42b"
checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046"
dependencies = [
"crossbeam-channel",
"crossbeam-epoch",
@@ -2771,7 +2793,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
[[package]]
name = "telemt"
version = "3.3.29"
version = "3.3.30"
dependencies = [
"aes",
"anyhow",
+4 -1
View File
@@ -1,8 +1,11 @@
[package]
name = "telemt"
version = "3.3.29"
version = "3.3.30"
edition = "2024"
[features]
redteam_offline_expected_fail = []
[dependencies]
# C
libc = "0.2"
+46 -69
View File
@@ -1,111 +1,88 @@
# syntax=docker/dockerfile:1
# ==========================
# Stage 1: Build
# ==========================
FROM rust:1.88-slim-bookworm AS builder
RUN apt-get update && apt-get install -y --no-install-recommends \
pkg-config \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /build
# Depcache
COPY Cargo.toml Cargo.lock* ./
RUN mkdir src && echo 'fn main() {}' > src/main.rs && \
cargo build --release 2>/dev/null || true && \
rm -rf src
# Build
COPY . .
RUN cargo build --release && strip target/release/telemt
ARG BINARY
ARG TARGETARCH
# ==========================
# Stage 2: Compress (strip + UPX)
# Stage: minimal
# ==========================
FROM debian:12-slim AS minimal
RUN apt-get update && apt-get install -y --no-install-recommends \
binutils \
curl \
ca-certificates \
&& rm -rf /var/lib/apt/lists/* \
ARG TARGETARCH
ARG BINARY
RUN set -eux; \
apt-get update; \
apt-get install -y --no-install-recommends \
binutils \
curl \
xz-utils \
ca-certificates; \
rm -rf /var/lib/apt/lists/*; \
\
# install UPX from Telemt releases
&& curl -fL \
case "${TARGETARCH}" in \
amd64) UPX_ARCH="amd64" ;; \
arm64) UPX_ARCH="arm64" ;; \
*) echo "Unsupported TARGETARCH: ${TARGETARCH}" >&2; exit 1 ;; \
esac; \
\
curl -fL \
--retry 5 \
--retry-delay 3 \
--connect-timeout 10 \
--max-time 120 \
-o /tmp/upx.tar.xz \
https://github.com/telemt/telemt/releases/download/toolchains/upx-amd64_linux.tar.xz \
&& tar -xf /tmp/upx.tar.xz -C /tmp \
&& mv /tmp/upx*/upx /usr/local/bin/upx \
&& chmod +x /usr/local/bin/upx \
&& rm -rf /tmp/upx*
"https://github.com/telemt/telemt/releases/download/toolchains/upx-${UPX_ARCH}_linux.tar.xz"; \
\
tar -xf /tmp/upx.tar.xz -C /tmp; \
install -m 0755 /tmp/upx*/upx /usr/local/bin/upx; \
rm -rf /tmp/upx*
COPY --from=builder /build/target/release/telemt /telemt
COPY ${BINARY} /telemt
RUN strip /telemt || true
RUN upx --best --lzma /telemt || true
RUN set -eux; \
test -f /telemt; \
strip --strip-unneeded /telemt || true; \
upx --best --lzma /telemt || true
# ==========================
# Stage 3: Debug base
# Debug image
# ==========================
FROM debian:12-slim AS debug-base
FROM debian:12-slim AS debug
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
tzdata \
curl \
iproute2 \
busybox \
&& rm -rf /var/lib/apt/lists/*
# ==========================
# Stage 4: Debug image
# ==========================
FROM debug-base AS debug
RUN set -eux; \
apt-get update; \
apt-get install -y --no-install-recommends \
ca-certificates \
tzdata \
curl \
iproute2 \
busybox; \
rm -rf /var/lib/apt/lists/*
WORKDIR /app
COPY --from=minimal /telemt /app/telemt
COPY config.toml /app/config.toml
USER root
EXPOSE 443
EXPOSE 9090
EXPOSE 9091
EXPOSE 443 9090 9091
ENTRYPOINT ["/app/telemt"]
CMD ["config.toml"]
# ==========================
# Stage 5: Production (distroless)
# Production (distroless, for static MUSL binary)
# ==========================
FROM gcr.io/distroless/base-debian12 AS prod
FROM gcr.io/distroless/static-debian12 AS prod
WORKDIR /app
COPY --from=minimal /telemt /app/telemt
COPY config.toml /app/config.toml
# TLS + timezone + shell
COPY --from=debug-base /etc/ssl/certs /etc/ssl/certs
COPY --from=debug-base /usr/share/zoneinfo /usr/share/zoneinfo
COPY --from=debug-base /bin/busybox /bin/busybox
RUN ["/bin/busybox", "--install", "-s", "/bin"]
# distroless user
USER nonroot:nonroot
EXPOSE 443
EXPOSE 9090
EXPOSE 9091
EXPOSE 443 9090 9091
ENTRYPOINT ["/app/telemt"]
CMD ["config.toml"]
+9 -2
View File
@@ -20,7 +20,7 @@ This document lists all configuration keys accepted by `config.toml`.
| Parameter | Type | Default | Constraints / validation | Description |
|---|---|---|---|---|
| data_path | `String \| null` | `null` | — | Optional runtime data directory path. |
| prefer_ipv6 | `bool` | `false` | — | Prefer IPv6 where applicable in runtime logic. |
| prefer_ipv6 | `bool` | `false` | Deprecated. Use `network.prefer`. | Deprecated legacy IPv6 preference flag migrated to `network.prefer`. |
| fast_mode | `bool` | `true` | — | Enables fast-path optimizations for traffic processing. |
| use_middle_proxy | `bool` | `true` | none | Enables ME transport mode; if `false`, runtime falls back to direct DC routing. |
| proxy_secret_path | `String \| null` | `"proxy-secret"` | Path may be `null`. | Path to Telegram infrastructure proxy-secret file used by ME handshake logic. |
@@ -44,6 +44,7 @@ This document lists all configuration keys accepted by `config.toml`.
| me_writer_cmd_channel_capacity | `usize` | `4096` | Must be `> 0`. | Capacity of per-writer command channel. |
| me_route_channel_capacity | `usize` | `768` | Must be `> 0`. | Capacity of per-connection ME response route channel. |
| me_c2me_channel_capacity | `usize` | `1024` | Must be `> 0`. | Capacity of per-client command queue (client reader -> ME sender). |
| me_c2me_send_timeout_ms | `u64` | `4000` | `0..=60000`. | Maximum wait for enqueueing client->ME commands when the per-client queue is full (`0` keeps legacy unbounded wait). |
| me_reader_route_data_wait_ms | `u64` | `2` | `0..=20`. | Bounded wait for routing ME DATA to per-connection queue (`0` = no wait). |
| me_d2c_flush_batch_max_frames | `usize` | `32` | `1..=512`. | Max ME->client frames coalesced before flush. |
| me_d2c_flush_batch_max_bytes | `usize` | `131072` | `4096..=2_097_152`. | Max ME->client payload bytes coalesced before flush. |
@@ -105,6 +106,8 @@ This document lists all configuration keys accepted by `config.toml`.
| me_warn_rate_limit_ms | `u64` | `5000` | Must be `> 0`. | Cooldown for repetitive ME warning logs (ms). |
| me_route_no_writer_mode | `"async_recovery_failfast" \| "inline_recovery_legacy" \| "hybrid_async_persistent"` | `"hybrid_async_persistent"` | — | Route behavior when no writer is immediately available. |
| me_route_no_writer_wait_ms | `u64` | `250` | `10..=5000`. | Max wait in async-recovery failfast mode (ms). |
| me_route_hybrid_max_wait_ms | `u64` | `3000` | `50..=60000`. | Maximum cumulative wait in hybrid no-writer mode before failfast fallback (ms). |
| me_route_blocking_send_timeout_ms | `u64` | `250` | `0..=5000`. | Maximum wait for blocking route-channel send fallback (`0` keeps legacy unbounded wait). |
| me_route_inline_recovery_attempts | `u32` | `3` | Must be `> 0`. | Inline recovery attempts in legacy mode. |
| me_route_inline_recovery_wait_ms | `u64` | `3000` | `10..=30000`. | Max inline recovery wait in legacy mode (ms). |
| fast_mode_min_tls_record | `usize` | `0` | — | Minimum TLS record size when fast-mode coalescing is enabled (`0` disables). |
@@ -124,6 +127,7 @@ This document lists all configuration keys accepted by `config.toml`.
| me_secret_atomic_snapshot | `bool` | `true` | — | Keeps selector and secret bytes from the same snapshot atomically. |
| proxy_secret_len_max | `usize` | `256` | Must be within `[32, 4096]`. | Upper length limit for accepted proxy-secret bytes. |
| me_pool_drain_ttl_secs | `u64` | `90` | none | Time window where stale writers remain fallback-eligible after map change. |
| me_instadrain | `bool` | `false` | — | Forces draining stale writers to be removed on the next cleanup tick, bypassing TTL/deadline waiting. |
| me_pool_drain_threshold | `u64` | `128` | — | Max draining stale writers before batch force-close (`0` disables threshold cleanup). |
| me_pool_drain_soft_evict_enabled | `bool` | `true` | — | Enables gradual soft-eviction of stale writers during drain/reinit instead of immediate hard close. |
| me_pool_drain_soft_evict_grace_secs | `u64` | `30` | `0..=3600`. | Grace period before stale writers become soft-evict candidates. |
@@ -203,6 +207,7 @@ This document lists all configuration keys accepted by `config.toml`.
| metrics_listen | `String \| null` | `null` | — | Full metrics bind address (`IP:PORT`), overrides `metrics_port`. |
| metrics_whitelist | `IpNetwork[]` | `["127.0.0.1/32", "::1/128"]` | — | CIDR whitelist for metrics endpoint access. |
| max_connections | `u32` | `10000` | — | Max concurrent client connections (`0` = unlimited). |
| accept_permit_timeout_ms | `u64` | `250` | `0..=60000`. | Maximum wait for acquiring a connection-slot permit before the accepted connection is dropped (`0` keeps legacy unbounded wait). |
Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers are parsed from the first bytes of the connection and the client source address is replaced with `src_addr` from the header. For security, the peer source IP (the direct connection address) is verified against `server.proxy_protocol_trusted_cidrs`; if this list is empty, PROXY headers are rejected and the connection is considered untrusted.
@@ -229,7 +234,7 @@ Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers a
|---|---|---|---|---|
| ip | `IpAddr` | — | — | Listener bind IP. |
| announce | `String \| null` | — | — | Public IP/domain announced in proxy links (priority over `announce_ip`). |
| announce_ip | `IpAddr \| null` | — | | Deprecated legacy announce IP (migrated to `announce` if needed). |
| announce_ip | `IpAddr \| null` | — | Deprecated. Use `announce`. | Deprecated legacy announce IP (migrated to `announce` if needed). |
| proxy_protocol | `bool \| null` | `null` | — | Per-listener override for PROXY protocol enable flag. |
| reuse_allow | `bool` | `false` | — | Enables `SO_REUSEPORT` for multi-instance bind sharing. |
@@ -269,6 +274,8 @@ Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers a
| mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. |
| mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. |
| mask_shape_above_cap_blur_max_bytes | `usize` | `512` | Must be `<= 1048576`; must be `> 0` when `mask_shape_above_cap_blur = true`. | Maximum randomized extra bytes appended above cap. |
| mask_relay_max_bytes | `usize` | `5242880` | Must be `> 0`; must be `<= 67108864`. | Maximum relayed bytes per direction on unauthenticated masking fallback path. |
| mask_classifier_prefetch_timeout_ms | `u64` | `5` | Must be within `[5, 50]`. | Timeout budget (ms) for extending fragmented initial classifier window on masking fallback. |
| mask_timing_normalization_enabled | `bool` | `false` | Requires `mask_timing_normalization_floor_ms > 0`; requires `ceiling >= floor`. | Enables timing envelope normalization on masking outcomes. |
| mask_timing_normalization_floor_ms | `u64` | `0` | Must be `> 0` when timing normalization is enabled; must be `<= ceiling`. | Lower bound (ms) for masking outcome normalization target. |
| mask_timing_normalization_ceiling_ms | `u64` | `0` | Must be `>= floor`; must be `<= 60000`. | Upper bound (ms) for masking outcome normalization target. |
+4 -1
View File
@@ -3,7 +3,7 @@
1. Go to @MTProxybot bot.
2. Enter the command `/newproxy`
3. Send the server IP and port. For example: 1.2.3.4:443
4. Open the config `nano /etc/telemt.toml`.
4. Open the config `nano /etc/telemt/telemt.toml`.
5. Copy and send the user secret from the [access.users] section to the bot.
6. Copy the tag received from the bot. For example 1234567890abcdef1234567890abcdef.
> [!WARNING]
@@ -33,6 +33,9 @@ hello = "ad_tag"
hello2 = "ad_tag2"
```
## Why is middle proxy (ME) needed
https://github.com/telemt/telemt/discussions/167
## How many people can use 1 link
By default, 1 link can be used by any number of people.
+5 -1
View File
@@ -3,7 +3,7 @@
1. Зайти в бота @MTProxybot.
2. Ввести команду `/newproxy`
3. Отправить IP и порт сервера. Например: 1.2.3.4:443
4. Открыть конфиг `nano /etc/telemt.toml`.
4. Открыть конфиг `nano /etc/telemt/telemt.toml`.
5. Скопировать и отправить боту секрет пользователя из раздела [access.users].
6. Скопировать полученный tag у бота. Например 1234567890abcdef1234567890abcdef.
> [!WARNING]
@@ -33,6 +33,10 @@ hello = "ad_tag"
hello2 = "ad_tag2"
```
## Зачем нужен middle proxy (ME)
https://github.com/telemt/telemt/discussions/167
## Сколько человек может пользоваться 1 ссылкой
По умолчанию 1 ссылкой может пользоваться сколько угодно человек.
+287
View File
@@ -0,0 +1,287 @@
<img src="https://gist.githubusercontent.com/avbor/1f8a128e628f47249aae6e058a57610b/raw/19013276c035e91058e0a9799ab145f8e70e3ff5/scheme.svg">
## Concept
- **Server A** (__conditionally Russian Federation_):\
Entry point, receives Telegram proxy user traffic via **HAProxy** (port `443`)\
and sends it to the tunnel to Server **B**.\
Internal IP in the tunnel — `10.10.10.2`\
Port for HAProxy clients — `443\tcp`
- **Server B** (_conditionally Netherlands_):\
Exit point, runs **telemt** and accepts client connections through Server **A**.\
The server must have unrestricted access to Telegram servers.\
Internal IP in the tunnel — `10.10.10.1`\
AmneziaWG port — `8443\udp`\
Port for telemt clients — `443\tcp`
---
## Step 1. Setting up the AmneziaWG tunnel (A <-> B)
[AmneziaWG](https://github.com/amnezia-vpn/amneziawg-linux-kernel-module) must be installed on all servers.\
All following commands are given for **Ubuntu 24.04**.\
For RHEL-based distributions, installation instructions are available at the link above.
### Installing AmneziaWG (Servers A and B)
The following steps must be performed on each server:
#### 1. Adding the AmneziaWG repository and installing required packages:
```bash
sudo apt install -y software-properties-common python3-launchpadlib gnupg2 linux-headers-$(uname -r) && \
sudo add-apt-repository ppa:amnezia/ppa && \
sudo apt-get install -y amneziawg
```
#### 2. Generating a unique key pair:
```bash
cd /etc/amnezia/amneziawg && \
awg genkey | tee private.key | awg pubkey > public.key
```
As a result, you will get two files in the `/etc/amnezia/amneziawg` folder:\
`private.key` - private, and\
`public.key` - public server keys
#### 3. Configuring network interfaces:
Obfuscation parameters `S1`, `S2`, `H1`, `H2`, `H3`, `H4` must be strictly identical on both servers.\
Parameters `Jc`, `Jmin` and `Jmax` can differ.\
Parameters `I1-I5` ([Custom Protocol Signature](https://docs.amnezia.org/documentation/amnezia-wg/)) must be specified on the client side (Server **A**).
Recommendations for choosing values:
```text
Jc — 1 ≤ Jc ≤ 128; from 4 to 12 inclusive
Jmin — Jmax > Jmin < 1280*; recommended 8
Jmax — Jmin < Jmax ≤ 1280*; recommended 80
S1 — S1 ≤ 1132* (1280* - 148 = 1132); S1 + 56 ≠ S2;
recommended range from 15 to 150 inclusive
S2 — S2 ≤ 1188* (1280* - 92 = 1188);
recommended range from 15 to 150 inclusive
H1/H2/H3/H4 — must be unique and differ from each other;
recommended range from 5 to 2147483647 inclusive
* It is assumed that the Internet connection has an MTU of 1280.
```
> [!IMPORTANT]
> It is recommended to use your own, unique values.\
> You can use the [generator](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/13f5517ca473b47c412b9a99407066de973732bd/awg-gen.html) to select parameters.
#### Server B Configuration (Netherlands):
Create the interface configuration file (`awg0`)
```bash
nano /etc/amnezia/amneziawg/awg0.conf
```
File content
```ini
[Interface]
Address = 10.10.10.1/24
ListenPort = 8443
PrivateKey = <PRIVATE_KEY_SERVER_B>
SaveConfig = true
Jc = 4
Jmin = 8
Jmax = 80
S1 = 29
S2 = 15
S3 = 18
S4 = 0
H1 = 2087563914
H2 = 188817757
H3 = 101784570
H4 = 432174303
[Peer]
PublicKey = <PUBLIC_KEY_SERVER_A>
AllowedIPs = 10.10.10.2/32
```
`ListenPort` - the port on which the server will wait for connections, you can choose any free one.\
`<PRIVATE_KEY_SERVER_B>` - the content of the `private.key` file from Server **B**.\
`<PUBLIC_KEY_SERVER_A>` - the content of the `public.key` file from Server **A**.
Open the port on the firewall (if enabled):
```bash
sudo ufw allow from <PUBLIC_IP_SERVER_A> to any port 8443 proto udp
```
`<PUBLIC_IP_SERVER_A>` - the external IP address of Server **A**.
#### Server A Configuration (Russian Federation):
Create the interface configuration file (awg0)
```bash
nano /etc/amnezia/amneziawg/awg0.conf
```
File content
```ini
[Interface]
Address = 10.10.10.2/24
PrivateKey = <PRIVATE_KEY_SERVER_A>
Jc = 4
Jmin = 8
Jmax = 80
S1 = 29
S2 = 15
S3 = 18
S4 = 0
H1 = 2087563914
H2 = 188817757
H3 = 101784570
H4 = 432174303
I1 = <b 0xc10000000108981eba846e21f74e00>
I2 = <b 0xc20000000108981eba846e21f74e00>
I3 = <b 0xc30000000108981eba846e21f74e00>
I4 = <b 0x43981eba846e21f74e>
I5 = <b 0x43981eba846e21f74e>
[Peer]
PublicKey = <PUBLIC_KEY_SERVER_B>
Endpoint = <PUBLIC_IP_SERVER_B>:8443
AllowedIPs = 10.10.10.1/32
PersistentKeepalive = 25
```
`<PRIVATE_KEY_SERVER_A>` - the content of the `private.key` file from Server **A**.\
`<PUBLIC_KEY_SERVER_B>` - the content of the `public.key` file from Server **B**.\
`<PUBLIC_IP_SERVER_B>` - the public IP address of Server **B**.
Enable the tunnel on both servers:
```bash
sudo systemctl enable --now awg-quick@awg0
```
Make sure Server B is accessible from Server A through the tunnel.
```bash
ping 10.10.10.1
PING 10.10.10.1 (10.10.10.1) 56(84) bytes of data.
64 bytes from 10.10.10.1: icmp_seq=1 ttl=64 time=35.1 ms
64 bytes from 10.10.10.1: icmp_seq=2 ttl=64 time=35.0 ms
64 bytes from 10.10.10.1: icmp_seq=3 ttl=64 time=35.1 ms
^C
```
---
## Step 2. Installing telemt on Server B (conditionally Netherlands)
Installation and configuration are described [here](https://github.com/telemt/telemt/blob/main/docs/QUICK_START_GUIDE.ru.md) or [here](https://gitlab.com/An0nX/telemt-docker#-quick-start-docker-compose).\
It is assumed that telemt expects connections on port `443\tcp`.
In the telemt config, you must enable the `Proxy` protocol and restrict connections to it only through the tunnel.
```toml
[server]
port = 443
listen_addr_ipv4 = "10.10.10.1"
proxy_protocol = true
```
Also, for correct link generation, specify the FQDN or IP address and port of Server `A`
```toml
[general.links]
show = "*"
public_host = "<FQDN_OR_IP_SERVER_A>"
public_port = 443
```
Open the port on the firewall (if enabled):
```bash
sudo ufw allow from 10.10.10.2 to any port 443 proto tcp
```
---
## Step 3. Configuring HAProxy on Server A (Russian Federation)
Since the version in the standard Ubuntu repository is relatively old, it makes sense to use the official Docker image.\
[Instructions](https://docs.docker.com/engine/install/ubuntu/) for installing Docker on Ubuntu.
> [!WARNING]
> By default, regular users do not have rights to use ports < 1024.
> Attempts to run HAProxy on port 443 can lead to errors:
> ```
> [ALERT] (8) : Binding [/usr/local/etc/haproxy/haproxy.cfg:17] for frontend tcp_in_443:
> protocol tcpv4: cannot bind socket (Permission denied) for [0.0.0.0:443].
> ```
> There are two simple ways to bypass this restriction, choose one:
> 1. At the OS level, change the net.ipv4.ip_unprivileged_port_start setting to allow users to use all ports:
> ```
> echo "net.ipv4.ip_unprivileged_port_start = 0" | sudo tee -a /etc/sysctl.conf && sudo sysctl -p
> ```
> or
>
> 2. Run HAProxy as root:
> Uncomment the `user: "root"` parameter in docker-compose.yaml.
#### Create a folder for HAProxy:
```bash
mkdir -p /opt/docker-compose/haproxy && cd $_
```
#### Create the docker-compose.yaml file
`nano docker-compose.yaml`
File content
```yaml
services:
haproxy:
image: haproxy:latest
container_name: haproxy
restart: unless-stopped
# user: "root"
network_mode: "host"
volumes:
- ./haproxy.cfg:/usr/local/etc/haproxy/haproxy.cfg:ro
logging:
driver: "json-file"
options:
max-size: "1m"
max-file: "1"
```
#### Create the haproxy.cfg config file
Accept connections on port 443\tcp and send them through the tunnel to Server `B` 10.10.10.1:443
`nano haproxy.cfg`
File content
```haproxy
global
log stdout format raw local0
maxconn 10000
defaults
log global
mode tcp
option tcplog
option clitcpka
option srvtcpka
timeout connect 5s
timeout client 2h
timeout server 2h
timeout check 5s
frontend tcp_in_443
bind *:443
maxconn 8000
option tcp-smart-accept
default_backend telemt_nodes
backend telemt_nodes
option tcp-smart-connect
server server_a 10.10.10.1:443 check inter 5s rise 2 fall 3 send-proxy-v2
```
> [!WARNING]
> **The file must end with an empty line, otherwise HAProxy will not start!**
#### Allow port 443\tcp in the firewall (if enabled)
```bash
sudo ufw allow 443/tcp
```
#### Start the HAProxy container
```bash
docker compose up -d
```
If everything is configured correctly, you can now try connecting Telegram clients using links from the telemt log\api.
+291
View File
@@ -0,0 +1,291 @@
<img src="https://gist.githubusercontent.com/avbor/1f8a128e628f47249aae6e058a57610b/raw/19013276c035e91058e0a9799ab145f8e70e3ff5/scheme.svg">
## Концепция
- **Сервер A** (_РФ_):\
Точка входа, принимает трафик пользователей Telegram-прокси через **HAProxy** (порт `443`)\
и отправляет в туннель на Сервер **B**.\
Внутренний IP в туннеле — `10.10.10.2`\
Порт для клиентов HAProxy — `443\tcp`
- **Сервер B** (_условно Нидерланды_):\
Точка выхода, на нем работает **telemt** и принимает подключения клиентов через Сервер **A**.\
На сервере должен быть неограниченный доступ до серверов Telegram.\
Внутренний IP в туннеле — `10.10.10.1`\
Порт AmneziaWG — `8443\udp`\
Порт для клиентов telemt — `443\tcp`
---
## Шаг 1. Настройка туннеля AmneziaWG (A <-> B)
На всех серверах необходимо установить [amneziawg](https://github.com/amnezia-vpn/amneziawg-linux-kernel-module).\
Далее все команды даны для **Ununtu 24.04**.\
Для RHEL-based дистрибутивов инструкция по установке есть по ссылке выше.
### Установка AmneziaWG (Сервера A и B)
На каждом из серверов необходимо выполнить следующие шаги:
#### 1. Добавление репозитория AmneziaWG и установка необходимых пакетов:
```bash
sudo apt install -y software-properties-common python3-launchpadlib gnupg2 linux-headers-$(uname -r) && \
sudo add-apt-repository ppa:amnezia/ppa && \
sudo apt-get install -y amneziawg
```
#### 2. Генерация уникальной пары ключей:
```bash
cd /etc/amnezia/amneziawg && \
awg genkey | tee private.key | awg pubkey > public.key
```
В результате вы получите в папке `/etc/amnezia/amneziawg` два файла:\
`private.key` - приватный и\
`public.key` - публичный ключи сервера
#### 3. Настройка сетевых интерфейсов:
Параметры обфускации `S1`, `S2`, `H1`, `H2`, `H3`, `H4` должны быть строго идентичными на обоих серверах.\
Параметры `Jc`, `Jmin` и `Jmax` могут отличатся.\
Параметры `I1-I5` ([Custom Protocol Signature](https://docs.amnezia.org/documentation/amnezia-wg/)) нужно указывать на стороне _клиента_ (Сервер **А**).
Рекомендации по выбору значений:
```text
Jc — 1 ≤ Jc ≤ 128; от 4 до 12 включительно
Jmin — Jmax > Jmin < 1280*; рекомендовано 8
Jmax — Jmin < Jmax ≤ 1280*; рекомендовано 80
S1 — S1 ≤ 1132* (1280* - 148 = 1132); S1 + 56 ≠ S2;
рекомендованный диапазон от 15 до 150 включительно
S2 — S2 ≤ 1188* (1280* - 92 = 1188);
рекомендованный диапазон от 15 до 150 включительно
H1/H2/H3/H4 — должны быть уникальны и отличаться друг от друга;
рекомендованный диапазон от 5 до 2147483647 включительно
* Предполагается, что подключение к Интернету имеет MTU 1280.
```
> [!IMPORTANT]
> Рекомендуется использовать собственные, уникальные значения.\
> Для выбора параметров можете воспользоваться [генератором](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/13f5517ca473b47c412b9a99407066de973732bd/awg-gen.html).
#### Конфигурация Сервера B (_Нидерланды_):
Создаем файл конфигурации интерфейса (`awg0`)
```bash
nano /etc/amnezia/amneziawg/awg0.conf
```
Содержимое файла
```ini
[Interface]
Address = 10.10.10.1/24
ListenPort = 8443
PrivateKey = <PRIVATE_KEY_SERVER_B>
SaveConfig = true
Jc = 4
Jmin = 8
Jmax = 80
S1 = 29
S2 = 15
S3 = 18
S4 = 0
H1 = 2087563914
H2 = 188817757
H3 = 101784570
H4 = 432174303
[Peer]
PublicKey = <PUBLIC_KEY_SERVER_A>
AllowedIPs = 10.10.10.2/32
```
`ListenPort` - порт, на котором сервер будет ждать подключения, можете выбрать любой свободный.\
`<PRIVATE_KEY_SERVER_B>` - содержимое файла `private.key` с сервера **B**.\
`<PUBLIC_KEY_SERVER_A>` - содержимое файла `public.key` с сервера **A**.
Открываем порт на фаерволе (если включен):
```bash
sudo ufw allow from <PUBLIC_IP_SERVER_A> to any port 8443 proto udp
```
`<PUBLIC_IP_SERVER_A>` - внешний IP адрес Сервера **A**.
#### Конфигурация Сервера A (_РФ_):
Создаем файл конфигурации интерфейса (`awg0`)
```bash
nano /etc/amnezia/amneziawg/awg0.conf
```
Содержимое файла
```ini
[Interface]
Address = 10.10.10.2/24
PrivateKey = <PRIVATE_KEY_SERVER_A>
Jc = 4
Jmin = 8
Jmax = 80
S1 = 29
S2 = 15
S3 = 18
S4 = 0
H1 = 2087563914
H2 = 188817757
H3 = 101784570
H4 = 432174303
I1 = <b 0xc10000000108981eba846e21f74e00>
I2 = <b 0xc20000000108981eba846e21f74e00>
I3 = <b 0xc30000000108981eba846e21f74e00>
I4 = <b 0x43981eba846e21f74e>
I5 = <b 0x43981eba846e21f74e>
[Peer]
PublicKey = <PUBLIC_KEY_SERVER_B>
Endpoint = <PUBLIC_IP_SERVER_B>:8443
AllowedIPs = 10.10.10.1/32
PersistentKeepalive = 25
```
`<PRIVATE_KEY_SERVER_A>` - содержимое файла `private.key` с сервера **A**.\
`<PUBLIC_KEY_SERVER_B>` - содержимое файла `public.key` с сервера **B**.\
`<PUBLIC_IP_SERVER_B>` - публичный IP адресс сервера **B**.
#### Включаем туннель на обоих серверах:
```bash
sudo systemctl enable --now awg-quick@awg0
```
Убедитесь, что с Сервера `A` доступен Сервер `B` через туннель.
```bash
ping 10.10.10.1
PING 10.10.10.1 (10.10.10.1) 56(84) bytes of data.
64 bytes from 10.10.10.1: icmp_seq=1 ttl=64 time=35.1 ms
64 bytes from 10.10.10.1: icmp_seq=2 ttl=64 time=35.0 ms
64 bytes from 10.10.10.1: icmp_seq=3 ttl=64 time=35.1 ms
^C
```
---
## Шаг 2. Установка telemt на Сервере B (_условно Нидерланды_)
Установка и настройка описаны [здесь](https://github.com/telemt/telemt/blob/main/docs/QUICK_START_GUIDE.ru.md) или [здесь](https://gitlab.com/An0nX/telemt-docker#-quick-start-docker-compose).\
Подразумевается что telemt ожидает подключения на порту `443\tcp`.
В конфиге telemt необходимо включить протокол `Proxy` и ограничить подключения к нему только через туннель.
```toml
[server]
port = 443
listen_addr_ipv4 = "10.10.10.1"
proxy_protocol = true
```
А также, для правильной генерации ссылок, указать FQDN или IP адрес и порт Сервера `A`
```toml
[general.links]
show = "*"
public_host = "<FQDN_OR_IP_SERVER_A>"
public_port = 443
```
Открываем порт на фаерволе (если включен):
```bash
sudo ufw allow from 10.10.10.2 to any port 443 proto tcp
```
---
### Шаг 3. Настройка HAProxy на Сервере A (_РФ_)
Т.к. в стандартном репозитории Ubuntu версия относительно старая, имеет смысл воспользоваться официальным образом Docker.\
[Инструкция](https://docs.docker.com/engine/install/ubuntu/) по установке Docker на Ubuntu.
> [!WARNING]
> По умолчанию у обычных пользователей нет прав на использование портов < 1024.\
> Попытки запустить HAProxy на 443 порту могут приводить к ошибкам:
> ```
> [ALERT] (8) : Binding [/usr/local/etc/haproxy/haproxy.cfg:17] for frontend tcp_in_443:
> protocol tcpv4: cannot bind socket (Permission denied) for [0.0.0.0:443].
> ```
> Есть два простых способа обойти это ограничение, выберите что-то одно:
> 1. На уровне ОС изменить настройку net.ipv4.ip_unprivileged_port_start, разрешив пользователям использовать все порты:
> ```
> echo "net.ipv4.ip_unprivileged_port_start = 0" | sudo tee -a /etc/sysctl.conf && sudo sysctl -p
> ```
> или
>
> 2. Запустить HAProxy под root:\
> Раскомментируйте в docker-compose.yaml параметр `user: "root"`.
#### Создаем папку для HAProxy:
```bash
mkdir -p /opt/docker-compose/haproxy && cd $_
```
#### Создаем файл docker-compose.yaml
`nano docker-compose.yaml`
Содержимое файла
```yaml
services:
haproxy:
image: haproxy:latest
container_name: haproxy
restart: unless-stopped
# user: "root"
network_mode: "host"
volumes:
- ./haproxy.cfg:/usr/local/etc/haproxy/haproxy.cfg:ro
logging:
driver: "json-file"
options:
max-size: "1m"
max-file: "1"
```
#### Создаем файл конфига haproxy.cfg
Принимаем подключения на порту 443\tcp и отправляем их через туннель на Сервер `B` 10.10.10.1:443
`nano haproxy.cfg`
Содержимое файла
```haproxy
global
log stdout format raw local0
maxconn 10000
defaults
log global
mode tcp
option tcplog
option clitcpka
option srvtcpka
timeout connect 5s
timeout client 2h
timeout server 2h
timeout check 5s
frontend tcp_in_443
bind *:443
maxconn 8000
option tcp-smart-accept
default_backend telemt_nodes
backend telemt_nodes
option tcp-smart-connect
server server_a 10.10.10.1:443 check inter 5s rise 2 fall 3 send-proxy-v2
```
>[!WARNING]
>**Файл должен заканчиваться пустой строкой, иначе HAProxy не запустится!**
#### Разрешаем порт 443\tcp в фаерволе (если включен)
```bash
sudo ufw allow 443/tcp
```
#### Запускаем контейнер HAProxy
```bash
docker compose up -d
```
Если все настроено верно, то теперь можно пробовать подключить клиентов Telegram с использованием ссылок из лога\api telemt.
+34
View File
@@ -71,6 +71,22 @@ pub(crate) fn default_tls_fetch_scope() -> String {
String::new()
}
pub(crate) fn default_tls_fetch_attempt_timeout_ms() -> u64 {
5_000
}
pub(crate) fn default_tls_fetch_total_budget_ms() -> u64 {
15_000
}
pub(crate) fn default_tls_fetch_strict_route() -> bool {
true
}
pub(crate) fn default_tls_fetch_profile_cache_ttl_secs() -> u64 {
600
}
pub(crate) fn default_mask_port() -> u16 {
443
}
@@ -185,6 +201,10 @@ pub(crate) fn default_proxy_protocol_header_timeout_ms() -> u64 {
500
}
pub(crate) fn default_proxy_protocol_trusted_cidrs() -> Vec<IpNetwork> {
vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]
}
pub(crate) fn default_server_max_connections() -> u32 {
10_000
}
@@ -553,6 +573,20 @@ pub(crate) fn default_mask_shape_above_cap_blur_max_bytes() -> usize {
512
}
#[cfg(not(test))]
pub(crate) fn default_mask_relay_max_bytes() -> usize {
5 * 1024 * 1024
}
#[cfg(test)]
pub(crate) fn default_mask_relay_max_bytes() -> usize {
32 * 1024
}
pub(crate) fn default_mask_classifier_prefetch_timeout_ms() -> u64 {
5
}
pub(crate) fn default_mask_timing_normalization_enabled() -> bool {
false
}
+6 -1
View File
@@ -228,7 +228,9 @@ impl HotFields {
me_d2c_flush_batch_max_delay_us: cfg.general.me_d2c_flush_batch_max_delay_us,
me_d2c_ack_flush_immediate: cfg.general.me_d2c_ack_flush_immediate,
me_quota_soft_overshoot_bytes: cfg.general.me_quota_soft_overshoot_bytes,
me_d2c_frame_buf_shrink_threshold_bytes: cfg.general.me_d2c_frame_buf_shrink_threshold_bytes,
me_d2c_frame_buf_shrink_threshold_bytes: cfg
.general
.me_d2c_frame_buf_shrink_threshold_bytes,
direct_relay_copy_buf_c2s_bytes: cfg.general.direct_relay_copy_buf_c2s_bytes,
direct_relay_copy_buf_s2c_bytes: cfg.general.direct_relay_copy_buf_s2c_bytes,
me_health_interval_ms_unhealthy: cfg.general.me_health_interval_ms_unhealthy,
@@ -600,6 +602,9 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
|| old.censorship.mask_shape_above_cap_blur != new.censorship.mask_shape_above_cap_blur
|| old.censorship.mask_shape_above_cap_blur_max_bytes
!= new.censorship.mask_shape_above_cap_blur_max_bytes
|| old.censorship.mask_relay_max_bytes != new.censorship.mask_relay_max_bytes
|| old.censorship.mask_classifier_prefetch_timeout_ms
!= new.censorship.mask_classifier_prefetch_timeout_ms
|| old.censorship.mask_timing_normalization_enabled
!= new.censorship.mask_timing_normalization_enabled
|| old.censorship.mask_timing_normalization_floor_ms
+218 -2
View File
@@ -1,6 +1,6 @@
#![allow(deprecated)]
use std::collections::{BTreeSet, HashMap};
use std::collections::{BTreeSet, HashMap, HashSet};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::net::{IpAddr, SocketAddr};
use std::path::{Path, PathBuf};
@@ -430,6 +430,24 @@ impl ProxyConfig {
));
}
if config.censorship.mask_relay_max_bytes == 0 {
return Err(ProxyError::Config(
"censorship.mask_relay_max_bytes must be > 0".to_string(),
));
}
if config.censorship.mask_relay_max_bytes > 67_108_864 {
return Err(ProxyError::Config(
"censorship.mask_relay_max_bytes must be <= 67108864".to_string(),
));
}
if !(5..=50).contains(&config.censorship.mask_classifier_prefetch_timeout_ms) {
return Err(ProxyError::Config(
"censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]".to_string(),
));
}
if config.censorship.mask_timing_normalization_ceiling_ms
< config.censorship.mask_timing_normalization_floor_ms
{
@@ -539,7 +557,9 @@ impl ProxyConfig {
));
}
if !(4096..=16 * 1024 * 1024).contains(&config.general.me_d2c_frame_buf_shrink_threshold_bytes) {
if !(4096..=16 * 1024 * 1024)
.contains(&config.general.me_d2c_frame_buf_shrink_threshold_bytes)
{
return Err(ProxyError::Config(
"general.me_d2c_frame_buf_shrink_threshold_bytes must be within [4096, 16777216]"
.to_string(),
@@ -957,6 +977,28 @@ impl ProxyConfig {
// Normalize optional TLS fetch scope: whitespace-only values disable scoped routing.
config.censorship.tls_fetch_scope = config.censorship.tls_fetch_scope.trim().to_string();
if config.censorship.tls_fetch.profiles.is_empty() {
config.censorship.tls_fetch.profiles = TlsFetchConfig::default().profiles;
} else {
let mut seen = HashSet::new();
config
.censorship
.tls_fetch
.profiles
.retain(|profile| seen.insert(*profile));
}
if config.censorship.tls_fetch.attempt_timeout_ms == 0 {
return Err(ProxyError::Config(
"censorship.tls_fetch.attempt_timeout_ms must be > 0".to_string(),
));
}
if config.censorship.tls_fetch.total_budget_ms == 0 {
return Err(ProxyError::Config(
"censorship.tls_fetch.total_budget_ms must be > 0".to_string(),
));
}
// Merge primary + extra TLS domains, deduplicate (primary always first).
if !config.censorship.tls_domains.is_empty() {
let mut all = Vec::with_capacity(1 + config.censorship.tls_domains.len());
@@ -1134,6 +1176,10 @@ mod load_security_tests;
#[path = "tests/load_mask_shape_security_tests.rs"]
mod load_mask_shape_security_tests;
#[cfg(test)]
#[path = "tests/load_mask_classifier_prefetch_timeout_security_tests.rs"]
mod load_mask_classifier_prefetch_timeout_security_tests;
#[cfg(test)]
mod tests {
use super::*;
@@ -1239,6 +1285,11 @@ mod tests {
assert_eq!(cfg.general.update_every, default_update_every());
assert_eq!(cfg.server.listen_addr_ipv4, default_listen_addr_ipv4());
assert_eq!(cfg.server.listen_addr_ipv6, default_listen_addr_ipv6_opt());
assert_eq!(
cfg.server.proxy_protocol_trusted_cidrs,
default_proxy_protocol_trusted_cidrs()
);
assert_eq!(cfg.censorship.unknown_sni_action, UnknownSniAction::Drop);
assert_eq!(cfg.server.api.listen, default_api_listen());
assert_eq!(cfg.server.api.whitelist, default_api_whitelist());
assert_eq!(
@@ -1371,6 +1422,14 @@ mod tests {
let server = ServerConfig::default();
assert_eq!(server.listen_addr_ipv6, Some(default_listen_addr_ipv6()));
assert_eq!(
server.proxy_protocol_trusted_cidrs,
default_proxy_protocol_trusted_cidrs()
);
assert_eq!(
AntiCensorshipConfig::default().unknown_sni_action,
UnknownSniAction::Drop
);
assert_eq!(server.api.listen, default_api_listen());
assert_eq!(server.api.whitelist, default_api_whitelist());
assert_eq!(
@@ -1406,6 +1465,75 @@ mod tests {
assert_eq!(access.users, default_access_users());
}
#[test]
fn proxy_protocol_trusted_cidrs_missing_uses_trust_all_but_explicit_empty_stays_empty() {
let cfg_missing: ProxyConfig = toml::from_str(
r#"
[server]
[general]
[network]
[access]
"#,
)
.unwrap();
assert_eq!(
cfg_missing.server.proxy_protocol_trusted_cidrs,
default_proxy_protocol_trusted_cidrs()
);
let cfg_explicit_empty: ProxyConfig = toml::from_str(
r#"
[server]
proxy_protocol_trusted_cidrs = []
[general]
[network]
[access]
"#,
)
.unwrap();
assert!(
cfg_explicit_empty
.server
.proxy_protocol_trusted_cidrs
.is_empty()
);
}
#[test]
fn unknown_sni_action_parses_and_defaults_to_drop() {
let cfg_default: ProxyConfig = toml::from_str(
r#"
[server]
[general]
[network]
[access]
[censorship]
"#,
)
.unwrap();
assert_eq!(
cfg_default.censorship.unknown_sni_action,
UnknownSniAction::Drop
);
let cfg_mask: ProxyConfig = toml::from_str(
r#"
[server]
[general]
[network]
[access]
[censorship]
unknown_sni_action = "mask"
"#,
)
.unwrap();
assert_eq!(
cfg_mask.censorship.unknown_sni_action,
UnknownSniAction::Mask
);
}
#[test]
fn dc_overrides_allow_string_and_array() {
let toml = r#"
@@ -2353,6 +2481,94 @@ mod tests {
let _ = std::fs::remove_file(path);
}
#[test]
fn tls_fetch_defaults_are_applied() {
let toml = r#"
[censorship]
tls_domain = "example.com"
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_tls_fetch_defaults_test.toml");
std::fs::write(&path, toml).unwrap();
let cfg = ProxyConfig::load(&path).unwrap();
assert_eq!(
cfg.censorship.tls_fetch.profiles,
TlsFetchConfig::default().profiles
);
assert!(cfg.censorship.tls_fetch.strict_route);
assert_eq!(cfg.censorship.tls_fetch.attempt_timeout_ms, 5_000);
assert_eq!(cfg.censorship.tls_fetch.total_budget_ms, 15_000);
assert_eq!(cfg.censorship.tls_fetch.profile_cache_ttl_secs, 600);
let _ = std::fs::remove_file(path);
}
#[test]
fn tls_fetch_profiles_are_deduplicated_preserving_order() {
let toml = r#"
[censorship]
tls_domain = "example.com"
[censorship.tls_fetch]
profiles = ["compat_tls12", "modern_chrome_like", "compat_tls12", "legacy_minimal"]
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_tls_fetch_profiles_dedup_test.toml");
std::fs::write(&path, toml).unwrap();
let cfg = ProxyConfig::load(&path).unwrap();
assert_eq!(
cfg.censorship.tls_fetch.profiles,
vec![
TlsFetchProfile::CompatTls12,
TlsFetchProfile::ModernChromeLike,
TlsFetchProfile::LegacyMinimal
]
);
let _ = std::fs::remove_file(path);
}
#[test]
fn tls_fetch_attempt_timeout_zero_is_rejected() {
let toml = r#"
[censorship]
tls_domain = "example.com"
[censorship.tls_fetch]
attempt_timeout_ms = 0
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_tls_fetch_attempt_timeout_zero_test.toml");
std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!(err.contains("censorship.tls_fetch.attempt_timeout_ms must be > 0"));
let _ = std::fs::remove_file(path);
}
#[test]
fn tls_fetch_total_budget_zero_is_rejected() {
let toml = r#"
[censorship]
tls_domain = "example.com"
[censorship.tls_fetch]
total_budget_ms = 0
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_tls_fetch_total_budget_zero_test.toml");
std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!(err.contains("censorship.tls_fetch.total_budget_ms must be > 0"));
let _ = std::fs::remove_file(path);
}
#[test]
fn invalid_ad_tag_is_disabled_during_load() {
let toml = r#"
@@ -0,0 +1,76 @@
use super::*;
use std::fs;
use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};
fn write_temp_config(contents: &str) -> PathBuf {
let nonce = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time must be after unix epoch")
.as_nanos();
let path = std::env::temp_dir().join(format!(
"telemt-load-mask-prefetch-timeout-security-{nonce}.toml"
));
fs::write(&path, contents).expect("temp config write must succeed");
path
}
fn remove_temp_config(path: &PathBuf) {
let _ = fs::remove_file(path);
}
#[test]
fn load_rejects_mask_classifier_prefetch_timeout_below_min_bound() {
let path = write_temp_config(
r#"
[censorship]
mask_classifier_prefetch_timeout_ms = 4
"#,
);
let err = ProxyConfig::load(&path)
.expect_err("prefetch timeout below minimum security bound must be rejected");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"),
"error must explain timeout bound invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_mask_classifier_prefetch_timeout_above_max_bound() {
let path = write_temp_config(
r#"
[censorship]
mask_classifier_prefetch_timeout_ms = 51
"#,
);
let err = ProxyConfig::load(&path)
.expect_err("prefetch timeout above max security bound must be rejected");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"),
"error must explain timeout bound invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_accepts_mask_classifier_prefetch_timeout_within_bounds() {
let path = write_temp_config(
r#"
[censorship]
mask_classifier_prefetch_timeout_ms = 20
"#,
);
let cfg =
ProxyConfig::load(&path).expect("prefetch timeout within security bounds must be accepted");
assert_eq!(cfg.censorship.mask_classifier_prefetch_timeout_ms, 20);
remove_temp_config(&path);
}
@@ -236,3 +236,57 @@ mask_shape_above_cap_blur_max_bytes = 8
remove_temp_config(&path);
}
#[test]
fn load_rejects_zero_mask_relay_max_bytes() {
let path = write_temp_config(
r#"
[censorship]
mask_relay_max_bytes = 0
"#,
);
let err = ProxyConfig::load(&path).expect_err("mask_relay_max_bytes must be > 0");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_relay_max_bytes must be > 0"),
"error must explain non-zero relay cap invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_mask_relay_max_bytes_above_upper_bound() {
let path = write_temp_config(
r#"
[censorship]
mask_relay_max_bytes = 67108865
"#,
);
let err =
ProxyConfig::load(&path).expect_err("mask_relay_max_bytes above hard cap must be rejected");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_relay_max_bytes must be <= 67108864"),
"error must explain relay cap upper bound invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_accepts_valid_mask_relay_max_bytes() {
let path = write_temp_config(
r#"
[censorship]
mask_relay_max_bytes = 8388608
"#,
);
let cfg = ProxyConfig::load(&path).expect("valid mask_relay_max_bytes must be accepted");
assert_eq!(cfg.censorship.mask_relay_max_bytes, 8_388_608);
remove_temp_config(&path);
}
+111 -5
View File
@@ -954,7 +954,8 @@ impl Default for GeneralConfig {
me_d2c_flush_batch_max_delay_us: default_me_d2c_flush_batch_max_delay_us(),
me_d2c_ack_flush_immediate: default_me_d2c_ack_flush_immediate(),
me_quota_soft_overshoot_bytes: default_me_quota_soft_overshoot_bytes(),
me_d2c_frame_buf_shrink_threshold_bytes: default_me_d2c_frame_buf_shrink_threshold_bytes(),
me_d2c_frame_buf_shrink_threshold_bytes:
default_me_d2c_frame_buf_shrink_threshold_bytes(),
direct_relay_copy_buf_c2s_bytes: default_direct_relay_copy_buf_c2s_bytes(),
direct_relay_copy_buf_s2c_bytes: default_direct_relay_copy_buf_s2c_bytes(),
me_warmup_stagger_enabled: default_true(),
@@ -1239,9 +1240,10 @@ pub struct ServerConfig {
/// Trusted source CIDRs allowed to send incoming PROXY protocol headers.
///
/// When non-empty, connections from addresses outside this allowlist are
/// rejected before `src_addr` is applied.
#[serde(default)]
/// If this field is omitted in config, it defaults to trust-all CIDRs
/// (`0.0.0.0/0` and `::/0`). If it is explicitly set to an empty list,
/// all PROXY protocol headers are rejected.
#[serde(default = "default_proxy_protocol_trusted_cidrs")]
pub proxy_protocol_trusted_cidrs: Vec<IpNetwork>,
/// Port for the Prometheus-compatible metrics endpoint.
@@ -1286,7 +1288,7 @@ impl Default for ServerConfig {
listen_tcp: None,
proxy_protocol: false,
proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(),
proxy_protocol_trusted_cidrs: Vec::new(),
proxy_protocol_trusted_cidrs: default_proxy_protocol_trusted_cidrs(),
metrics_port: None,
metrics_listen: None,
metrics_whitelist: default_metrics_whitelist(),
@@ -1357,6 +1359,90 @@ impl Default for TimeoutsConfig {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum UnknownSniAction {
#[default]
Drop,
Mask,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TlsFetchProfile {
ModernChromeLike,
ModernFirefoxLike,
CompatTls12,
LegacyMinimal,
}
impl TlsFetchProfile {
pub fn as_str(self) -> &'static str {
match self {
TlsFetchProfile::ModernChromeLike => "modern_chrome_like",
TlsFetchProfile::ModernFirefoxLike => "modern_firefox_like",
TlsFetchProfile::CompatTls12 => "compat_tls12",
TlsFetchProfile::LegacyMinimal => "legacy_minimal",
}
}
}
fn default_tls_fetch_profiles() -> Vec<TlsFetchProfile> {
vec![
TlsFetchProfile::ModernChromeLike,
TlsFetchProfile::ModernFirefoxLike,
TlsFetchProfile::CompatTls12,
TlsFetchProfile::LegacyMinimal,
]
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsFetchConfig {
/// Ordered list of ClientHello profiles used for adaptive fallback.
#[serde(default = "default_tls_fetch_profiles")]
pub profiles: Vec<TlsFetchProfile>,
/// When true and upstream route is configured, TLS fetch fails closed on
/// upstream connect errors and does not fallback to direct TCP.
#[serde(default = "default_tls_fetch_strict_route")]
pub strict_route: bool,
/// Timeout per one profile attempt in milliseconds.
#[serde(default = "default_tls_fetch_attempt_timeout_ms")]
pub attempt_timeout_ms: u64,
/// Total wall-clock budget in milliseconds across all profile attempts.
#[serde(default = "default_tls_fetch_total_budget_ms")]
pub total_budget_ms: u64,
/// Adds GREASE-style values into selected ClientHello extensions.
#[serde(default)]
pub grease_enabled: bool,
/// Produces deterministic ClientHello randomness for debugging/tests.
#[serde(default)]
pub deterministic: bool,
/// TTL for winner-profile cache entries in seconds.
/// Set to 0 to disable profile cache.
#[serde(default = "default_tls_fetch_profile_cache_ttl_secs")]
pub profile_cache_ttl_secs: u64,
}
impl Default for TlsFetchConfig {
fn default() -> Self {
Self {
profiles: default_tls_fetch_profiles(),
strict_route: default_tls_fetch_strict_route(),
attempt_timeout_ms: default_tls_fetch_attempt_timeout_ms(),
total_budget_ms: default_tls_fetch_total_budget_ms(),
grease_enabled: false,
deterministic: false,
profile_cache_ttl_secs: default_tls_fetch_profile_cache_ttl_secs(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AntiCensorshipConfig {
#[serde(default = "default_tls_domain")]
@@ -1366,11 +1452,19 @@ pub struct AntiCensorshipConfig {
#[serde(default)]
pub tls_domains: Vec<String>,
/// Policy for TLS ClientHello with unknown (non-configured) SNI.
#[serde(default)]
pub unknown_sni_action: UnknownSniAction,
/// Upstream scope used for TLS front metadata fetches.
/// Empty value keeps default upstream routing behavior.
#[serde(default = "default_tls_fetch_scope")]
pub tls_fetch_scope: String,
/// Fetch strategy for TLS front metadata bootstrap and periodic refresh.
#[serde(default)]
pub tls_fetch: TlsFetchConfig,
#[serde(default = "default_true")]
pub mask: bool,
@@ -1450,6 +1544,14 @@ pub struct AntiCensorshipConfig {
#[serde(default = "default_mask_shape_above_cap_blur_max_bytes")]
pub mask_shape_above_cap_blur_max_bytes: usize,
/// Maximum bytes relayed per direction on unauthenticated masking fallback paths.
#[serde(default = "default_mask_relay_max_bytes")]
pub mask_relay_max_bytes: usize,
/// Prefetch timeout (ms) for extending fragmented masking classifier window.
#[serde(default = "default_mask_classifier_prefetch_timeout_ms")]
pub mask_classifier_prefetch_timeout_ms: u64,
/// Enable outcome-time normalization envelope for masking fallback.
#[serde(default = "default_mask_timing_normalization_enabled")]
pub mask_timing_normalization_enabled: bool,
@@ -1468,7 +1570,9 @@ impl Default for AntiCensorshipConfig {
Self {
tls_domain: default_tls_domain(),
tls_domains: Vec::new(),
unknown_sni_action: UnknownSniAction::Drop,
tls_fetch_scope: default_tls_fetch_scope(),
tls_fetch: TlsFetchConfig::default(),
mask: default_true(),
mask_host: None,
mask_port: default_mask_port(),
@@ -1488,6 +1592,8 @@ impl Default for AntiCensorshipConfig {
mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(),
mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(),
mask_shape_above_cap_blur_max_bytes: default_mask_shape_above_cap_blur_max_bytes(),
mask_relay_max_bytes: default_mask_relay_max_bytes(),
mask_classifier_prefetch_timeout_ms: default_mask_classifier_prefetch_timeout_ms(),
mask_timing_normalization_enabled: default_mask_timing_normalization_enabled(),
mask_timing_normalization_floor_ms: default_mask_timing_normalization_floor_ms(),
mask_timing_normalization_ceiling_ms: default_mask_timing_normalization_ceiling_ms(),
+3
View File
@@ -216,6 +216,9 @@ pub enum ProxyError {
#[error("Invalid proxy protocol header")]
InvalidProxyProtocol,
#[error("Unknown TLS SNI")]
UnknownTlsSni,
#[error("Proxy error: {0}")]
Proxy(String),
+5 -2
View File
@@ -8,8 +8,10 @@ use tracing::{debug, error, info, warn};
use crate::cli;
use crate::config::ProxyConfig;
use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::{
ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache,
ProxyConfigData, fetch_proxy_config_with_raw_via_upstream, load_proxy_config_cache,
save_proxy_config_cache,
};
pub(crate) fn resolve_runtime_config_path(
@@ -288,9 +290,10 @@ pub(crate) async fn load_startup_proxy_config_snapshot(
cache_path: Option<&str>,
me2dc_fallback: bool,
label: &'static str,
upstream: Option<std::sync::Arc<UpstreamManager>>,
) -> Option<ProxyConfigData> {
loop {
match fetch_proxy_config_with_raw(url).await {
match fetch_proxy_config_with_raw_via_upstream(url, upstream.clone()).await {
Ok((cfg, raw)) => {
if !cfg.map.is_empty() {
if let Some(path) = cache_path
+4 -1
View File
@@ -63,9 +63,10 @@ pub(crate) async fn initialize_me_pool(
let proxy_secret_path = config.general.proxy_secret_path.as_deref();
let pool_size = config.general.middle_proxy_pool_size.max(1);
let proxy_secret = loop {
match crate::transport::middle_proxy::fetch_proxy_secret(
match crate::transport::middle_proxy::fetch_proxy_secret_with_upstream(
proxy_secret_path,
config.general.proxy_secret_len_max,
Some(upstream_manager.clone()),
)
.await
{
@@ -129,6 +130,7 @@ pub(crate) async fn initialize_me_pool(
config.general.proxy_config_v4_cache_path.as_deref(),
me2dc_fallback,
"getProxyConfig",
Some(upstream_manager.clone()),
)
.await;
if cfg_v4.is_some() {
@@ -160,6 +162,7 @@ pub(crate) async fn initialize_me_pool(
config.general.proxy_config_v6_cache_path.as_deref(),
me2dc_fallback,
"getProxyConfigV6",
Some(upstream_manager.clone()),
)
.await;
if cfg_v6.is_some() {
+20 -5
View File
@@ -7,6 +7,7 @@ use tracing::warn;
use crate::config::ProxyConfig;
use crate::startup::{COMPONENT_TLS_FRONT_BOOTSTRAP, StartupTracker};
use crate::tls_front::TlsFrontCache;
use crate::tls_front::fetcher::TlsFetchStrategy;
use crate::transport::UpstreamManager;
pub(crate) async fn bootstrap_tls_front(
@@ -40,7 +41,17 @@ pub(crate) async fn bootstrap_tls_front(
let mask_unix_sock = config.censorship.mask_unix_sock.clone();
let tls_fetch_scope = (!config.censorship.tls_fetch_scope.is_empty())
.then(|| config.censorship.tls_fetch_scope.clone());
let fetch_timeout = Duration::from_secs(5);
let tls_fetch = config.censorship.tls_fetch.clone();
let fetch_strategy = TlsFetchStrategy {
profiles: tls_fetch.profiles,
strict_route: tls_fetch.strict_route,
attempt_timeout: Duration::from_millis(tls_fetch.attempt_timeout_ms.max(1)),
total_budget: Duration::from_millis(tls_fetch.total_budget_ms.max(1)),
grease_enabled: tls_fetch.grease_enabled,
deterministic: tls_fetch.deterministic,
profile_cache_ttl: Duration::from_secs(tls_fetch.profile_cache_ttl_secs),
};
let fetch_timeout = fetch_strategy.total_budget;
let cache_initial = cache.clone();
let domains_initial = tls_domains.to_vec();
@@ -48,6 +59,7 @@ pub(crate) async fn bootstrap_tls_front(
let unix_sock_initial = mask_unix_sock.clone();
let scope_initial = tls_fetch_scope.clone();
let upstream_initial = upstream_manager.clone();
let strategy_initial = fetch_strategy.clone();
tokio::spawn(async move {
let mut join = tokio::task::JoinSet::new();
for domain in domains_initial {
@@ -56,12 +68,13 @@ pub(crate) async fn bootstrap_tls_front(
let unix_sock_domain = unix_sock_initial.clone();
let scope_domain = scope_initial.clone();
let upstream_domain = upstream_initial.clone();
let strategy_domain = strategy_initial.clone();
join.spawn(async move {
match crate::tls_front::fetcher::fetch_real_tls(
match crate::tls_front::fetcher::fetch_real_tls_with_strategy(
&host_domain,
port,
&domain,
fetch_timeout,
&strategy_domain,
Some(upstream_domain),
scope_domain.as_deref(),
proxy_protocol,
@@ -107,6 +120,7 @@ pub(crate) async fn bootstrap_tls_front(
let unix_sock_refresh = mask_unix_sock.clone();
let scope_refresh = tls_fetch_scope.clone();
let upstream_refresh = upstream_manager.clone();
let strategy_refresh = fetch_strategy.clone();
tokio::spawn(async move {
loop {
let base_secs = rand::rng().random_range(4 * 3600..=6 * 3600);
@@ -120,12 +134,13 @@ pub(crate) async fn bootstrap_tls_front(
let unix_sock_domain = unix_sock_refresh.clone();
let scope_domain = scope_refresh.clone();
let upstream_domain = upstream_refresh.clone();
let strategy_domain = strategy_refresh.clone();
join.spawn(async move {
match crate::tls_front::fetcher::fetch_real_tls(
match crate::tls_front::fetcher::fetch_real_tls_with_strategy(
&host_domain,
port,
&domain,
fetch_timeout,
&strategy_domain,
Some(upstream_domain),
scope_domain.as_deref(),
proxy_protocol,
+4 -3
View File
@@ -7,12 +7,12 @@ mod crypto;
mod error;
mod ip_tracker;
#[cfg(test)]
#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"]
mod ip_tracker_hotpath_adversarial_tests;
#[cfg(test)]
#[path = "tests/ip_tracker_encapsulation_adversarial_tests.rs"]
mod ip_tracker_encapsulation_adversarial_tests;
#[cfg(test)]
#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"]
mod ip_tracker_hotpath_adversarial_tests;
#[cfg(test)]
#[path = "tests/ip_tracker_regression_tests.rs"]
mod ip_tracker_regression_tests;
mod maestro;
@@ -29,5 +29,6 @@ mod util;
#[tokio::main]
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let _ = rustls::crypto::ring::default_provider().install_default();
maestro::run().await
}
+1 -4
View File
@@ -1233,10 +1233,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out,
"# HELP telemt_me_d2c_batch_bytes_bucket_total DC->Client batch byte size buckets"
);
let _ = writeln!(
out,
"# TYPE telemt_me_d2c_batch_bytes_bucket_total counter"
);
let _ = writeln!(out, "# TYPE telemt_me_d2c_batch_bytes_bucket_total counter");
let _ = writeln!(
out,
"telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"0_1k\"}} {}",
+125 -5
View File
@@ -186,6 +186,72 @@ fn handshake_timeout_with_mask_grace(config: &ProxyConfig) -> Duration {
}
}
const MASK_CLASSIFIER_PREFETCH_WINDOW: usize = 16;
#[cfg(test)]
const MASK_CLASSIFIER_PREFETCH_TIMEOUT: Duration = Duration::from_millis(5);
fn mask_classifier_prefetch_timeout(config: &ProxyConfig) -> Duration {
Duration::from_millis(config.censorship.mask_classifier_prefetch_timeout_ms)
}
fn should_prefetch_mask_classifier_window(initial_data: &[u8]) -> bool {
if initial_data.len() >= MASK_CLASSIFIER_PREFETCH_WINDOW {
return false;
}
if initial_data.is_empty() {
// Empty initial_data means there is no client probe prefix to refine.
// Prefetching in this case can consume fallback relay payload bytes and
// accidentally route them through shaping heuristics.
return false;
}
if initial_data[0] == 0x16 || initial_data.starts_with(b"SSH-") {
return false;
}
initial_data
.iter()
.all(|b| b.is_ascii_alphabetic() || *b == b' ')
}
#[cfg(test)]
async fn extend_masking_initial_window<R>(reader: &mut R, initial_data: &mut Vec<u8>)
where
R: AsyncRead + Unpin,
{
extend_masking_initial_window_with_timeout(
reader,
initial_data,
MASK_CLASSIFIER_PREFETCH_TIMEOUT,
)
.await;
}
async fn extend_masking_initial_window_with_timeout<R>(
reader: &mut R,
initial_data: &mut Vec<u8>,
prefetch_timeout: Duration,
) where
R: AsyncRead + Unpin,
{
if !should_prefetch_mask_classifier_window(initial_data) {
return;
}
let need = MASK_CLASSIFIER_PREFETCH_WINDOW.saturating_sub(initial_data.len());
if need == 0 {
return;
}
let mut extra = [0u8; MASK_CLASSIFIER_PREFETCH_WINDOW];
if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await
&& n > 0
{
initial_data.extend_from_slice(&extra[..n]);
}
}
fn masking_outcome<R, W>(
reader: R,
writer: W,
@@ -200,6 +266,15 @@ where
W: AsyncWrite + Unpin + Send + 'static,
{
HandshakeOutcome::NeedsMasking(Box::pin(async move {
let mut reader = reader;
let mut initial_data = initial_data;
extend_masking_initial_window_with_timeout(
&mut reader,
&mut initial_data,
mask_classifier_prefetch_timeout(&config),
)
.await;
handle_bad_client(
reader,
writer,
@@ -242,13 +317,20 @@ fn record_handshake_failure_class(
record_beobachten_class(beobachten, config, peer_ip, class);
}
#[inline]
fn increment_bad_on_unknown_tls_sni(stats: &Stats, error: &ProxyError) {
if matches!(error, ProxyError::UnknownTlsSni) {
stats.increment_connects_bad();
}
}
fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool {
if trusted.is_empty() {
static EMPTY_PROXY_TRUST_WARNED: OnceLock<AtomicBool> = OnceLock::new();
let warned = EMPTY_PROXY_TRUST_WARNED.get_or_init(|| AtomicBool::new(false));
if !warned.swap(true, Ordering::Relaxed) {
warn!(
"PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers by default"
"PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers"
);
}
return false;
@@ -433,7 +515,10 @@ where
beobachten.clone(),
));
}
HandshakeResult::Error(e) => return Err(e),
HandshakeResult::Error(e) => {
increment_bad_on_unknown_tls_sni(stats.as_ref(), &e);
return Err(e);
}
};
debug!(peer = %peer, "Reading MTProto handshake through TLS");
@@ -884,7 +969,10 @@ impl RunningClientHandler {
self.beobachten.clone(),
));
}
HandshakeResult::Error(e) => return Err(e),
HandshakeResult::Error(e) => {
increment_bad_on_unknown_tls_sni(stats.as_ref(), &e);
return Err(e);
}
};
debug!(peer = %peer, "Reading MTProto handshake through TLS");
@@ -1153,7 +1241,7 @@ impl RunningClientHandler {
}
if let Some(quota) = config.access.user_data_quota.get(user)
&& stats.get_user_total_octets(user) >= *quota
&& stats.get_user_quota_used(user) >= *quota
{
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
@@ -1212,7 +1300,7 @@ impl RunningClientHandler {
}
if let Some(quota) = config.access.user_data_quota.get(user)
&& stats.get_user_total_octets(user) >= *quota
&& stats.get_user_quota_used(user) >= *quota
{
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
@@ -1321,6 +1409,38 @@ mod masking_shape_classifier_fuzz_redteam_expected_fail_tests;
#[path = "tests/client_masking_probe_evasion_blackhat_tests.rs"]
mod masking_probe_evasion_blackhat_tests;
#[cfg(test)]
#[path = "tests/client_masking_fragmented_classifier_security_tests.rs"]
mod masking_fragmented_classifier_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_replay_timing_security_tests.rs"]
mod masking_replay_timing_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_http2_fragmented_preface_security_tests.rs"]
mod masking_http2_fragmented_preface_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_invariant_security_tests.rs"]
mod masking_prefetch_invariant_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_timing_matrix_security_tests.rs"]
mod masking_prefetch_timing_matrix_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_config_runtime_security_tests.rs"]
mod masking_prefetch_config_runtime_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs"]
mod masking_prefetch_config_pipeline_integration_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_strict_boundary_security_tests.rs"]
mod masking_prefetch_strict_boundary_security_tests;
#[cfg(test)]
#[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"]
mod beobachten_ttl_bounds_security_tests;
+165 -89
View File
@@ -16,7 +16,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, trace, warn};
use zeroize::{Zeroize, Zeroizing};
use crate::config::ProxyConfig;
use crate::config::{ProxyConfig, UnknownSniAction};
use crate::crypto::{AesCtr, SecureRandom, sha256};
use crate::error::{HandshakeResult, ProxyError};
use crate::protocol::constants::*;
@@ -121,6 +121,19 @@ fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
hasher.finish() as usize
}
fn auth_probe_scan_start_offset(
peer_ip: IpAddr,
now: Instant,
state_len: usize,
scan_limit: usize,
) -> usize {
if state_len == 0 || scan_limit == 0 {
return 0;
}
auth_probe_eviction_offset(peer_ip, now) % state_len
}
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
@@ -269,34 +282,9 @@ fn auth_probe_record_failure_with_state(
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
let state_len = state.len();
let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT);
let start_offset = if state_len == 0 {
0
} else {
auth_probe_eviction_offset(peer_ip, now) % state_len
};
let mut scanned = 0usize;
for entry in state.iter().skip(start_offset) {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => {}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(key);
}
scanned += 1;
if scanned >= scan_limit {
break;
}
}
if scanned < scan_limit {
for entry in state.iter().take(scan_limit - scanned) {
if state_len <= AUTH_PROBE_PRUNE_SCAN_LIMIT {
for entry in state.iter() {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
@@ -310,6 +298,46 @@ fn auth_probe_record_failure_with_state(
stale_keys.push(key);
}
}
} else {
let start_offset =
auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit);
let mut scanned = 0usize;
for entry in state.iter().skip(start_offset) {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => {}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(key);
}
scanned += 1;
if scanned >= scan_limit {
break;
}
}
if scanned < scan_limit {
for entry in state.iter().take(scan_limit - scanned) {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail
&& last_seen >= current_seen) => {}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(key);
}
}
}
}
for stale_key in stale_keys {
@@ -501,6 +529,21 @@ fn decode_user_secrets(
secrets
}
#[inline]
fn find_matching_tls_domain<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> {
if config.censorship.tls_domain.eq_ignore_ascii_case(sni) {
return Some(config.censorship.tls_domain.as_str());
}
for domain in &config.censorship.tls_domains {
if domain.eq_ignore_ascii_case(sni) {
return Some(domain.as_str());
}
}
None
}
async fn maybe_apply_server_hello_delay(config: &ProxyConfig) {
if config.censorship.server_hello_delay_max_ms == 0 {
return;
@@ -584,60 +627,12 @@ where
}
let client_sni = tls::extract_sni_from_client_hello(handshake);
let secrets = decode_user_secrets(config, client_sni.as_deref());
let validation = match tls::validate_tls_handshake_with_replay_window(
handshake,
&secrets,
config.access.ignore_time_skew,
config.access.replay_window_secs,
) {
Some(v) => v,
None => {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,
ignore_time_skew = config.access.ignore_time_skew,
"TLS handshake validation failed - no matching user or time skew"
);
return HandshakeResult::BadClient { reader, writer };
}
};
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s,
None => {
maybe_apply_server_hello_delay(config).await;
return HandshakeResult::BadClient { reader, writer };
}
};
let cached = if config.censorship.tls_emulation {
if let Some(cache) = tls_cache.as_ref() {
let selected_domain = if let Some(sni) = client_sni.as_ref() {
if cache.contains_domain(sni).await {
sni.clone()
} else {
config.censorship.tls_domain.clone()
}
} else {
config.censorship.tls_domain.clone()
};
let cached_entry = cache.get(&selected_domain).await;
let use_full_cert_payload = cache
.take_full_cert_budget_for_ip(
peer.ip(),
Duration::from_secs(config.censorship.tls_full_cert_ttl_secs),
)
.await;
Some((cached_entry, use_full_cert_payload))
} else {
None
}
} else {
None
};
let preferred_user_hint = client_sni
.as_deref()
.filter(|sni| config.access.users.contains_key(*sni));
let matched_tls_domain = client_sni
.as_deref()
.and_then(|sni| find_matching_tls_domain(config, sni));
let alpn_list = if config.censorship.alpn_enforce {
tls::extract_alpn_from_client_hello(handshake)
@@ -660,16 +655,81 @@ where
None
};
// Replay tracking is applied only after full policy validation (including
// ALPN checks) so rejected handshakes cannot poison replay state.
if client_sni.is_some() && matched_tls_domain.is_none() && preferred_user_hint.is_none() {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,
sni = ?client_sni,
action = ?config.censorship.unknown_sni_action,
"TLS handshake rejected by unknown SNI policy"
);
return match config.censorship.unknown_sni_action {
UnknownSniAction::Drop => HandshakeResult::Error(ProxyError::UnknownTlsSni),
UnknownSniAction::Mask => HandshakeResult::BadClient { reader, writer },
};
}
let secrets = decode_user_secrets(config, preferred_user_hint);
let validation = match tls::validate_tls_handshake_with_replay_window(
handshake,
&secrets,
config.access.ignore_time_skew,
config.access.replay_window_secs,
) {
Some(v) => v,
None => {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,
ignore_time_skew = config.access.ignore_time_skew,
"TLS handshake validation failed - no matching user or time skew"
);
return HandshakeResult::BadClient { reader, writer };
}
};
// Reject known replay digests before expensive cache/domain/ALPN policy work.
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_and_add_tls_digest(digest_half) {
if replay_checker.check_tls_digest(digest_half) {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer };
}
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s,
None => {
maybe_apply_server_hello_delay(config).await;
return HandshakeResult::BadClient { reader, writer };
}
};
let cached = if config.censorship.tls_emulation {
if let Some(cache) = tls_cache.as_ref() {
let selected_domain =
matched_tls_domain.unwrap_or(config.censorship.tls_domain.as_str());
let cached_entry = cache.get(selected_domain).await;
let use_full_cert_payload = cache
.take_full_cert_budget_for_ip(
peer.ip(),
Duration::from_secs(config.censorship.tls_full_cert_ttl_secs),
)
.await;
Some((cached_entry, use_full_cert_payload))
} else {
None
}
} else {
None
};
// Add replay digest only for policy-valid handshakes.
replay_checker.add_tls_digest(digest_half);
let response = if let Some((cached_entry, use_full_cert_payload)) = cached {
emulator::build_emulated_server_hello(
secret,
@@ -769,7 +829,7 @@ where
let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let dec_key = Zeroizing::new(sha256(&dec_key_input));
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
@@ -805,7 +865,7 @@ where
let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
enc_key_input.extend_from_slice(enc_prekey);
enc_key_input.extend_from_slice(&secret);
let enc_key = sha256(&enc_key_input);
let enc_key = Zeroizing::new(sha256(&enc_key_input));
let mut enc_iv_arr = [0u8; IV_LEN];
enc_iv_arr.copy_from_slice(enc_iv_bytes);
@@ -830,9 +890,9 @@ where
user: user.clone(),
dc_idx,
proto_tag,
dec_key,
dec_key: *dec_key,
dec_iv,
enc_key,
enc_key: *enc_key,
enc_iv,
peer,
is_tls,
@@ -979,6 +1039,18 @@ mod saturation_poison_security_tests;
#[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"]
mod auth_probe_hardening_adversarial_tests;
#[cfg(test)]
#[path = "tests/handshake_auth_probe_scan_budget_security_tests.rs"]
mod auth_probe_scan_budget_security_tests;
#[cfg(test)]
#[path = "tests/handshake_auth_probe_scan_offset_stress_tests.rs"]
mod auth_probe_scan_offset_stress_tests;
#[cfg(test)]
#[path = "tests/handshake_auth_probe_eviction_bias_security_tests.rs"]
mod auth_probe_eviction_bias_security_tests;
#[cfg(test)]
#[path = "tests/handshake_advanced_clever_tests.rs"]
mod advanced_clever_tests;
@@ -995,6 +1067,10 @@ mod real_bug_stress_tests;
#[path = "tests/handshake_timing_manual_bench_tests.rs"]
mod timing_manual_bench_tests;
#[cfg(test)]
#[path = "tests/handshake_key_material_zeroization_security_tests.rs"]
mod handshake_key_material_zeroization_security_tests;
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and
/// must never be Copy. A Copy impl would allow silent key duplication,
/// undermining the zeroize-on-drop guarantee.
+489 -34
View File
@@ -4,14 +4,23 @@ use crate::config::ProxyConfig;
use crate::network::dns_overrides::resolve_socket_addr;
use crate::stats::beobachten::BeobachtenStore;
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
use rand::{Rng, RngExt};
use std::net::SocketAddr;
#[cfg(unix)]
use nix::ifaddrs::getifaddrs;
use rand::rngs::StdRng;
use rand::{Rng, RngExt, SeedableRng};
use std::net::{IpAddr, SocketAddr};
use std::str;
use std::time::Duration;
#[cfg(test)]
use std::sync::atomic::{AtomicUsize, Ordering};
#[cfg(unix)]
use std::sync::{Mutex, OnceLock};
use std::time::{Duration, Instant as StdInstant};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
#[cfg(unix)]
use tokio::sync::Mutex as AsyncMutex;
use tokio::time::{Instant, timeout};
use tracing::debug;
@@ -30,28 +39,55 @@ const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5);
#[cfg(test)]
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100);
const MASK_BUFFER_SIZE: usize = 8192;
#[cfg(unix)]
#[cfg(not(test))]
const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(300);
#[cfg(all(unix, test))]
const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(1);
struct CopyOutcome {
total: usize,
ended_by_eof: bool,
}
async fn copy_with_idle_timeout<R, W>(reader: &mut R, writer: &mut W) -> CopyOutcome
async fn copy_with_idle_timeout<R, W>(
reader: &mut R,
writer: &mut W,
byte_cap: usize,
shutdown_on_eof: bool,
) -> CopyOutcome
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let mut buf = [0u8; MASK_BUFFER_SIZE];
let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]);
let mut total = 0usize;
let mut ended_by_eof = false;
if byte_cap == 0 {
return CopyOutcome {
total,
ended_by_eof,
};
}
loop {
let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await;
let remaining_budget = byte_cap.saturating_sub(total);
if remaining_budget == 0 {
break;
}
let read_len = remaining_budget.min(MASK_BUFFER_SIZE);
let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await;
let n = match read_res {
Ok(Ok(n)) => n,
Ok(Err(_)) | Err(_) => break,
};
if n == 0 {
ended_by_eof = true;
if shutdown_on_eof {
let _ = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.shutdown()).await;
}
break;
}
total = total.saturating_add(n);
@@ -68,6 +104,31 @@ where
}
}
fn is_http_probe(data: &[u8]) -> bool {
// RFC 7540 section 3.5: HTTP/2 client preface starts with "PRI ".
const HTTP_METHODS: [&[u8]; 10] = [
b"GET ", b"POST", b"HEAD", b"PUT ", b"DELETE", b"OPTIONS", b"CONNECT", b"TRACE", b"PATCH",
b"PRI ",
];
if data.is_empty() {
return false;
}
let window = &data[..data.len().min(16)];
for method in HTTP_METHODS {
if data.len() >= method.len() && window.starts_with(method) {
return true;
}
if (2..=3).contains(&window.len()) && method.starts_with(window) {
return true;
}
}
false
}
fn next_mask_shape_bucket(total: usize, floor: usize, cap: usize) -> usize {
if total == 0 || floor == 0 || cap < floor {
return total;
@@ -125,6 +186,11 @@ async fn maybe_write_shape_padding<W>(
let mut remaining = target_total - total_sent;
let mut pad_chunk = [0u8; 1024];
let deadline = Instant::now() + MASK_TIMEOUT;
// Use a Send RNG so relay futures remain spawn-safe under Tokio.
let mut rng = {
let mut seed_source = rand::rng();
StdRng::from_rng(&mut seed_source)
};
while remaining > 0 {
let now = Instant::now();
@@ -133,10 +199,7 @@ async fn maybe_write_shape_padding<W>(
}
let write_len = remaining.min(pad_chunk.len());
{
let mut rng = rand::rng();
rng.fill_bytes(&mut pad_chunk[..write_len]);
}
rng.fill_bytes(&mut pad_chunk[..write_len]);
let write_budget = deadline.saturating_duration_since(now);
match timeout(write_budget, mask_write.write_all(&pad_chunk[..write_len])).await {
Ok(Ok(())) => {}
@@ -167,11 +230,11 @@ where
}
}
async fn consume_client_data_with_timeout<R>(reader: R)
async fn consume_client_data_with_timeout_and_cap<R>(reader: R, byte_cap: usize)
where
R: AsyncRead + Unpin,
{
if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader))
if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, byte_cap))
.await
.is_err()
{
@@ -190,6 +253,13 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration {
if config.censorship.mask_timing_normalization_enabled {
let floor = config.censorship.mask_timing_normalization_floor_ms;
let ceiling = config.censorship.mask_timing_normalization_ceiling_ms;
if floor == 0 {
if ceiling == 0 {
return Duration::from_millis(0);
}
let mut rng = rand::rng();
return Duration::from_millis(rng.random_range(0..=ceiling));
}
if ceiling > floor {
let mut rng = rand::rng();
return Duration::from_millis(rng.random_range(floor..=ceiling));
@@ -219,14 +289,7 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) {
/// Detect client type based on initial data
fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request
if data.len() > 4
&& (data.starts_with(b"GET ")
|| data.starts_with(b"POST")
|| data.starts_with(b"HEAD")
|| data.starts_with(b"PUT ")
|| data.starts_with(b"DELETE")
|| data.starts_with(b"OPTIONS"))
{
if is_http_probe(data) {
return "HTTP";
}
@@ -248,6 +311,247 @@ fn detect_client_type(data: &[u8]) -> &'static str {
"unknown"
}
fn parse_mask_host_ip_literal(host: &str) -> Option<IpAddr> {
if host.starts_with('[') && host.ends_with(']') {
return host[1..host.len() - 1].parse::<IpAddr>().ok();
}
host.parse::<IpAddr>().ok()
}
fn canonical_ip(ip: IpAddr) -> IpAddr {
match ip {
IpAddr::V6(v6) => v6
.to_ipv4_mapped()
.map(IpAddr::V4)
.unwrap_or(IpAddr::V6(v6)),
IpAddr::V4(v4) => IpAddr::V4(v4),
}
}
#[cfg(unix)]
fn collect_local_interface_ips() -> Vec<IpAddr> {
#[cfg(test)]
LOCAL_INTERFACE_ENUMERATIONS.fetch_add(1, Ordering::Relaxed);
let mut out = Vec::new();
if let Ok(addrs) = getifaddrs() {
for iface in addrs {
if let Some(address) = iface.address {
if let Some(v4) = address.as_sockaddr_in() {
out.push(canonical_ip(IpAddr::V4(v4.ip())));
} else if let Some(v6) = address.as_sockaddr_in6() {
out.push(canonical_ip(IpAddr::V6(v6.ip())));
}
}
}
}
out
}
fn choose_interface_snapshot(previous: &[IpAddr], refreshed: Vec<IpAddr>) -> Vec<IpAddr> {
if refreshed.is_empty() && !previous.is_empty() {
return previous.to_vec();
}
refreshed
}
#[cfg(unix)]
#[derive(Default)]
struct LocalInterfaceCache {
ips: Vec<IpAddr>,
refreshed_at: Option<StdInstant>,
}
#[cfg(unix)]
static LOCAL_INTERFACE_CACHE: OnceLock<Mutex<LocalInterfaceCache>> = OnceLock::new();
#[cfg(unix)]
static LOCAL_INTERFACE_REFRESH_LOCK: OnceLock<AsyncMutex<()>> = OnceLock::new();
#[cfg(all(unix, test))]
fn local_interface_ips() -> Vec<IpAddr> {
let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default()));
let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
let stale = guard
.refreshed_at
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
if stale {
let refreshed = collect_local_interface_ips();
guard.ips = choose_interface_snapshot(&guard.ips, refreshed);
guard.refreshed_at = Some(StdInstant::now());
}
guard.ips.clone()
}
#[cfg(unix)]
async fn local_interface_ips_async() -> Vec<IpAddr> {
let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default()));
{
let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
let stale = guard
.refreshed_at
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
if !stale {
return guard.ips.clone();
}
}
let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(()));
let _refresh_guard = refresh_lock.lock().await;
{
let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
let stale = guard
.refreshed_at
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
if !stale {
return guard.ips.clone();
}
}
let refreshed = tokio::task::spawn_blocking(collect_local_interface_ips)
.await
.unwrap_or_default();
let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
let stale = guard
.refreshed_at
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
if stale {
guard.ips = choose_interface_snapshot(&guard.ips, refreshed);
guard.refreshed_at = Some(StdInstant::now());
}
guard.ips.clone()
}
#[cfg(all(not(unix), test))]
fn local_interface_ips() -> Vec<IpAddr> {
Vec::new()
}
#[cfg(not(unix))]
async fn local_interface_ips_async() -> Vec<IpAddr> {
Vec::new()
}
#[cfg(test)]
static LOCAL_INTERFACE_ENUMERATIONS: AtomicUsize = AtomicUsize::new(0);
#[cfg(test)]
fn reset_local_interface_enumerations_for_tests() {
LOCAL_INTERFACE_ENUMERATIONS.store(0, Ordering::Relaxed);
#[cfg(unix)]
if let Some(cache) = LOCAL_INTERFACE_CACHE.get() {
let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
guard.ips.clear();
guard.refreshed_at = None;
}
}
#[cfg(test)]
fn local_interface_enumerations_for_tests() -> usize {
LOCAL_INTERFACE_ENUMERATIONS.load(Ordering::Relaxed)
}
fn is_mask_target_local_listener_with_interfaces(
mask_host: &str,
mask_port: u16,
local_addr: SocketAddr,
resolved_override: Option<SocketAddr>,
interface_ips: &[IpAddr],
) -> bool {
if mask_port != local_addr.port() {
return false;
}
let local_ip = canonical_ip(local_addr.ip());
let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip);
if let Some(addr) = resolved_override {
let resolved_ip = canonical_ip(addr.ip());
if resolved_ip == local_ip {
return true;
}
if local_ip.is_unspecified()
&& (resolved_ip.is_loopback()
|| resolved_ip.is_unspecified()
|| interface_ips.contains(&resolved_ip))
{
return true;
}
}
if let Some(mask_ip) = literal_mask_ip {
if mask_ip == local_ip {
return true;
}
if local_ip.is_unspecified()
&& (mask_ip.is_loopback()
|| mask_ip.is_unspecified()
|| interface_ips.contains(&mask_ip))
{
return true;
}
}
false
}
#[cfg(test)]
fn is_mask_target_local_listener(
mask_host: &str,
mask_port: u16,
local_addr: SocketAddr,
resolved_override: Option<SocketAddr>,
) -> bool {
if mask_port != local_addr.port() {
return false;
}
let interfaces = local_interface_ips();
is_mask_target_local_listener_with_interfaces(
mask_host,
mask_port,
local_addr,
resolved_override,
&interfaces,
)
}
async fn is_mask_target_local_listener_async(
mask_host: &str,
mask_port: u16,
local_addr: SocketAddr,
resolved_override: Option<SocketAddr>,
) -> bool {
if mask_port != local_addr.port() {
return false;
}
let interfaces = local_interface_ips_async().await;
is_mask_target_local_listener_with_interfaces(
mask_host,
mask_port,
local_addr,
resolved_override,
&interfaces,
)
}
fn masking_beobachten_ttl(config: &ProxyConfig) -> Duration {
let minutes = config.general.beobachten_minutes;
let clamped = minutes.clamp(1, 24 * 60);
Duration::from_secs(clamped.saturating_mul(60))
}
fn build_mask_proxy_header(
version: u8,
peer: SocketAddr,
@@ -290,13 +594,14 @@ pub async fn handle_bad_client<R, W>(
{
let client_type = detect_client_type(initial_data);
if config.general.beobachten {
let ttl = Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60));
let ttl = masking_beobachten_ttl(config);
beobachten.record(client_type, peer.ip(), ttl);
}
if !config.censorship.mask {
// Masking disabled, just consume data
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes)
.await;
return;
}
@@ -341,6 +646,7 @@ pub async fn handle_bad_client<R, W>(
config.censorship.mask_shape_above_cap_blur,
config.censorship.mask_shape_above_cap_blur_max_bytes,
config.censorship.mask_shape_hardening_aggressive_mode,
config.censorship.mask_relay_max_bytes,
),
)
.await
@@ -353,12 +659,20 @@ pub async fn handle_bad_client<R, W>(
Ok(Err(e)) => {
wait_mask_connect_budget_if_needed(connect_started, config).await;
debug!(error = %e, "Failed to connect to mask unix socket");
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(
reader,
config.censorship.mask_relay_max_bytes,
)
.await;
wait_mask_outcome_budget(outcome_started, config).await;
}
Err(_) => {
debug!("Timeout connecting to mask unix socket");
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(
reader,
config.censorship.mask_relay_max_bytes,
)
.await;
wait_mask_outcome_budget(outcome_started, config).await;
}
}
@@ -372,6 +686,29 @@ pub async fn handle_bad_client<R, W>(
.unwrap_or(&config.censorship.tls_domain);
let mask_port = config.censorship.mask_port;
// Fail closed when fallback points at our own listener endpoint.
// Self-referential masking can create recursive proxy loops under
// misconfiguration and leak distinguishable load spikes to adversaries.
let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port);
if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, resolved_mask_addr)
.await
{
let outcome_started = Instant::now();
debug!(
client_type = client_type,
host = %mask_host,
port = mask_port,
local = %local_addr,
"Mask target resolves to local listener; refusing self-referential masking fallback"
);
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes)
.await;
wait_mask_outcome_budget(outcome_started, config).await;
return;
}
let outcome_started = Instant::now();
debug!(
client_type = client_type,
host = %mask_host,
@@ -381,10 +718,9 @@ pub async fn handle_bad_client<R, W>(
);
// Apply runtime DNS override for mask target when configured.
let mask_addr = resolve_socket_addr(mask_host, mask_port)
let mask_addr = resolved_mask_addr
.map(|addr| addr.to_string())
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
let outcome_started = Instant::now();
let connect_started = Instant::now();
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
match connect_result {
@@ -413,6 +749,7 @@ pub async fn handle_bad_client<R, W>(
config.censorship.mask_shape_above_cap_blur,
config.censorship.mask_shape_above_cap_blur_max_bytes,
config.censorship.mask_shape_hardening_aggressive_mode,
config.censorship.mask_relay_max_bytes,
),
)
.await
@@ -425,12 +762,20 @@ pub async fn handle_bad_client<R, W>(
Ok(Err(e)) => {
wait_mask_connect_budget_if_needed(connect_started, config).await;
debug!(error = %e, "Failed to connect to mask host");
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(
reader,
config.censorship.mask_relay_max_bytes,
)
.await;
wait_mask_outcome_budget(outcome_started, config).await;
}
Err(_) => {
debug!("Timeout connecting to mask host");
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(
reader,
config.censorship.mask_relay_max_bytes,
)
.await;
wait_mask_outcome_budget(outcome_started, config).await;
}
}
@@ -449,6 +794,7 @@ async fn relay_to_mask<R, W, MR, MW>(
shape_above_cap_blur: bool,
shape_above_cap_blur_max_bytes: usize,
shape_hardening_aggressive_mode: bool,
mask_relay_max_bytes: usize,
) where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
@@ -464,8 +810,18 @@ async fn relay_to_mask<R, W, MR, MW>(
}
let (upstream_copy, downstream_copy) = tokio::join!(
async { copy_with_idle_timeout(&mut reader, &mut mask_write).await },
async { copy_with_idle_timeout(&mut mask_read, &mut writer).await }
async {
copy_with_idle_timeout(
&mut reader,
&mut mask_write,
mask_relay_max_bytes,
!shape_hardening_enabled,
)
.await
},
async {
copy_with_idle_timeout(&mut mask_read, &mut writer, mask_relay_max_bytes, true).await
}
);
let total_sent = initial_data.len().saturating_add(upstream_copy.total);
@@ -491,13 +847,36 @@ async fn relay_to_mask<R, W, MR, MW>(
let _ = writer.shutdown().await;
}
/// Just consume all data from client without responding
async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R) {
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
while let Ok(n) = reader.read(&mut buf).await {
/// Just consume all data from client without responding.
async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R, byte_cap: usize) {
if byte_cap == 0 {
return;
}
// Keep drain path fail-closed under slow-loris stalls.
let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]);
let mut total = 0usize;
loop {
let remaining_budget = byte_cap.saturating_sub(total);
if remaining_budget == 0 {
break;
}
let read_len = remaining_budget.min(MASK_BUFFER_SIZE);
let n = match timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await {
Ok(Ok(n)) => n,
Ok(Err(_)) | Err(_) => break,
};
if n == 0 {
break;
}
total = total.saturating_add(n);
if total >= byte_cap {
break;
}
}
}
@@ -521,6 +900,10 @@ mod masking_shape_above_cap_blur_security_tests;
#[path = "tests/masking_timing_normalization_security_tests.rs"]
mod masking_timing_normalization_security_tests;
#[cfg(test)]
#[path = "tests/masking_timing_budget_coupling_security_tests.rs"]
mod masking_timing_budget_coupling_security_tests;
#[cfg(test)]
#[path = "tests/masking_ab_envelope_blur_integration_security_tests.rs"]
mod masking_ab_envelope_blur_integration_security_tests;
@@ -548,3 +931,75 @@ mod masking_aggressive_mode_security_tests;
#[cfg(test)]
#[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"]
mod masking_timing_sidechannel_redteam_expected_fail_tests;
#[cfg(test)]
#[path = "tests/masking_self_target_loop_security_tests.rs"]
mod masking_self_target_loop_security_tests;
#[cfg(test)]
#[path = "tests/masking_classification_completeness_security_tests.rs"]
mod masking_classification_completeness_security_tests;
#[cfg(test)]
#[path = "tests/masking_relay_guardrails_security_tests.rs"]
mod masking_relay_guardrails_security_tests;
#[cfg(test)]
#[path = "tests/masking_connect_failure_close_matrix_security_tests.rs"]
mod masking_connect_failure_close_matrix_security_tests;
#[cfg(test)]
#[path = "tests/masking_additional_hardening_security_tests.rs"]
mod masking_additional_hardening_security_tests;
#[cfg(test)]
#[path = "tests/masking_consume_idle_timeout_security_tests.rs"]
mod masking_consume_idle_timeout_security_tests;
#[cfg(test)]
#[path = "tests/masking_http2_probe_classification_security_tests.rs"]
mod masking_http2_probe_classification_security_tests;
#[cfg(test)]
#[path = "tests/masking_http_probe_boundary_security_tests.rs"]
mod masking_http_probe_boundary_security_tests;
#[cfg(test)]
#[path = "tests/masking_rng_hoist_perf_regression_tests.rs"]
mod masking_rng_hoist_perf_regression_tests;
#[cfg(test)]
#[path = "tests/masking_http2_preface_integration_security_tests.rs"]
mod masking_http2_preface_integration_security_tests;
#[cfg(test)]
#[path = "tests/masking_consume_stress_adversarial_tests.rs"]
mod masking_consume_stress_adversarial_tests;
#[cfg(test)]
#[path = "tests/masking_interface_cache_security_tests.rs"]
mod masking_interface_cache_security_tests;
#[cfg(test)]
#[path = "tests/masking_interface_cache_defense_in_depth_security_tests.rs"]
mod masking_interface_cache_defense_in_depth_security_tests;
#[cfg(test)]
#[path = "tests/masking_interface_cache_concurrency_security_tests.rs"]
mod masking_interface_cache_concurrency_security_tests;
#[cfg(test)]
#[path = "tests/masking_production_cap_regression_security_tests.rs"]
mod masking_production_cap_regression_security_tests;
#[cfg(test)]
#[path = "tests/masking_extended_attack_surface_security_tests.rs"]
mod masking_extended_attack_surface_security_tests;
#[cfg(test)]
#[path = "tests/masking_padding_timeout_adversarial_tests.rs"]
mod masking_padding_timeout_adversarial_tests;
#[cfg(all(test, feature = "redteam_offline_expected_fail"))]
#[path = "tests/masking_offline_target_redteam_expected_fail_tests.rs"]
mod masking_offline_target_redteam_expected_fail_tests;
+188 -165
View File
@@ -1,5 +1,7 @@
use std::collections::hash_map::RandomState;
use std::collections::{BTreeSet, HashMap};
#[cfg(test)]
use std::future::Future;
use std::hash::{BuildHasher, Hash};
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
@@ -8,7 +10,7 @@ use std::time::{Duration, Instant};
use dashmap::DashMap;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{Mutex as AsyncMutex, mpsc, oneshot, watch};
use tokio::sync::{mpsc, oneshot, watch};
use tokio::time::timeout;
use tracing::{debug, info, trace, warn};
@@ -21,7 +23,9 @@ use crate::proxy::route_mode::{
ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state,
cutover_stagger_delay,
};
use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, Stats};
use crate::stats::{
MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, QuotaReserveError, Stats, UserStats,
};
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
@@ -39,28 +43,23 @@ const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1);
const TINY_FRAME_DEBT_PER_TINY: u32 = 8;
const TINY_FRAME_DEBT_LIMIT: u32 = 512;
#[cfg(test)]
const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50);
#[cfg(not(test))]
const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5);
#[cfg(test)]
const RELAY_TEST_STEP_TIMEOUT: Duration = Duration::from_secs(1);
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2;
const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024;
#[cfg(test)]
const QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))]
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
#[cfg(test)]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
#[cfg(not(test))]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
const QUOTA_RESERVE_SPIN_RETRIES: usize = 32;
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new();
static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock<Mutex<Option<Instant>>> = OnceLock::new();
static DESYNC_DEDUP_EVER_SATURATED: OnceLock<AtomicBool> = OnceLock::new();
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<AsyncMutex<()>>>> = OnceLock::new();
static QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<AsyncMutex<()>>>> = OnceLock::new();
static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock<Mutex<RelayIdleCandidateRegistry>> = OnceLock::new();
static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0);
@@ -94,10 +93,24 @@ fn relay_idle_candidate_registry() -> &'static Mutex<RelayIdleCandidateRegistry>
RELAY_IDLE_CANDIDATE_REGISTRY.get_or_init(|| Mutex::new(RelayIdleCandidateRegistry::default()))
}
fn relay_idle_candidate_registry_lock() -> std::sync::MutexGuard<'static, RelayIdleCandidateRegistry>
{
let registry = relay_idle_candidate_registry();
match registry.lock() {
Ok(guard) => guard,
Err(poisoned) => {
let mut guard = poisoned.into_inner();
// Fail closed after panic while holding registry lock: drop all
// candidates and pressure cursors to avoid stale cross-session state.
*guard = RelayIdleCandidateRegistry::default();
registry.clear_poison();
guard
}
}
}
fn mark_relay_idle_candidate(conn_id: u64) -> bool {
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
return false;
};
let mut guard = relay_idle_candidate_registry_lock();
if guard.by_conn_id.contains_key(&conn_id) {
return false;
@@ -116,9 +129,7 @@ fn mark_relay_idle_candidate(conn_id: u64) -> bool {
}
fn clear_relay_idle_candidate(conn_id: u64) {
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
return;
};
let mut guard = relay_idle_candidate_registry_lock();
if let Some(meta) = guard.by_conn_id.remove(&conn_id) {
guard.ordered.remove(&(meta.mark_order_seq, conn_id));
@@ -127,23 +138,17 @@ fn clear_relay_idle_candidate(conn_id: u64) {
#[cfg(test)]
fn oldest_relay_idle_candidate() -> Option<u64> {
let Ok(guard) = relay_idle_candidate_registry().lock() else {
return None;
};
let guard = relay_idle_candidate_registry_lock();
guard.ordered.iter().next().map(|(_, conn_id)| *conn_id)
}
fn note_relay_pressure_event() {
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
return;
};
let mut guard = relay_idle_candidate_registry_lock();
guard.pressure_event_seq = guard.pressure_event_seq.wrapping_add(1);
}
fn relay_pressure_event_seq() -> u64 {
let Ok(guard) = relay_idle_candidate_registry().lock() else {
return 0;
};
let guard = relay_idle_candidate_registry_lock();
guard.pressure_event_seq
}
@@ -152,9 +157,7 @@ fn maybe_evict_idle_candidate_on_pressure(
seen_pressure_seq: &mut u64,
stats: &Stats,
) -> bool {
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
return false;
};
let mut guard = relay_idle_candidate_registry_lock();
let latest_pressure_seq = guard.pressure_event_seq;
if latest_pressure_seq == *seen_pressure_seq {
@@ -199,13 +202,9 @@ fn maybe_evict_idle_candidate_on_pressure(
#[cfg(test)]
fn clear_relay_idle_pressure_state_for_testing() {
if let Some(registry) = RELAY_IDLE_CANDIDATE_REGISTRY.get()
&& let Ok(mut guard) = registry.lock()
{
guard.by_conn_id.clear();
guard.ordered.clear();
guard.pressure_event_seq = 0;
guard.pressure_consumed_seq = 0;
if RELAY_IDLE_CANDIDATE_REGISTRY.get().is_some() {
let mut guard = relay_idle_candidate_registry_lock();
*guard = RelayIdleCandidateRegistry::default();
}
RELAY_IDLE_MARK_SEQ.store(0, Ordering::Relaxed);
}
@@ -259,6 +258,7 @@ impl RelayClientIdlePolicy {
struct RelayClientIdleState {
last_client_frame_at: Instant,
soft_idle_marked: bool,
tiny_frame_debt: u32,
}
impl RelayClientIdleState {
@@ -266,6 +266,7 @@ impl RelayClientIdleState {
Self {
last_client_frame_at: now,
soft_idle_marked: false,
tiny_frame_debt: 0,
}
}
@@ -531,47 +532,28 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool
has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET
}
fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option<u64>) -> bool {
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota)
}
fn quota_would_be_exceeded_for_user(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
bytes: u64,
) -> bool {
quota_limit.is_some_and(|quota| {
let used = stats.get_user_total_octets(user);
used >= quota || bytes > quota.saturating_sub(used)
})
}
fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 {
limit.saturating_add(overshoot)
}
fn quota_exceeded_for_user_soft(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
overshoot: u64,
) -> bool {
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota_soft_cap(quota, overshoot))
}
fn quota_would_be_exceeded_for_user_soft(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
async fn reserve_user_quota_with_yield(
user_stats: &UserStats,
bytes: u64,
overshoot: u64,
) -> bool {
quota_limit.is_some_and(|quota| {
let cap = quota_soft_cap(quota, overshoot);
let used = stats.get_user_total_octets(user);
used >= cap || bytes > cap.saturating_sub(used)
})
limit: u64,
) -> std::result::Result<u64, QuotaReserveError> {
loop {
for _ in 0..QUOTA_RESERVE_SPIN_RETRIES {
match user_stats.quota_try_reserve(bytes, limit) {
Ok(total) => return Ok(total),
Err(QuotaReserveError::LimitExceeded) => {
return Err(QuotaReserveError::LimitExceeded);
}
Err(QuotaReserveError::Contended) => std::hint::spin_loop(),
}
}
tokio::task::yield_now().await;
}
}
fn classify_me_d2c_flush_reason(
@@ -618,53 +600,18 @@ fn observe_me_d2c_flush_event(
}
#[cfg(test)]
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
fn relay_idle_pressure_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
quota_user_lock_test_guard()
pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static, ()> {
relay_idle_pressure_test_guard()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn quota_overflow_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
.map(|_| Arc::new(AsyncMutex::new(())))
.collect()
});
let hash = crc32fast::hash(user.as_bytes()) as usize;
Arc::clone(&stripes[hash % stripes.len()])
}
fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
locks.retain(|_, value| Arc::strong_count(value) > 1);
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
return quota_overflow_user_lock(user);
}
let created = Arc::new(AsyncMutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
}
async fn enqueue_c2me_command(
tx: &mpsc::Sender<C2MeCommand>,
cmd: C2MeCommand,
@@ -690,6 +637,16 @@ async fn enqueue_c2me_command(
}
}
#[cfg(test)]
async fn run_relay_test_step_timeout<F, T>(context: &'static str, fut: F) -> T
where
F: Future<Output = T>,
{
timeout(RELAY_TEST_STEP_TIMEOUT, fut)
.await
.unwrap_or_else(|_| panic!("{context} exceeded {}s", RELAY_TEST_STEP_TIMEOUT.as_secs()))
}
pub(crate) async fn handle_via_middle_proxy<R, W>(
mut crypto_reader: CryptoReader<R>,
crypto_writer: CryptoWriter<W>,
@@ -710,6 +667,7 @@ where
{
let user = success.user.clone();
let quota_limit = config.access.user_data_quota.get(&user).copied();
let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user));
let peer = success.peer;
let proto_tag = success.proto_tag;
let pool_generation = me_pool.current_generation();
@@ -836,6 +794,7 @@ where
let stats_clone = stats.clone();
let rng_clone = rng.clone();
let user_clone = user.clone();
let quota_user_stats_me_writer = quota_user_stats.clone();
let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone();
let bytes_me2c_clone = bytes_me2c.clone();
let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config);
@@ -865,6 +824,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
bytes_me2c_clone.as_ref(),
@@ -923,6 +883,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
bytes_me2c_clone.as_ref(),
@@ -984,6 +945,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
bytes_me2c_clone.as_ref(),
@@ -1047,6 +1009,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
bytes_me2c_clone.as_ref(),
@@ -1218,16 +1181,23 @@ where
forensics.bytes_c2me = forensics
.bytes_c2me
.saturating_add(payload.len() as u64);
if let Some(limit) = quota_limit {
let quota_lock = quota_user_lock(&user);
let _quota_guard = quota_lock.lock().await;
stats.add_user_octets_from(&user, payload.len() as u64);
if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) {
if let (Some(limit), Some(user_stats)) =
(quota_limit, quota_user_stats.as_deref())
{
if reserve_user_quota_with_yield(
user_stats,
payload.len() as u64,
limit,
)
.await
.is_err()
{
main_result = Err(ProxyError::DataQuotaExceeded {
user: user.clone(),
});
break;
}
stats.add_user_octets_from_handle(user_stats, payload.len() as u64);
} else {
stats.add_user_octets_from(&user, payload.len() as u64);
}
@@ -1320,6 +1290,8 @@ async fn read_client_payload_with_idle_policy<R>(
where
R: AsyncRead + Unpin + Send + 'static,
{
const LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES: u32 = 4;
async fn read_exact_with_policy<R>(
client_reader: &mut CryptoReader<R>,
buf: &mut [u8],
@@ -1458,6 +1430,7 @@ where
Ok(())
}
let mut consecutive_zero_len_frames = 0u32;
loop {
let (len, quickack, raw_len_bytes) = match proto_tag {
ProtoTag::Abridged => {
@@ -1538,6 +1511,26 @@ where
};
if len == 0 {
idle_state.tiny_frame_debt = idle_state
.tiny_frame_debt
.saturating_add(TINY_FRAME_DEBT_PER_TINY);
if idle_state.tiny_frame_debt >= TINY_FRAME_DEBT_LIMIT {
stats.increment_relay_protocol_desync_close_total();
return Err(ProxyError::Proxy(format!(
"Tiny frame overhead limit exceeded: debt={}, conn_id={}",
idle_state.tiny_frame_debt, forensics.conn_id
)));
}
if !idle_policy.enabled {
consecutive_zero_len_frames = consecutive_zero_len_frames.saturating_add(1);
if consecutive_zero_len_frames > LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES {
stats.increment_relay_protocol_desync_close_total();
return Err(ProxyError::Proxy(
"Excessive zero-length abridged frames".to_string(),
));
}
}
continue;
}
if len < 4 && proto_tag != ProtoTag::Abridged {
@@ -1606,6 +1599,7 @@ where
}
*frame_counter += 1;
idle_state.on_client_frame(Instant::now());
idle_state.tiny_frame_debt = idle_state.tiny_frame_debt.saturating_sub(1);
clear_relay_idle_candidate(forensics.conn_id);
return Ok(Some((payload, quickack)));
}
@@ -1689,6 +1683,7 @@ async fn process_me_writer_response<W>(
frame_buf: &mut Vec<u8>,
stats: &Stats,
user: &str,
quota_user_stats: Option<&UserStats>,
quota_limit: Option<u64>,
quota_soft_overshoot_bytes: u64,
bytes_me2c: &AtomicU64,
@@ -1707,40 +1702,42 @@ where
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
}
let data_len = data.len() as u64;
if quota_would_be_exceeded_for_user_soft(
stats,
user,
quota_limit,
data_len,
quota_soft_overshoot_bytes,
) {
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
if let (Some(limit), Some(user_stats)) = (quota_limit, quota_user_stats) {
let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes);
if reserve_user_quota_with_yield(user_stats, data_len, soft_limit)
.await
.is_err()
{
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
}
let write_mode =
write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await?;
stats.increment_me_d2c_write_mode(write_mode);
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await
{
Ok(mode) => mode,
Err(err) => {
if quota_limit.is_some() {
stats.add_quota_write_fail_bytes_total(data_len);
stats.increment_quota_write_fail_events_total();
}
return Err(err);
}
};
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64);
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data.len() as u64);
if quota_exceeded_for_user_soft(
stats,
user,
quota_limit,
quota_soft_overshoot_bytes,
) {
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PostWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
if let Some(user_stats) = quota_user_stats {
stats.add_user_octets_to_handle(user_stats, data_len);
} else {
stats.add_user_octets_to(user, data_len);
}
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data_len);
stats.increment_me_d2c_write_mode(write_mode);
Ok(MeWriterResponseOutcome::Continue {
frames: 1,
@@ -1840,8 +1837,14 @@ where
MeD2cWriteMode::Coalesced
} else {
let header = [first];
client_writer.write_all(&header).await.map_err(ProxyError::Io)?;
client_writer.write_all(data).await.map_err(ProxyError::Io)?;
client_writer
.write_all(&header)
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Split
}
} else if len_words < (1 << 24) {
@@ -1863,8 +1866,14 @@ where
MeD2cWriteMode::Coalesced
} else {
let header = [first, lw[0], lw[1], lw[2]];
client_writer.write_all(&header).await.map_err(ProxyError::Io)?;
client_writer.write_all(data).await.map_err(ProxyError::Io)?;
client_writer
.write_all(&header)
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Split
}
} else {
@@ -1906,8 +1915,14 @@ where
MeD2cWriteMode::Coalesced
} else {
let header = len_val.to_le_bytes();
client_writer.write_all(&header).await.map_err(ProxyError::Io)?;
client_writer.write_all(data).await.map_err(ProxyError::Io)?;
client_writer
.write_all(&header)
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
if padding_len > 0 {
frame_buf.clear();
if frame_buf.capacity() < padding_len {
@@ -1947,10 +1962,6 @@ where
.map_err(ProxyError::Io)
}
#[cfg(test)]
#[path = "tests/middle_relay_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_idle_policy_security_tests.rs"]
mod idle_policy_security_tests;
@@ -1963,18 +1974,30 @@ mod desync_all_full_dedup_security_tests;
#[path = "tests/middle_relay_stub_completion_security_tests.rs"]
mod stub_completion_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_coverage_high_risk_security_tests.rs"]
mod coverage_high_risk_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_quota_overflow_lock_security_tests.rs"]
mod quota_overflow_lock_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_length_cast_hardening_security_tests.rs"]
mod length_cast_hardening_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"]
mod blackhat_campaign_integration_tests;
#[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"]
mod middle_relay_idle_registry_poison_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_zero_length_frame_security_tests.rs"]
mod middle_relay_zero_length_frame_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_tiny_frame_debt_security_tests.rs"]
mod middle_relay_tiny_frame_debt_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs"]
mod middle_relay_tiny_frame_debt_concurrency_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"]
mod middle_relay_tiny_frame_debt_proto_chunking_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_atomic_quota_invariant_tests.rs"]
mod middle_relay_atomic_quota_invariant_tests;
+50 -50
View File
@@ -4,58 +4,58 @@
#![cfg_attr(test, allow(warnings))]
#![cfg_attr(not(test), forbid(clippy::undocumented_unsafe_blocks))]
#![cfg_attr(
not(test),
deny(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::todo,
clippy::unimplemented,
clippy::correctness,
clippy::option_if_let_else,
clippy::or_fun_call,
clippy::branches_sharing_code,
clippy::single_option_map,
clippy::useless_let_if_seq,
clippy::redundant_locals,
clippy::cloned_ref_to_slice_refs,
unsafe_code,
clippy::await_holding_lock,
clippy::await_holding_refcell_ref,
clippy::debug_assert_with_mut_call,
clippy::macro_use_imports,
clippy::cast_ptr_alignment,
clippy::cast_lossless,
clippy::ptr_as_ptr,
clippy::large_stack_arrays,
clippy::same_functions_in_if_condition,
trivial_casts,
trivial_numeric_casts,
unused_extern_crates,
unused_import_braces,
rust_2018_idioms
)
not(test),
deny(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::todo,
clippy::unimplemented,
clippy::correctness,
clippy::option_if_let_else,
clippy::or_fun_call,
clippy::branches_sharing_code,
clippy::single_option_map,
clippy::useless_let_if_seq,
clippy::redundant_locals,
clippy::cloned_ref_to_slice_refs,
unsafe_code,
clippy::await_holding_lock,
clippy::await_holding_refcell_ref,
clippy::debug_assert_with_mut_call,
clippy::macro_use_imports,
clippy::cast_ptr_alignment,
clippy::cast_lossless,
clippy::ptr_as_ptr,
clippy::large_stack_arrays,
clippy::same_functions_in_if_condition,
trivial_casts,
trivial_numeric_casts,
unused_extern_crates,
unused_import_braces,
rust_2018_idioms
)
)]
#![cfg_attr(
not(test),
allow(
clippy::use_self,
clippy::redundant_closure,
clippy::too_many_arguments,
clippy::doc_markdown,
clippy::missing_const_for_fn,
clippy::unnecessary_operation,
clippy::redundant_pub_crate,
clippy::derive_partial_eq_without_eq,
clippy::type_complexity,
clippy::new_ret_no_self,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::significant_drop_tightening,
clippy::significant_drop_in_scrutinee,
clippy::float_cmp,
clippy::nursery
)
not(test),
allow(
clippy::use_self,
clippy::redundant_closure,
clippy::too_many_arguments,
clippy::doc_markdown,
clippy::missing_const_for_fn,
clippy::unnecessary_operation,
clippy::redundant_pub_crate,
clippy::derive_partial_eq_without_eq,
clippy::type_complexity,
clippy::new_ret_no_self,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::significant_drop_tightening,
clippy::significant_drop_in_scrutinee,
clippy::float_cmp,
clippy::nursery
)
)]
pub mod adaptive_buffers;
+98 -218
View File
@@ -52,13 +52,12 @@
//! - `SharedCounters` (atomics) let the watchdog read stats without locking
use crate::error::{ProxyError, Result};
use crate::stats::Stats;
use crate::stats::{Stats, UserStats};
use crate::stream::BufferPool;
use dashmap::DashMap;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes};
@@ -209,12 +208,10 @@ struct StatsIo<S> {
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
user_stats: Arc<UserStats>,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
quota_read_wake_scheduled: bool,
quota_write_wake_scheduled: bool,
quota_read_retry_active: Arc<AtomicBool>,
quota_write_retry_active: Arc<AtomicBool>,
quota_bytes_since_check: u64,
epoch: Instant,
}
@@ -230,30 +227,21 @@ impl<S> StatsIo<S> {
) -> Self {
// Mark initial activity so the watchdog doesn't fire before data flows
counters.touch(Instant::now(), epoch);
let user_stats = stats.get_or_create_user_stats_handle(&user);
Self {
inner,
counters,
stats,
user,
user_stats,
quota_limit,
quota_exceeded,
quota_read_wake_scheduled: false,
quota_write_wake_scheduled: false,
quota_read_retry_active: Arc::new(AtomicBool::new(false)),
quota_write_retry_active: Arc::new(AtomicBool::new(false)),
quota_bytes_since_check: 0,
epoch,
}
}
}
impl<S> Drop for StatsIo<S> {
fn drop(&mut self) {
self.quota_read_retry_active.store(false, Ordering::Relaxed);
self.quota_write_retry_active
.store(false, Ordering::Relaxed);
}
}
#[derive(Debug)]
struct QuotaIoSentinel;
@@ -277,84 +265,22 @@ fn is_quota_io_error(err: &io::Error) -> bool {
.is_some()
}
#[cfg(test)]
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1);
#[cfg(not(test))]
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2);
const QUOTA_NEAR_LIMIT_BYTES: u64 = 64 * 1024;
const QUOTA_LARGE_CHARGE_BYTES: u64 = 16 * 1024;
const QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES: u64 = 4 * 1024;
const QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES: u64 = 64 * 1024;
fn spawn_quota_retry_waker(retry_active: Arc<AtomicBool>, waker: std::task::Waker) {
tokio::task::spawn(async move {
loop {
if !retry_active.load(Ordering::Relaxed) {
break;
}
tokio::time::sleep(QUOTA_CONTENTION_RETRY_INTERVAL).await;
if !retry_active.load(Ordering::Relaxed) {
break;
}
waker.wake_by_ref();
}
});
#[inline]
fn quota_adaptive_interval_bytes(remaining_before: u64) -> u64 {
remaining_before.saturating_div(2).clamp(
QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES,
QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES,
)
}
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
static QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<Mutex<()>>>> = OnceLock::new();
#[cfg(test)]
const QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))]
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
#[cfg(test)]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
#[cfg(not(test))]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
#[cfg(test)]
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
quota_user_lock_test_guard()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
.map(|_| Arc::new(Mutex::new(())))
.collect()
});
let hash = crc32fast::hash(user.as_bytes()) as usize;
Arc::clone(&stripes[hash % stripes.len()])
}
fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
locks.retain(|_, value| Arc::strong_count(value) > 1);
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
return quota_overflow_user_lock(user);
}
let created = Arc::new(Mutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
#[inline]
fn should_immediate_quota_check(remaining_before: u64, charge_bytes: u64) -> bool {
remaining_before <= QUOTA_NEAR_LIMIT_BYTES || charge_bytes >= QUOTA_LARGE_CHARGE_BYTES
}
impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
@@ -364,80 +290,60 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
if this.quota_exceeded.load(Ordering::Relaxed) {
if this.quota_exceeded.load(Ordering::Acquire) {
return Poll::Ready(Err(quota_io_error()));
}
let quota_lock = this
.quota_limit
.is_some()
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => {
this.quota_read_wake_scheduled = false;
this.quota_read_retry_active.store(false, Ordering::Relaxed);
Some(guard)
}
Err(_) => {
if !this.quota_read_wake_scheduled {
this.quota_read_wake_scheduled = true;
this.quota_read_retry_active.store(true, Ordering::Relaxed);
spawn_quota_retry_waker(
Arc::clone(&this.quota_read_retry_active),
cx.waker().clone(),
);
}
return Poll::Pending;
}
let mut remaining_before = None;
if let Some(limit) = this.quota_limit {
let used_before = this.user_stats.quota_used();
let remaining = limit.saturating_sub(used_before);
if remaining == 0 {
this.quota_exceeded.store(true, Ordering::Release);
return Poll::Ready(Err(quota_io_error()));
}
} else {
None
};
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
remaining_before = Some(remaining);
}
let before = buf.filled().len();
match Pin::new(&mut this.inner).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let n = buf.filled().len() - before;
if n > 0 {
let mut reached_quota_boundary = false;
if let Some(limit) = this.quota_limit {
let used = this.stats.get_user_total_octets(&this.user);
if used >= limit {
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
let remaining = limit - used;
if (n as u64) > remaining {
// Fail closed: when a single read chunk would cross quota,
// stop relay immediately without accounting beyond the cap.
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
reached_quota_boundary = (n as u64) == remaining;
}
let n_to_charge = n as u64;
// C→S: client sent data
this.counters
.c2s_bytes
.fetch_add(n as u64, Ordering::Relaxed);
.fetch_add(n_to_charge, Ordering::Relaxed);
this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch);
this.stats.add_user_octets_from(&this.user, n as u64);
this.stats.increment_user_msgs_from(&this.user);
this.stats
.add_user_octets_from_handle(this.user_stats.as_ref(), n_to_charge);
this.stats
.increment_user_msgs_from_handle(this.user_stats.as_ref());
if reached_quota_boundary {
this.quota_exceeded.store(true, Ordering::Relaxed);
if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) {
this.stats
.quota_charge_post_write(this.user_stats.as_ref(), n_to_charge);
if should_immediate_quota_check(remaining, n_to_charge) {
this.quota_bytes_since_check = 0;
if this.user_stats.quota_used() >= limit {
this.quota_exceeded.store(true, Ordering::Release);
}
} else {
this.quota_bytes_since_check =
this.quota_bytes_since_check.saturating_add(n_to_charge);
let interval = quota_adaptive_interval_bytes(remaining);
if this.quota_bytes_since_check >= interval {
this.quota_bytes_since_check = 0;
if this.user_stats.quota_used() >= limit {
this.quota_exceeded.store(true, Ordering::Release);
}
}
}
}
trace!(user = %this.user, bytes = n, "C->S");
@@ -456,75 +362,57 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if this.quota_exceeded.load(Ordering::Relaxed) {
if this.quota_exceeded.load(Ordering::Acquire) {
return Poll::Ready(Err(quota_io_error()));
}
let quota_lock = this
.quota_limit
.is_some()
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => {
this.quota_write_wake_scheduled = false;
this.quota_write_retry_active
.store(false, Ordering::Relaxed);
Some(guard)
}
Err(_) => {
if !this.quota_write_wake_scheduled {
this.quota_write_wake_scheduled = true;
this.quota_write_retry_active.store(true, Ordering::Relaxed);
spawn_quota_retry_waker(
Arc::clone(&this.quota_write_retry_active),
cx.waker().clone(),
);
}
return Poll::Pending;
}
}
} else {
None
};
let write_buf = if let Some(limit) = this.quota_limit {
let used = this.stats.get_user_total_octets(&this.user);
if used >= limit {
this.quota_exceeded.store(true, Ordering::Relaxed);
let mut remaining_before = None;
if let Some(limit) = this.quota_limit {
let used_before = this.user_stats.quota_used();
let remaining = limit.saturating_sub(used_before);
if remaining == 0 {
this.quota_exceeded.store(true, Ordering::Release);
return Poll::Ready(Err(quota_io_error()));
}
remaining_before = Some(remaining);
}
let remaining = (limit - used) as usize;
if buf.len() > remaining {
// Fail closed: do not emit partial S->C payload when remaining
// quota cannot accommodate the pending write request.
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
buf
} else {
buf
};
match Pin::new(&mut this.inner).poll_write(cx, write_buf) {
match Pin::new(&mut this.inner).poll_write(cx, buf) {
Poll::Ready(Ok(n)) => {
if n > 0 {
let n_to_charge = n as u64;
// S→C: data written to client
this.counters
.s2c_bytes
.fetch_add(n as u64, Ordering::Relaxed);
.fetch_add(n_to_charge, Ordering::Relaxed);
this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch);
this.stats.add_user_octets_to(&this.user, n as u64);
this.stats.increment_user_msgs_to(&this.user);
this.stats
.add_user_octets_to_handle(this.user_stats.as_ref(), n_to_charge);
this.stats
.increment_user_msgs_to_handle(this.user_stats.as_ref());
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) {
this.stats
.quota_charge_post_write(this.user_stats.as_ref(), n_to_charge);
if should_immediate_quota_check(remaining, n_to_charge) {
this.quota_bytes_since_check = 0;
if this.user_stats.quota_used() >= limit {
this.quota_exceeded.store(true, Ordering::Release);
}
} else {
this.quota_bytes_since_check =
this.quota_bytes_since_check.saturating_add(n_to_charge);
let interval = quota_adaptive_interval_bytes(remaining);
if this.quota_bytes_since_check >= interval {
this.quota_bytes_since_check = 0;
if this.user_stats.quota_used() >= limit {
this.quota_exceeded.store(true, Ordering::Release);
}
}
}
}
trace!(user = %this.user, bytes = n, "S->C");
@@ -618,7 +506,7 @@ where
let now = Instant::now();
let idle = wd_counters.idle_duration(now, epoch);
if wd_quota_exceeded.load(Ordering::Relaxed) {
if wd_quota_exceeded.load(Ordering::Acquire) {
warn!(user = %wd_user, "User data quota reached, closing relay");
return;
}
@@ -756,18 +644,10 @@ where
}
}
#[cfg(test)]
#[path = "tests/relay_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "tests/relay_adversarial_tests.rs"]
mod adversarial_tests;
#[cfg(test)]
#[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"]
mod relay_quota_lock_pressure_adversarial_tests;
#[cfg(test)]
#[path = "tests/relay_quota_boundary_blackhat_tests.rs"]
mod relay_quota_boundary_blackhat_tests;
@@ -780,14 +660,14 @@ mod relay_quota_model_adversarial_tests;
#[path = "tests/relay_quota_overflow_regression_tests.rs"]
mod relay_quota_overflow_regression_tests;
#[cfg(test)]
#[path = "tests/relay_quota_extended_attack_surface_security_tests.rs"]
mod relay_quota_extended_attack_surface_security_tests;
#[cfg(test)]
#[path = "tests/relay_watchdog_delta_security_tests.rs"]
mod relay_watchdog_delta_security_tests;
#[cfg(test)]
#[path = "tests/relay_quota_waker_storm_adversarial_tests.rs"]
mod relay_quota_waker_storm_adversarial_tests;
#[cfg(test)]
#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"]
mod relay_quota_wake_liveness_regression_tests;
#[path = "tests/relay_atomic_quota_invariant_tests.rs"]
mod relay_atomic_quota_invariant_tests;
+75 -17
View File
@@ -1,5 +1,5 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType, ProxyConfig};
use crate::config::{ProxyConfig, UpstreamConfig, UpstreamType};
use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE};
use crate::stats::Stats;
use crate::transport::UpstreamManager;
@@ -41,7 +41,9 @@ fn edge_handshake_timeout_with_mask_grace_saturating_add_prevents_overflow() {
#[test]
fn edge_tls_clienthello_len_in_bounds_exact_boundaries() {
assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE));
assert!(!tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE - 1));
assert!(!tls_clienthello_len_in_bounds(
MIN_TLS_CLIENT_HELLO_SIZE - 1
));
assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE));
assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1));
}
@@ -87,7 +89,15 @@ async fn adversarial_tls_handshake_timeout_during_masking_delay() {
"198.51.100.1:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
@@ -99,7 +109,10 @@ async fn adversarial_tls_handshake_timeout_during_masking_delay() {
false,
));
client_side.write_all(&[0x16, 0x03, 0x01, 0xFF, 0xFF]).await.unwrap();
client_side
.write_all(&[0x16, 0x03, 0x01, 0xFF, 0xFF])
.await
.unwrap();
let result = tokio::time::timeout(Duration::from_secs(4), handle)
.await
@@ -123,7 +136,15 @@ async fn blackhat_proxy_protocol_slowloris_timeout() {
"198.51.100.2:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
@@ -167,7 +188,15 @@ async fn negative_proxy_protocol_enabled_but_client_sends_tls_hello() {
"198.51.100.3:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
@@ -179,7 +208,10 @@ async fn negative_proxy_protocol_enabled_but_client_sends_tls_hello() {
true,
));
client_side.write_all(&[0x16, 0x03, 0x01, 0x02, 0x00]).await.unwrap();
client_side
.write_all(&[0x16, 0x03, 0x01, 0x02, 0x00])
.await
.unwrap();
let result = tokio::time::timeout(Duration::from_secs(2), handle)
.await
@@ -202,7 +234,15 @@ async fn edge_client_stream_exactly_4_bytes_eof() {
"198.51.100.4:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
@@ -214,7 +254,10 @@ async fn edge_client_stream_exactly_4_bytes_eof() {
false,
));
client_side.write_all(&[0x16, 0x03, 0x01, 0x00]).await.unwrap();
client_side
.write_all(&[0x16, 0x03, 0x01, 0x00])
.await
.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handle).await;
@@ -234,7 +277,15 @@ async fn edge_client_stream_tls_header_valid_but_body_1_byte_short_eof() {
"198.51.100.5:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
@@ -246,7 +297,10 @@ async fn edge_client_stream_tls_header_valid_but_body_1_byte_short_eof() {
false,
));
client_side.write_all(&[0x16, 0x03, 0x01, 0x00, 100]).await.unwrap();
client_side
.write_all(&[0x16, 0x03, 0x01, 0x00, 100])
.await
.unwrap();
client_side.write_all(&vec![0x41; 99]).await.unwrap();
client_side.shutdown().await.unwrap();
@@ -269,7 +323,15 @@ async fn integration_non_tls_modes_disabled_immediately_masks() {
"198.51.100.6:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
@@ -372,11 +434,7 @@ async fn stress_user_connection_reservation_concurrent_same_ip_exhaustion() {
let ip_tracker = ip_tracker.clone();
tasks.spawn(async move {
RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats,
peer,
ip_tracker,
user, &config, stats, peer, ip_tracker,
)
.await
});
@@ -7,6 +7,11 @@ use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncWriteExt, duplex};
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
}
#[test]
fn invariant_wrap_tls_application_record_exact_multiples() {
let chunk_size = u16::MAX as usize;
@@ -37,7 +42,15 @@ async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking()
"198.51.100.20:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
@@ -60,7 +73,9 @@ async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking()
.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handler)
.await
.unwrap();
assert_eq!(stats.get_connects_bad(), 1);
}
@@ -68,7 +83,10 @@ async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking()
async fn invariant_acquire_reservation_ip_limit_rollback() {
let user = "rollback-test-user";
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 10);
config
.access
.user_max_tcp_conns
.insert(user.to_string(), 10);
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
@@ -114,7 +132,7 @@ async fn invariant_quota_exact_boundary_inclusive() {
let ip_tracker = Arc::new(UserIpTracker::new());
let peer = "198.51.100.23:55000".parse().unwrap();
stats.add_user_octets_from(user, 999);
preload_user_quota(stats.as_ref(), user, 999);
let res1 = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
@@ -126,7 +144,7 @@ async fn invariant_quota_exact_boundary_inclusive() {
assert!(res1.is_ok());
res1.unwrap().release().await;
stats.add_user_octets_from(user, 1);
preload_user_quota(stats.as_ref(), user, 1);
let res2 = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
@@ -154,7 +172,15 @@ async fn invariant_direct_mode_partial_header_eof_is_error_not_bad_connect() {
"198.51.100.25:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
@@ -0,0 +1,100 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::Duration;
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
#[tokio::test]
async fn fragmented_connect_probe_is_classified_as_http_via_prefetch_window() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
stream.read_to_end(&mut got).await.unwrap();
got
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = true;
cfg.general.beobachten_minutes = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let peer: SocketAddr = "198.51.100.251:57501".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
new_upstream_manager(stats),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
client_side.write_all(b"CONNE").await.unwrap();
client_side
.write_all(b"CT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n")
.await
.unwrap();
client_side.shutdown().await.unwrap();
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
assert!(
forwarded.starts_with(b"CONNECT example.org:443 HTTP/1.1"),
"mask backend must receive the full fragmented CONNECT probe"
);
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[HTTP]"));
assert!(snapshot.contains("198.51.100.251-1"));
}
@@ -0,0 +1,122 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, sleep};
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
async fn run_http2_fragment_case(split_at: usize, delay_ms: u64, peer: SocketAddr) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
stream.read_to_end(&mut got).await.unwrap();
got
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = true;
cfg.general.beobachten_minutes = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
new_upstream_manager(stats),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
let first = split_at.min(preface.len());
client_side.write_all(&preface[..first]).await.unwrap();
if first < preface.len() {
sleep(Duration::from_millis(delay_ms)).await;
client_side.write_all(&preface[first..]).await.unwrap();
}
client_side.shutdown().await.unwrap();
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
assert!(
forwarded.starts_with(&preface),
"mask backend must receive an intact HTTP/2 preface prefix"
);
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[HTTP]"));
assert!(snapshot.contains(&format!("{}-1", peer.ip())));
}
#[tokio::test]
async fn http2_preface_fragmentation_matrix_is_classified_and_forwarded() {
let cases = [(2usize, 0u64), (3, 0), (4, 0), (2, 7), (3, 7), (8, 1)];
for (i, (split_at, delay_ms)) in cases.into_iter().enumerate() {
let peer: SocketAddr = format!("198.51.100.{}:58{}", 140 + i, 100 + i)
.parse()
.unwrap();
run_http2_fragment_case(split_at, delay_ms, peer).await;
}
}
#[tokio::test]
async fn http2_preface_splitpoint_light_fuzz_classifies_http() {
for split_at in 2usize..=12 {
let delay_ms = if split_at % 3 == 0 { 7 } else { 1 };
let peer: SocketAddr = format!("198.51.101.{}:59{}", split_at, 10 + split_at)
.parse()
.unwrap();
run_http2_fragment_case(split_at, delay_ms, peer).await;
}
}
@@ -0,0 +1,150 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, sleep};
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
async fn run_pipeline_prefetch_case(
prefetch_timeout_ms: u64,
delayed_tail_ms: u64,
peer: SocketAddr,
) -> (Vec<u8>, String) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
stream.read_to_end(&mut got).await.unwrap();
got
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = true;
cfg.general.beobachten_minutes = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_classifier_prefetch_timeout_ms = prefetch_timeout_ms;
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
new_upstream_manager(stats),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
client_side.write_all(b"C").await.unwrap();
sleep(Duration::from_millis(delayed_tail_ms)).await;
client_side
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n")
.await
.unwrap();
client_side.shutdown().await.unwrap();
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
(forwarded, snapshot)
}
#[tokio::test]
async fn tdd_pipeline_prefetch_5ms_misses_15ms_tail_and_classifies_as_port_scanner() {
let peer: SocketAddr = "198.51.100.171:58071".parse().unwrap();
let (forwarded, snapshot) = run_pipeline_prefetch_case(5, 15, peer).await;
assert!(
forwarded.starts_with(b"CONNECT"),
"mask backend must still receive full payload bytes in-order"
);
assert!(
snapshot.contains("[HTTP]") || snapshot.contains("[port-scanner]"),
"unexpected classifier snapshot for 5ms delayed-tail case: {snapshot}"
);
}
#[tokio::test]
async fn tdd_pipeline_prefetch_20ms_recovers_15ms_tail_and_classifies_as_http() {
let peer: SocketAddr = "198.51.100.172:58072".parse().unwrap();
let (forwarded, snapshot) = run_pipeline_prefetch_case(20, 15, peer).await;
assert!(
forwarded.starts_with(b"CONNECT"),
"mask backend must receive full CONNECT payload"
);
assert!(
snapshot.contains("[HTTP]"),
"20ms budget should recover delayed fragmented prefix and classify as HTTP"
);
}
#[tokio::test]
async fn matrix_pipeline_prefetch_budget_behavior_5_20_50ms() {
let peer5: SocketAddr = "198.51.100.173:58073".parse().unwrap();
let peer20: SocketAddr = "198.51.100.174:58074".parse().unwrap();
let peer50: SocketAddr = "198.51.100.175:58075".parse().unwrap();
let (_, snap5) = run_pipeline_prefetch_case(5, 35, peer5).await;
let (_, snap20) = run_pipeline_prefetch_case(20, 35, peer20).await;
let (_, snap50) = run_pipeline_prefetch_case(50, 35, peer50).await;
assert!(
snap5.contains("[HTTP]") || snap5.contains("[port-scanner]"),
"unexpected 5ms snapshot: {snap5}"
);
assert!(
snap20.contains("[HTTP]") || snap20.contains("[port-scanner]"),
"unexpected 20ms snapshot: {snap20}"
);
assert!(snap50.contains("[HTTP]"));
}
@@ -0,0 +1,88 @@
use super::*;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, sleep};
#[test]
fn prefetch_timeout_budget_reads_from_config() {
let mut cfg = ProxyConfig::default();
assert_eq!(
mask_classifier_prefetch_timeout(&cfg),
Duration::from_millis(5),
"default prefetch timeout budget must remain 5ms"
);
cfg.censorship.mask_classifier_prefetch_timeout_ms = 20;
assert_eq!(
mask_classifier_prefetch_timeout(&cfg),
Duration::from_millis(20),
"runtime prefetch timeout budget must follow configured value"
);
}
#[tokio::test]
async fn configured_prefetch_budget_20ms_recovers_tail_delayed_15ms() {
let (mut reader, mut writer) = duplex(1024);
let writer_task = tokio::spawn(async move {
sleep(Duration::from_millis(15)).await;
writer
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
.await
.expect("tail bytes must be writable");
writer
.shutdown()
.await
.expect("writer shutdown must succeed");
});
let mut initial_data = b"C".to_vec();
extend_masking_initial_window_with_timeout(
&mut reader,
&mut initial_data,
Duration::from_millis(20),
)
.await;
writer_task
.await
.expect("writer task must not panic in runtime timeout test");
assert!(
initial_data.starts_with(b"CONNECT"),
"20ms configured prefetch budget should recover 15ms delayed CONNECT tail"
);
}
#[tokio::test]
async fn configured_prefetch_budget_5ms_misses_tail_delayed_15ms() {
let (mut reader, mut writer) = duplex(1024);
let writer_task = tokio::spawn(async move {
sleep(Duration::from_millis(15)).await;
writer
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
.await
.expect("tail bytes must be writable");
writer
.shutdown()
.await
.expect("writer shutdown must succeed");
});
let mut initial_data = b"C".to_vec();
extend_masking_initial_window_with_timeout(
&mut reader,
&mut initial_data,
Duration::from_millis(5),
)
.await;
writer_task
.await
.expect("writer task must not panic in runtime timeout test");
assert!(
!initial_data.starts_with(b"CONNECT"),
"5ms configured prefetch budget should miss 15ms delayed CONNECT tail"
);
}
@@ -0,0 +1,264 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION};
use crate::protocol::tls;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
struct PipelineHarness {
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
route_runtime: Arc<RouteRuntimeController>,
ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>,
}
fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = mask_port;
cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true;
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
PipelineHarness {
config,
stats,
upstream_manager,
replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))),
buffer_pool: Arc::new(BufferPool::new()),
rng: Arc::new(SecureRandom::new()),
route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
ip_tracker: Arc::new(UserIpTracker::new()),
beobachten: Arc::new(BeobachtenStore::new()),
}
}
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn wrap_tls_application_data(payload: &[u8]) -> Vec<u8> {
let mut record = Vec::with_capacity(5 + payload.len());
record.push(0x17);
record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&(payload.len() as u16).to_be_bytes());
record.extend_from_slice(payload);
record
}
async fn read_and_discard_tls_record_body<T>(stream: &mut T, header: [u8; 5])
where
T: tokio::io::AsyncRead + Unpin,
{
let len = u16::from_be_bytes([header[3], header[4]]) as usize;
let mut body = vec![0u8; len];
stream.read_exact(&mut body).await.unwrap();
}
#[test]
fn empty_initial_data_prefetch_gate_is_fail_closed() {
assert!(
!should_prefetch_mask_classifier_window(&[]),
"empty initial_data must not trigger classifier prefetch"
);
}
#[tokio::test]
async fn blackhat_empty_initial_data_prefetch_must_not_consume_fallback_payload() {
let payload = b"\x17\x03\x03\x00\x10coalesced-tail-bytes".to_vec();
let (mut reader, mut writer) = duplex(1024);
writer.write_all(&payload).await.unwrap();
writer.shutdown().await.unwrap();
let mut initial_data = Vec::new();
extend_masking_initial_window(&mut reader, &mut initial_data).await;
assert!(
initial_data.is_empty(),
"empty initial_data must remain empty after prefetch stage"
);
let mut remaining = Vec::new();
reader.read_to_end(&mut remaining).await.unwrap();
assert_eq!(
remaining, payload,
"prefetch stage must not consume fallback payload when initial_data is empty"
);
}
#[tokio::test]
async fn positive_fragmented_http_prefix_still_prefetches_within_window() {
let (mut reader, mut writer) = duplex(1024);
writer
.write_all(b"NECT example.org:443 HTTP/1.1\r\n")
.await
.unwrap();
writer.shutdown().await.unwrap();
let mut initial_data = b"CON".to_vec();
extend_masking_initial_window(&mut reader, &mut initial_data).await;
assert!(
initial_data.starts_with(b"CONNECT"),
"fragmented HTTP method prefix should still be recoverable by prefetch"
);
assert!(
initial_data.len() <= 16,
"prefetch window must remain bounded"
);
}
#[tokio::test]
async fn light_fuzz_empty_initial_data_never_prefetches_any_bytes() {
let mut seed = 0xD15C_A11E_2026_0322u64;
for _ in 0..128 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let len = ((seed & 0x3f) as usize).saturating_add(1);
let mut payload = vec![0u8; len];
for (idx, byte) in payload.iter_mut().enumerate() {
*byte = (seed as u8).wrapping_add(idx as u8).wrapping_mul(17);
}
let (mut reader, mut writer) = duplex(1024);
writer.write_all(&payload).await.unwrap();
writer.shutdown().await.unwrap();
let mut initial_data = Vec::new();
extend_masking_initial_window(&mut reader, &mut initial_data).await;
assert!(initial_data.is_empty());
let mut remaining = Vec::new();
reader.read_to_end(&mut remaining).await.unwrap();
assert_eq!(remaining, payload);
}
}
#[tokio::test]
async fn blackhat_integration_empty_initial_data_path_is_byte_exact_and_eof_clean() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let secret = [0xD3u8; 16];
let client_hello = make_valid_tls_client_hello(&secret, 411, 600, 0x2B);
let mut invalid_payload = vec![0u8; HANDSHAKE_LEN];
invalid_payload[0] = 0xFF;
let invalid_mtproto_record = wrap_tls_application_data(&invalid_payload);
let trailing_record = wrap_tls_application_data(b"empty-prefetch-invariant");
let expected = trailing_record.clone();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; expected.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, expected);
let mut one = [0u8; 1];
let n = stream.read(&mut one).await.unwrap();
assert_eq!(
n, 0,
"fallback stream must not append synthetic bytes on empty initial_data path"
);
});
let harness = build_harness("d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", backend_addr.port());
let (server_side, mut client_side) = duplex(131072);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.245:56145".parse().unwrap(),
harness.config,
harness.stats,
harness.upstream_manager,
harness.replay_checker,
harness.buffer_pool,
harness.rng,
None,
harness.route_runtime,
None,
harness.ip_tracker,
harness.beobachten,
false,
));
client_side.write_all(&client_hello).await.unwrap();
let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, head).await;
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
client_side.shutdown().await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
}
@@ -0,0 +1,72 @@
use super::*;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, advance, sleep};
async fn run_strict_prefetch_case(prefetch_ms: u64, tail_delay_ms: u64) -> Vec<u8> {
let (mut reader, mut writer) = duplex(1024);
let writer_task = tokio::spawn(async move {
sleep(Duration::from_millis(tail_delay_ms)).await;
let _ = writer
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
.await;
let _ = writer.shutdown().await;
});
let mut initial_data = b"C".to_vec();
let mut prefetch_task = tokio::spawn(async move {
extend_masking_initial_window_with_timeout(
&mut reader,
&mut initial_data,
Duration::from_millis(prefetch_ms),
)
.await;
initial_data
});
tokio::task::yield_now().await;
if tail_delay_ms > 0 {
advance(Duration::from_millis(tail_delay_ms)).await;
tokio::task::yield_now().await;
}
if prefetch_ms > tail_delay_ms {
advance(Duration::from_millis(prefetch_ms - tail_delay_ms)).await;
tokio::task::yield_now().await;
}
let result = prefetch_task.await.expect("prefetch task must not panic");
writer_task.await.expect("writer task must not panic");
result
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_5ms_misses_15ms_tail() {
let got = run_strict_prefetch_case(5, 15).await;
assert_eq!(got, b"C".to_vec());
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_20ms_recovers_15ms_tail() {
let got = run_strict_prefetch_case(20, 15).await;
assert!(got.starts_with(b"CONNECT"));
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_50ms_recovers_35ms_tail() {
let got = run_strict_prefetch_case(50, 35).await;
assert!(got.starts_with(b"CONNECT"));
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_equal_budget_and_delay_recovers_tail() {
let got = run_strict_prefetch_case(20, 20).await;
assert!(got.starts_with(b"CONNECT"));
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_one_ms_after_budget_misses_tail() {
let got = run_strict_prefetch_case(20, 21).await;
assert_eq!(got, b"C".to_vec());
}
@@ -0,0 +1,98 @@
use super::*;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex};
use tokio::time::{Duration, sleep, timeout};
async fn extend_masking_initial_window_with_budget<R>(
reader: &mut R,
initial_data: &mut Vec<u8>,
prefetch_timeout: Duration,
) where
R: AsyncRead + Unpin,
{
if !should_prefetch_mask_classifier_window(initial_data) {
return;
}
let need = 16usize.saturating_sub(initial_data.len());
if need == 0 {
return;
}
let mut extra = [0u8; 16];
if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await
&& n > 0
{
initial_data.extend_from_slice(&extra[..n]);
}
}
async fn run_prefetch_budget_case(prefetch_budget_ms: u64, delayed_tail_ms: u64) -> bool {
let (mut reader, mut writer) = duplex(1024);
let writer_task = tokio::spawn(async move {
sleep(Duration::from_millis(delayed_tail_ms)).await;
writer
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
.await
.expect("tail bytes must be writable");
writer
.shutdown()
.await
.expect("writer shutdown must succeed");
});
let mut initial_data = b"C".to_vec();
extend_masking_initial_window_with_budget(
&mut reader,
&mut initial_data,
Duration::from_millis(prefetch_budget_ms),
)
.await;
writer_task
.await
.expect("writer task must not panic during matrix case");
initial_data.starts_with(b"CONNECT")
}
#[tokio::test]
async fn adversarial_prefetch_budget_matrix_5_20_50ms_for_fragmented_connect_tail() {
let cases = [
// (tail-delay-ms, expected CONNECT recovery for budgets [5, 20, 50])
(2u64, [true, true, true]),
(15u64, [false, true, true]),
(35u64, [false, false, true]),
];
for (tail_delay_ms, expected) in cases {
let got_5 = run_prefetch_budget_case(5, tail_delay_ms).await;
let got_20 = run_prefetch_budget_case(20, tail_delay_ms).await;
let got_50 = run_prefetch_budget_case(50, tail_delay_ms).await;
assert_eq!(
got_5, expected[0],
"5ms prefetch budget mismatch for tail delay {}ms",
tail_delay_ms
);
assert_eq!(
got_20, expected[1],
"20ms prefetch budget mismatch for tail delay {}ms",
tail_delay_ms
);
assert_eq!(
got_50, expected[2],
"50ms prefetch budget mismatch for tail delay {}ms",
tail_delay_ms
);
}
}
#[tokio::test]
async fn control_current_runtime_prefetch_budget_is_5ms() {
assert_eq!(
MASK_CLASSIFIER_PREFETCH_TIMEOUT,
Duration::from_millis(5),
"matrix assumptions require current runtime prefetch budget to stay at 5ms"
);
}
@@ -0,0 +1,167 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION};
use crate::protocol::tls;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant};
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
async fn run_replay_candidate_session(
replay_checker: Arc<ReplayChecker>,
hello: &[u8],
peer: SocketAddr,
drive_mtproto_fail: bool,
) -> Duration {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = 1;
cfg.censorship.mask_timing_normalization_enabled = false;
cfg.access.ignore_time_skew = true;
cfg.access.users.insert(
"user".to_string(),
"abababababababababababababababab".to_string(),
);
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(65536);
let started = Instant::now();
let task = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
new_upstream_manager(stats),
replay_checker,
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten,
false,
));
client_side.write_all(hello).await.unwrap();
if drive_mtproto_fail {
let mut server_hello_head = [0u8; 5];
client_side
.read_exact(&mut server_hello_head)
.await
.unwrap();
assert_eq!(server_hello_head[0], 0x16);
let body_len = u16::from_be_bytes([server_hello_head[3], server_hello_head[4]]) as usize;
let mut body = vec![0u8; body_len];
client_side.read_exact(&mut body).await.unwrap();
let mut invalid_mtproto_record = Vec::with_capacity(5 + HANDSHAKE_LEN);
invalid_mtproto_record.push(0x17);
invalid_mtproto_record.extend_from_slice(&TLS_VERSION);
invalid_mtproto_record.extend_from_slice(&(HANDSHAKE_LEN as u16).to_be_bytes());
invalid_mtproto_record.extend_from_slice(&vec![0u8; HANDSHAKE_LEN]);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side
.write_all(b"GET /replay-fallback HTTP/1.1\r\nHost: x\r\n\r\n")
.await
.unwrap();
}
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(4), task)
.await
.unwrap()
.unwrap();
started.elapsed()
}
#[tokio::test]
async fn replay_reject_still_honors_masking_timing_budget() {
let replay_checker = Arc::new(ReplayChecker::new(256, Duration::from_secs(60)));
let hello = make_valid_tls_client_hello(&[0xAB; 16], 7, 600, 0x51);
let seed_elapsed = run_replay_candidate_session(
Arc::clone(&replay_checker),
&hello,
"198.51.100.201:58001".parse().unwrap(),
true,
)
.await;
assert!(
seed_elapsed >= Duration::from_millis(40) && seed_elapsed < Duration::from_millis(250),
"seed replay-candidate run must honor masking timing budget without unbounded delay"
);
let replay_elapsed = run_replay_candidate_session(
Arc::clone(&replay_checker),
&hello,
"198.51.100.202:58002".parse().unwrap(),
false,
)
.await;
assert!(
replay_elapsed >= Duration::from_millis(40) && replay_elapsed < Duration::from_millis(250),
"replay rejection path must still satisfy masking timing budget without unbounded DB/CPU delay"
);
}
+49 -18
View File
@@ -6,6 +6,11 @@ use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
}
#[tokio::test]
async fn edge_mask_delay_bypassed_if_max_is_zero() {
let mut config = ProxyConfig::default();
@@ -42,17 +47,13 @@ async fn boundary_user_data_quota_exact_match_rejects() {
config.access.user_data_quota.insert(user.to_string(), 1024);
let stats = Arc::new(Stats::new());
stats.add_user_octets_from(user, 1024);
preload_user_quota(stats.as_ref(), user, 1024);
let ip_tracker = Arc::new(UserIpTracker::new());
let peer = "198.51.100.10:55000".parse().unwrap();
let result = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats,
peer,
ip_tracker,
user, &config, stats, peer, ip_tracker,
)
.await;
@@ -74,11 +75,7 @@ async fn boundary_user_expiration_in_past_rejects() {
let peer = "198.51.100.11:55000".parse().unwrap();
let result = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats,
peer,
ip_tracker,
user, &config, stats, peer, ip_tracker,
)
.await;
@@ -98,7 +95,15 @@ async fn blackhat_proxy_protocol_massive_garbage_rejected_quickly() {
"198.51.100.12:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
@@ -136,7 +141,15 @@ async fn edge_tls_body_immediate_eof_triggers_masking_and_bad_connect() {
"198.51.100.13:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
@@ -148,10 +161,15 @@ async fn edge_tls_body_immediate_eof_triggers_masking_and_bad_connect() {
false,
));
client_side.write_all(&[0x16, 0x03, 0x01, 0x00, 100]).await.unwrap();
client_side
.write_all(&[0x16, 0x03, 0x01, 0x00, 100])
.await
.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handler)
.await
.unwrap();
assert_eq!(stats.get_connects_bad(), 1);
}
@@ -172,7 +190,15 @@ async fn security_classic_mode_disabled_masks_valid_length_payload() {
"198.51.100.15:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(vec![], 1, 1, 1, 1, false, stats.clone())),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
@@ -187,7 +213,9 @@ async fn security_classic_mode_disabled_masks_valid_length_payload() {
client_side.write_all(&vec![0xEF; 64]).await.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handler).await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handler)
.await
.unwrap();
assert_eq!(stats.get_connects_bad(), 1);
}
@@ -195,7 +223,10 @@ async fn security_classic_mode_disabled_masks_valid_length_payload() {
async fn concurrency_ip_tracker_strict_limit_one_rapid_churn() {
let user = "rapid-churn-user";
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 10);
config
.access
.user_max_tcp_conns
.insert(user.to_string(), 10);
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
+34 -11
View File
@@ -7,9 +7,9 @@ use crate::protocol::tls;
use crate::proxy::handshake::HandshakeSuccess;
use crate::stream::{CryptoReader, CryptoWriter};
use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
use rand::rngs::StdRng;
use rand::Rng;
use rand::SeedableRng;
use rand::rngs::StdRng;
use std::net::Ipv4Addr;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::{TcpListener, TcpStream};
@@ -34,7 +34,10 @@ fn handshake_timeout_with_mask_grace_includes_mask_margin() {
config.timeouts.client_handshake = 2;
config.censorship.mask = false;
assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_secs(2));
assert_eq!(
handshake_timeout_with_mask_grace(&config),
Duration::from_secs(2)
);
config.censorship.mask = true;
assert_eq!(
@@ -86,7 +89,10 @@ impl tokio::io::AsyncRead for ErrorReader {
_cx: &mut std::task::Context<'_>,
_buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "fake error")))
std::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"fake error",
)))
}
}
@@ -124,7 +130,10 @@ fn handshake_timeout_without_mask_is_exact_base() {
config.timeouts.client_handshake = 7;
config.censorship.mask = false;
assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_secs(7));
assert_eq!(
handshake_timeout_with_mask_grace(&config),
Duration::from_secs(7)
);
}
#[test]
@@ -133,7 +142,10 @@ fn handshake_timeout_mask_enabled_adds_750ms() {
config.timeouts.client_handshake = 3;
config.censorship.mask = true;
assert_eq!(handshake_timeout_with_mask_grace(&config), Duration::from_millis(3750));
assert_eq!(
handshake_timeout_with_mask_grace(&config),
Duration::from_millis(3750)
);
}
#[tokio::test]
@@ -155,10 +167,12 @@ async fn read_with_progress_fragmented_io_works_over_multiple_calls() {
let mut b = vec![0u8; chunk_size];
let n = read_with_progress(&mut cursor, &mut b).await.unwrap();
result.extend_from_slice(&b[..n]);
if n == 0 { break; }
if n == 0 {
break;
}
}
assert_eq!(result, vec![1,2,3,4,5]);
assert_eq!(result, vec![1, 2, 3, 4, 5]);
}
#[tokio::test]
@@ -174,7 +188,9 @@ async fn read_with_progress_stress_randomized_chunk_sizes() {
let mut b = vec![0u8; chunk];
let read = read_with_progress(&mut cursor, &mut b).await.unwrap();
collected.extend_from_slice(&b[..read]);
if read == 0 { break; }
if read == 0 {
break;
}
}
assert_eq!(collected, input);
@@ -215,10 +231,12 @@ fn wrap_tls_application_record_roundtrip_size_check() {
let mut consumed = 0;
while idx + 5 <= wrapped.len() {
assert_eq!(wrapped[idx], 0x17);
let len = u16::from_be_bytes([wrapped[idx+3], wrapped[idx+4]]) as usize;
let len = u16::from_be_bytes([wrapped[idx + 3], wrapped[idx + 4]]) as usize;
consumed += len;
idx += 5 + len;
if idx >= wrapped.len() { break; }
if idx >= wrapped.len() {
break;
}
}
assert_eq!(consumed, payload_len);
@@ -242,6 +260,11 @@ where
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
}
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
}
#[tokio::test]
async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() {
let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new());
@@ -3040,7 +3063,7 @@ async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() {
.insert("user".to_string(), 1024);
let stats = Stats::new();
stats.add_user_octets_from("user", 1024);
preload_user_quota(&stats, "user", 1024);
let ip_tracker = UserIpTracker::new();
let peer_addr: SocketAddr = "203.0.113.211:50001".parse().unwrap();
@@ -25,13 +25,26 @@ fn wrap_tls_application_record_oversized_payload_is_chunked_without_truncation()
let len = u16::from_be_bytes([record[offset + 3], record[offset + 4]]) as usize;
let body_start = offset + 5;
let body_end = body_start + len;
assert!(body_end <= record.len(), "declared TLS record length must be in-bounds");
assert!(
body_end <= record.len(),
"declared TLS record length must be in-bounds"
);
recovered.extend_from_slice(&record[body_start..body_end]);
offset = body_end;
frames += 1;
}
assert_eq!(offset, record.len(), "record parser must consume exact output size");
assert_eq!(frames, 2, "oversized payload should split into exactly two records");
assert_eq!(recovered, payload, "chunked records must preserve full payload");
assert_eq!(
offset,
record.len(),
"record parser must consume exact output size"
);
assert_eq!(
frames, 2,
"oversized payload should split into exactly two records"
);
assert_eq!(
recovered, payload,
"chunked records must preserve full payload"
);
}
+23 -16
View File
@@ -773,8 +773,7 @@ fn anchored_open_nix_path_writes_expected_lines() {
"target/telemt-unknown-dc-anchored-open-ok-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let _ = fs::remove_file(&sanitized.resolved_path);
let mut first = open_unknown_dc_log_append_anchored(&sanitized)
@@ -787,7 +786,10 @@ fn anchored_open_nix_path_writes_expected_lines() {
let content =
fs::read_to_string(&sanitized.resolved_path).expect("anchored log file must be readable");
let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect();
let lines: Vec<&str> = content
.lines()
.filter(|line| !line.trim().is_empty())
.collect();
assert_eq!(lines.len(), 2, "expected one line per anchored append call");
assert!(
lines.contains(&"dc_idx=31200") && lines.contains(&"dc_idx=31201"),
@@ -811,8 +813,7 @@ fn anchored_open_parallel_appends_preserve_line_integrity() {
"target/telemt-unknown-dc-anchored-open-parallel-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let _ = fs::remove_file(&sanitized.resolved_path);
let mut workers = Vec::new();
@@ -831,8 +832,15 @@ fn anchored_open_parallel_appends_preserve_line_integrity() {
let content =
fs::read_to_string(&sanitized.resolved_path).expect("parallel log file must be readable");
let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect();
assert_eq!(lines.len(), 64, "expected one complete line per worker append");
let lines: Vec<&str> = content
.lines()
.filter(|line| !line.trim().is_empty())
.collect();
assert_eq!(
lines.len(),
64,
"expected one complete line per worker append"
);
for line in lines {
assert!(
line.starts_with("dc_idx="),
@@ -867,8 +875,7 @@ fn anchored_open_creates_private_0600_file_permissions() {
"target/telemt-unknown-dc-anchored-perms-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let _ = fs::remove_file(&sanitized.resolved_path);
let mut file = open_unknown_dc_log_append_anchored(&sanitized)
@@ -905,8 +912,7 @@ fn anchored_open_rejects_existing_symlink_target() {
"target/telemt-unknown-dc-anchored-symlink-target-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let outside = std::env::temp_dir().join(format!(
"telemt-unknown-dc-anchored-symlink-outside-{}.log",
@@ -943,8 +949,7 @@ fn anchored_open_high_contention_multi_write_preserves_complete_lines() {
"target/telemt-unknown-dc-anchored-contention-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let _ = fs::remove_file(&sanitized.resolved_path);
let workers = 24usize;
@@ -970,7 +975,10 @@ fn anchored_open_high_contention_multi_write_preserves_complete_lines() {
let content = fs::read_to_string(&sanitized.resolved_path)
.expect("contention output file must be readable");
let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect();
let lines: Vec<&str> = content
.lines()
.filter(|line| !line.trim().is_empty())
.collect();
assert_eq!(
lines.len(),
workers * rounds,
@@ -1014,8 +1022,7 @@ fn append_unknown_dc_line_returns_error_for_read_only_descriptor() {
"target/telemt-unknown-dc-append-ro-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
fs::write(&sanitized.resolved_path, "seed\n").expect("seed file must be writable");
let mut readonly = std::fs::OpenOptions::new()
@@ -1,5 +1,5 @@
use super::*;
use crate::crypto::{sha256, sha256_hmac, AesCtr};
use crate::crypto::{AesCtr, sha256, sha256_hmac};
use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
@@ -175,7 +175,10 @@ async fn tls_minimum_viable_length_boundary() {
None,
)
.await;
assert!(matches!(res, HandshakeResult::Success(_)), "Exact minimum length TLS handshake must succeed");
assert!(
matches!(res, HandshakeResult::Success(_)),
"Exact minimum length TLS handshake must succeed"
);
let short_handshake = vec![0x42u8; min_len - 1];
let res_short = handle_tls_handshake(
@@ -189,7 +192,10 @@ async fn tls_minimum_viable_length_boundary() {
None,
)
.await;
assert!(matches!(res_short, HandshakeResult::BadClient { .. }), "Handshake 1 byte shorter than minimum must fail closed");
assert!(
matches!(res_short, HandshakeResult::BadClient { .. }),
"Handshake 1 byte shorter than minimum must fail closed"
);
}
#[tokio::test]
@@ -219,9 +225,16 @@ async fn mtproto_extreme_dc_index_serialization() {
match res {
HandshakeResult::Success((_, _, success)) => {
assert_eq!(success.dc_idx, extreme_dc, "Extreme DC index {} must serialize/deserialize perfectly", extreme_dc);
assert_eq!(
success.dc_idx, extreme_dc,
"Extreme DC index {} must serialize/deserialize perfectly",
extreme_dc
);
}
_ => panic!("MTProto handshake with extreme DC index {} failed", extreme_dc),
_ => panic!(
"MTProto handshake with extreme DC index {} failed",
extreme_dc
),
}
}
}
@@ -253,7 +266,11 @@ async fn alpn_strict_case_and_padding_rejection() {
None,
)
.await;
assert!(matches!(res, HandshakeResult::BadClient { .. }), "ALPN strict enforcement must reject {:?}", bad_alpn);
assert!(
matches!(res, HandshakeResult::BadClient { .. }),
"ALPN strict enforcement must reject {:?}",
bad_alpn
);
}
}
@@ -265,8 +282,15 @@ fn ipv4_mapped_ipv6_bucketing_anomaly() {
let norm_1 = normalize_auth_probe_ip(ipv4_mapped_1);
let norm_2 = normalize_auth_probe_ip(ipv4_mapped_2);
assert_eq!(norm_1, norm_2, "IPv4-mapped IPv6 addresses must collapse into the same /64 bucket (::0)");
assert_eq!(norm_1, IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), "The bucket must be exactly ::0");
assert_eq!(
norm_1, norm_2,
"IPv4-mapped IPv6 addresses must collapse into the same /64 bucket (::0)"
);
assert_eq!(
norm_1,
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
"The bucket must be exactly ::0"
);
}
// --- Category 2: Adversarial & Black Hat ---
@@ -309,7 +333,10 @@ async fn mtproto_invalid_ciphertext_does_not_poison_replay_cache() {
None,
)
.await;
assert!(matches!(res_valid, HandshakeResult::Success(_)), "Invalid MTProto ciphertext must not poison the replay cache");
assert!(
matches!(res_valid, HandshakeResult::Success(_)),
"Invalid MTProto ciphertext must not poison the replay cache"
);
}
#[tokio::test]
@@ -352,7 +379,10 @@ async fn tls_invalid_session_does_not_poison_replay_cache() {
None,
)
.await;
assert!(matches!(res_valid, HandshakeResult::Success(_)), "Invalid TLS payload must not poison the replay cache");
assert!(
matches!(res_valid, HandshakeResult::Success(_)),
"Invalid TLS payload must not poison the replay cache"
);
}
#[tokio::test]
@@ -387,7 +417,10 @@ async fn server_hello_delay_timing_neutrality_on_hmac_failure() {
let elapsed = start.elapsed();
assert!(matches!(res, HandshakeResult::BadClient { .. }));
assert!(elapsed >= Duration::from_millis(45), "Invalid HMAC must still incur the configured ServerHello delay to prevent timing side-channels");
assert!(
elapsed >= Duration::from_millis(45),
"Invalid HMAC must still incur the configured ServerHello delay to prevent timing side-channels"
);
}
#[tokio::test]
@@ -421,7 +454,10 @@ async fn server_hello_delay_inversion_resilience() {
let elapsed = start.elapsed();
assert!(matches!(res, HandshakeResult::Success(_)));
assert!(elapsed >= Duration::from_millis(90), "Delay logic must gracefully handle min > max inversions via max.max(min)");
assert!(
elapsed >= Duration::from_millis(90),
"Delay logic must gracefully handle min > max inversions via max.max(min)"
);
}
#[tokio::test]
@@ -436,10 +472,16 @@ async fn mixed_valid_and_invalid_user_secrets_configuration() {
for i in 0..9 {
let bad_secret = if i % 2 == 0 { "badhex!" } else { "1122" };
config.access.users.insert(format!("bad_user_{}", i), bad_secret.to_string());
config
.access
.users
.insert(format!("bad_user_{}", i), bad_secret.to_string());
}
let valid_secret_hex = "99999999999999999999999999999999";
config.access.users.insert("good_user".to_string(), valid_secret_hex.to_string());
config
.access
.users
.insert("good_user".to_string(), valid_secret_hex.to_string());
config.general.modes.secure = true;
config.general.modes.classic = true;
config.general.modes.tls = true;
@@ -463,7 +505,10 @@ async fn mixed_valid_and_invalid_user_secrets_configuration() {
)
.await;
assert!(matches!(res, HandshakeResult::Success(_)), "Proxy must gracefully skip invalid secrets and authenticate the valid one");
assert!(
matches!(res, HandshakeResult::Success(_)),
"Proxy must gracefully skip invalid secrets and authenticate the valid one"
);
}
#[tokio::test]
@@ -494,7 +539,10 @@ async fn tls_emulation_fallback_when_cache_missing() {
)
.await;
assert!(matches!(res, HandshakeResult::Success(_)), "TLS emulation must gracefully fall back to standard ServerHello if cache is missing");
assert!(
matches!(res, HandshakeResult::Success(_)),
"TLS emulation must gracefully fall back to standard ServerHello if cache is missing"
);
}
#[tokio::test]
@@ -524,7 +572,10 @@ async fn classic_mode_over_tls_transport_protocol_confusion() {
)
.await;
assert!(matches!(res, HandshakeResult::Success(_)), "Intermediate tag over TLS must succeed if classic mode is enabled, locking in cross-transport behavior");
assert!(
matches!(res, HandshakeResult::Success(_)),
"Intermediate tag over TLS must succeed if classic mode is enabled, locking in cross-transport behavior"
);
}
#[test]
@@ -543,9 +594,15 @@ fn generate_tg_nonce_never_emits_reserved_bytes() {
false,
);
assert!(!RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]), "Nonce must never start with reserved bytes");
assert!(
!RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]),
"Nonce must never start with reserved bytes"
);
let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]];
assert!(!RESERVED_NONCE_BEGINNINGS.contains(&first_four), "Nonce must never match reserved 4-byte beginnings");
assert!(
!RESERVED_NONCE_BEGINNINGS.contains(&first_four),
"Nonce must never match reserved 4-byte beginnings"
);
}
}
@@ -568,11 +625,18 @@ async fn dashmap_concurrent_saturation_stress() {
}
for task in tasks {
task.await.expect("Task panicked during concurrent DashMap stress");
task.await
.expect("Task panicked during concurrent DashMap stress");
}
assert!(auth_probe_is_throttled_for_testing(ip_a), "IP A must be throttled after concurrent stress");
assert!(auth_probe_is_throttled_for_testing(ip_b), "IP B must be throttled after concurrent stress");
assert!(
auth_probe_is_throttled_for_testing(ip_a),
"IP A must be throttled after concurrent stress"
);
assert!(
auth_probe_is_throttled_for_testing(ip_b),
"IP B must be throttled after concurrent stress"
);
}
#[test]
@@ -586,7 +650,12 @@ fn prototag_invalid_bytes_fail_closed() {
];
for tag in invalid_tags {
assert_eq!(ProtoTag::from_bytes(tag), None, "Invalid ProtoTag bytes {:?} must fail closed", tag);
assert_eq!(
ProtoTag::from_bytes(tag),
None,
"Invalid ProtoTag bytes {:?} must fail closed",
tag
);
}
}
@@ -603,7 +672,10 @@ fn auth_probe_eviction_hash_collision_stress() {
auth_probe_record_failure_with_state(state, ip, now);
}
assert!(state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES, "Eviction logic must successfully bound the map size under heavy insertion stress");
assert!(
state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES,
"Eviction logic must successfully bound the map size under heavy insertion stress"
);
}
#[test]
@@ -0,0 +1,96 @@
use super::*;
use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn adversarial_large_state_offsets_escape_first_scan_window() {
let _guard = auth_probe_test_guard();
let base = Instant::now();
let state_len = 65_536usize;
let scan_limit = 1_024usize;
let mut saw_offset_outside_first_window = false;
for i in 0..8_192u64 {
let ip = IpAddr::V4(Ipv4Addr::new(
((i >> 16) & 0xff) as u8,
((i >> 8) & 0xff) as u8,
(i & 0xff) as u8,
((i.wrapping_mul(131)) & 0xff) as u8,
));
let now = base + Duration::from_nanos(i);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
if start >= scan_limit {
saw_offset_outside_first_window = true;
break;
}
}
assert!(
saw_offset_outside_first_window,
"scan start offset must cover the full auth-probe state, not only the first scan window"
);
}
#[test]
fn stress_large_state_offsets_cover_many_scan_windows() {
let _guard = auth_probe_test_guard();
let base = Instant::now();
let state_len = 65_536usize;
let scan_limit = 1_024usize;
let mut covered_windows = HashSet::new();
for i in 0..16_384u64 {
let ip = IpAddr::V4(Ipv4Addr::new(
((i >> 16) & 0xff) as u8,
((i >> 8) & 0xff) as u8,
(i & 0xff) as u8,
((i.wrapping_mul(17)) & 0xff) as u8,
));
let now = base + Duration::from_micros(i);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
covered_windows.insert(start / scan_limit);
}
assert!(
covered_windows.len() >= 16,
"eviction scan must not collapse to a tiny hot zone; covered windows={} out of {}",
covered_windows.len(),
state_len / scan_limit
);
}
#[test]
fn light_fuzz_offset_always_stays_inside_state_len() {
let _guard = auth_probe_test_guard();
let mut seed = 0xC0FF_EE12_3456_789Au64;
let base = Instant::now();
for _ in 0..8_192usize {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let ip = IpAddr::V4(Ipv4Addr::new(
(seed >> 24) as u8,
(seed >> 16) as u8,
(seed >> 8) as u8,
seed as u8,
));
let state_len = ((seed >> 16) as usize % 200_000).saturating_add(1);
let scan_limit = ((seed >> 40) as usize % 2_048).saturating_add(1);
let now = base + Duration::from_nanos(seed & 0x0fff);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
assert!(
start < state_len,
"scan offset must stay inside state length"
);
}
}
@@ -0,0 +1,99 @@
use super::*;
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn edge_zero_state_len_yields_zero_start_offset() {
let _guard = auth_probe_test_guard();
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 44));
let now = Instant::now();
assert_eq!(
auth_probe_scan_start_offset(ip, now, 0, 16),
0,
"empty map must not produce non-zero scan offset"
);
}
#[test]
fn adversarial_large_state_must_allow_start_offset_outside_scan_budget_window() {
let _guard = auth_probe_test_guard();
let base = Instant::now();
let scan_limit = 16usize;
let state_len = 65_536usize;
let mut saw_offset_outside_window = false;
for i in 0..2048u32 {
let ip = IpAddr::V4(Ipv4Addr::new(
203,
((i >> 16) & 0xff) as u8,
((i >> 8) & 0xff) as u8,
(i & 0xff) as u8,
));
let now = base + Duration::from_micros(i as u64);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
assert!(
start < state_len,
"start offset must stay within state length; start={start}, len={state_len}"
);
if start >= scan_limit {
saw_offset_outside_window = true;
break;
}
}
assert!(
saw_offset_outside_window,
"large-state eviction must sample beyond the first scan window"
);
}
#[test]
fn positive_state_smaller_than_scan_limit_caps_to_state_len() {
let _guard = auth_probe_test_guard();
let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 17));
let now = Instant::now();
for state_len in 1..32usize {
let start = auth_probe_scan_start_offset(ip, now, state_len, 64);
assert!(
start < state_len,
"start offset must never exceed state length when scan limit is larger"
);
}
}
#[test]
fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() {
let _guard = auth_probe_test_guard();
let mut seed = 0x5A41_5356_4C32_3236u64;
let base = Instant::now();
for _ in 0..4096 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let ip = IpAddr::V4(Ipv4Addr::new(
(seed >> 24) as u8,
(seed >> 16) as u8,
(seed >> 8) as u8,
seed as u8,
));
let state_len = ((seed >> 8) as usize % 131_072).saturating_add(1);
let scan_limit = ((seed >> 32) as usize % 512).saturating_add(1);
let now = base + Duration::from_nanos(seed & 0xffff);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
assert!(
start < state_len,
"scan offset must stay inside state length"
);
}
}
@@ -0,0 +1,116 @@
use super::*;
use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn positive_same_ip_moving_time_yields_diverse_scan_offsets() {
let _guard = auth_probe_test_guard();
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77));
let base = Instant::now();
let mut uniq = HashSet::new();
for i in 0..512u64 {
let now = base + Duration::from_nanos(i);
let offset = auth_probe_scan_start_offset(ip, now, 65_536, 16);
uniq.insert(offset);
}
assert!(
uniq.len() >= 256,
"offset randomization collapsed unexpectedly for same-ip moving-time samples (uniq={})",
uniq.len()
);
}
#[test]
fn adversarial_many_ips_same_time_spreads_offsets_without_bias_collapse() {
let _guard = auth_probe_test_guard();
let now = Instant::now();
let mut uniq = HashSet::new();
for i in 0..1024u32 {
let ip = IpAddr::V4(Ipv4Addr::new(
(i >> 16) as u8,
(i >> 8) as u8,
i as u8,
(255 - (i as u8)),
));
uniq.insert(auth_probe_scan_start_offset(ip, now, 65_536, 16));
}
assert!(
uniq.len() >= 512,
"scan offset distribution collapsed unexpectedly across adversarial peer set (uniq={})",
uniq.len()
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_failure_churn_under_saturation_remains_capped_and_live() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let start = Instant::now();
let mut workers = Vec::new();
for worker in 0..8u8 {
workers.push(tokio::spawn(async move {
for i in 0..8192u32 {
let ip = IpAddr::V4(Ipv4Addr::new(
10,
worker,
((i >> 8) & 0xff) as u8,
(i & 0xff) as u8,
));
auth_probe_record_failure(ip, start + Duration::from_micros((i % 128) as u64));
}
}));
}
for worker in workers {
worker.await.expect("saturation worker must not panic");
}
assert!(
auth_probe_state_map().len() <= AUTH_PROBE_TRACK_MAX_ENTRIES,
"state must remain hard-capped under parallel saturation churn"
);
let probe = IpAddr::V4(Ipv4Addr::new(10, 4, 1, 1));
let _ = auth_probe_should_apply_preauth_throttle(probe, start + Duration::from_millis(1));
}
#[test]
fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() {
let _guard = auth_probe_test_guard();
let mut seed = 0xA55A_1357_2468_9BDFu64;
let base = Instant::now();
for _ in 0..8192 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let ip = IpAddr::V4(Ipv4Addr::new(
(seed >> 24) as u8,
(seed >> 16) as u8,
(seed >> 8) as u8,
seed as u8,
));
let state_len = ((seed >> 8) as usize % 200_000).saturating_add(1);
let scan_limit = ((seed >> 40) as usize % 1024).saturating_add(1);
let now = base + Duration::from_nanos(seed & 0x1fff);
let offset = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
assert!(
offset < state_len,
"scan offset must always remain inside state length"
);
}
}
@@ -0,0 +1,42 @@
use super::*;
fn handshake_source() -> &'static str {
include_str!("../handshake.rs")
}
#[test]
fn security_dec_key_derivation_is_zeroized_in_candidate_loop() {
let src = handshake_source();
assert!(
src.contains("let dec_key = Zeroizing::new(sha256(&dec_key_input));"),
"candidate-loop dec_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths"
);
}
#[test]
fn security_enc_key_derivation_is_zeroized_in_candidate_loop() {
let src = handshake_source();
assert!(
src.contains("let enc_key = Zeroizing::new(sha256(&enc_key_input));"),
"candidate-loop enc_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths"
);
}
#[test]
fn security_aes_ctr_initialization_uses_zeroizing_references() {
let src = handshake_source();
assert!(
src.contains("let mut decryptor = AesCtr::new(&dec_key, dec_iv);")
&& src.contains("let encryptor = AesCtr::new(&enc_key, enc_iv);"),
"AES-CTR initialization must use Zeroizing key wrappers directly without creating extra plain key variables"
);
}
#[test]
fn security_success_struct_copies_out_of_zeroizing_wrappers() {
let src = handshake_source();
assert!(
src.contains("dec_key: *dec_key,") && src.contains("enc_key: *enc_key,"),
"HandshakeSuccess construction must copy from Zeroizing wrappers so loop-local key material is dropped and zeroized"
);
}
+98 -26
View File
@@ -1,8 +1,8 @@
use super::*;
use crate::crypto::{sha256, sha256_hmac, AesCtr};
use crate::crypto::{AesCtr, sha256, sha256_hmac};
use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES};
use rand::{Rng, SeedableRng};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
@@ -223,7 +223,10 @@ fn auth_probe_backoff_extreme_fail_streak_clamps_safely() {
assert_eq!(updated.fail_streak, u32::MAX);
let expected_blocked_until = now + Duration::from_millis(AUTH_PROBE_BACKOFF_MAX_MS);
assert_eq!(updated.blocked_until, expected_blocked_until, "Extreme fail streak must clamp cleanly to AUTH_PROBE_BACKOFF_MAX_MS");
assert_eq!(
updated.blocked_until, expected_blocked_until,
"Extreme fail streak must clamp cleanly to AUTH_PROBE_BACKOFF_MAX_MS"
);
}
#[test]
@@ -250,12 +253,19 @@ fn generate_tg_nonce_cryptographic_uniqueness_and_entropy() {
total_set_bits += byte.count_ones() as usize;
}
assert!(nonces.insert(nonce), "generate_tg_nonce emitted a duplicate nonce! RNG is stuck.");
assert!(
nonces.insert(nonce),
"generate_tg_nonce emitted a duplicate nonce! RNG is stuck."
);
}
let total_bits = iterations * HANDSHAKE_LEN * 8;
let ratio = (total_set_bits as f64) / (total_bits as f64);
assert!(ratio > 0.48 && ratio < 0.52, "Nonce entropy is degraded. Set bit ratio: {}", ratio);
assert!(
ratio > 0.48 && ratio < 0.52,
"Nonce entropy is degraded. Set bit ratio: {}",
ratio
);
}
#[tokio::test]
@@ -267,10 +277,19 @@ async fn mtproto_multi_user_decryption_isolation() {
config.general.modes.secure = true;
config.access.ignore_time_skew = true;
config.access.users.insert("user_a".to_string(), "11111111111111111111111111111111".to_string());
config.access.users.insert("user_b".to_string(), "22222222222222222222222222222222".to_string());
config.access.users.insert(
"user_a".to_string(),
"11111111111111111111111111111111".to_string(),
);
config.access.users.insert(
"user_b".to_string(),
"22222222222222222222222222222222".to_string(),
);
let good_secret_hex = "33333333333333333333333333333333";
config.access.users.insert("user_c".to_string(), good_secret_hex.to_string());
config
.access
.users
.insert("user_c".to_string(), good_secret_hex.to_string());
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.104:12345".parse().unwrap();
@@ -291,9 +310,14 @@ async fn mtproto_multi_user_decryption_isolation() {
match res {
HandshakeResult::Success((_, _, success)) => {
assert_eq!(success.user, "user_c", "Decryption attempts on previous users must not corrupt the handshake buffer for the valid user");
assert_eq!(
success.user, "user_c",
"Decryption attempts on previous users must not corrupt the handshake buffer for the valid user"
);
}
_ => panic!("Multi-user MTProto handshake failed. Decryption buffer might be mutating in place."),
_ => panic!(
"Multi-user MTProto handshake failed. Decryption buffer might be mutating in place."
),
}
}
@@ -325,7 +349,9 @@ async fn invalid_secret_warning_lock_contention_and_bound() {
}
let warned = INVALID_SECRET_WARNED.get().unwrap();
let guard = warned.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
let guard = warned
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
assert_eq!(
guard.len(),
@@ -342,7 +368,11 @@ async fn mtproto_strict_concurrent_replay_race_condition() {
let secret_hex = "4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A";
let config = Arc::new(test_config_with_secret_hex(secret_hex));
let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60)));
let valid_handshake = Arc::new(make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1));
let valid_handshake = Arc::new(make_valid_mtproto_handshake(
secret_hex,
ProtoTag::Secure,
1,
));
let tasks = 100;
let barrier = Arc::new(Barrier::new(tasks));
@@ -355,7 +385,10 @@ async fn mtproto_strict_concurrent_replay_race_condition() {
let hs = valid_handshake.clone();
handles.push(tokio::spawn(async move {
let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 250) as u8)), 10000 + i as u16);
let peer = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 250) as u8)),
10000 + i as u16,
);
b.wait().await;
handle_mtproto_handshake(
&hs,
@@ -382,8 +415,15 @@ async fn mtproto_strict_concurrent_replay_race_condition() {
}
}
assert_eq!(successes, 1, "Replay cache race condition allowed multiple identical MTProto handshakes to succeed");
assert_eq!(failures, tasks - 1, "Replay cache failed to forcefully reject concurrent duplicates");
assert_eq!(
successes, 1,
"Replay cache race condition allowed multiple identical MTProto handshakes to succeed"
);
assert_eq!(
failures,
tasks - 1,
"Replay cache failed to forcefully reject concurrent duplicates"
);
}
#[tokio::test]
@@ -398,7 +438,8 @@ async fn tls_alpn_zero_length_protocol_handled_safely() {
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.107:12345".parse().unwrap();
let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b""]);
let handshake =
make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b""]);
let res = handle_tls_handshake(
&handshake,
@@ -412,7 +453,10 @@ async fn tls_alpn_zero_length_protocol_handled_safely() {
)
.await;
assert!(matches!(res, HandshakeResult::BadClient { .. }), "0-length ALPN must be safely rejected without panicking");
assert!(
matches!(res, HandshakeResult::BadClient { .. }),
"0-length ALPN must be safely rejected without panicking"
);
}
#[tokio::test]
@@ -427,7 +471,8 @@ async fn tls_sni_massive_hostname_does_not_panic() {
let peer: SocketAddr = "192.0.2.108:12345".parse().unwrap();
let massive_hostname = String::from_utf8(vec![b'a'; 65000]).unwrap();
let handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, &massive_hostname, &[]);
let handshake =
make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, &massive_hostname, &[]);
let res = handle_tls_handshake(
&handshake,
@@ -441,7 +486,13 @@ async fn tls_sni_massive_hostname_does_not_panic() {
)
.await;
assert!(matches!(res, HandshakeResult::Success(_) | HandshakeResult::BadClient { .. }), "Massive SNI hostname must be processed or ignored without stack overflow or panic");
assert!(
matches!(
res,
HandshakeResult::Success(_) | HandshakeResult::BadClient { .. }
),
"Massive SNI hostname must be processed or ignored without stack overflow or panic"
);
}
#[tokio::test]
@@ -455,7 +506,8 @@ async fn tls_progressive_truncation_fuzzing_no_panics() {
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.109:12345".parse().unwrap();
let valid_handshake = make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b"h2"]);
let valid_handshake =
make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b"h2"]);
let full_len = valid_handshake.len();
// Truncated corpus only: full_len is a valid baseline and should not be
@@ -473,7 +525,11 @@ async fn tls_progressive_truncation_fuzzing_no_panics() {
None,
)
.await;
assert!(matches!(res, HandshakeResult::BadClient { .. }), "Truncated TLS handshake at len {} must fail safely without panicking", i);
assert!(
matches!(res, HandshakeResult::BadClient { .. }),
"Truncated TLS handshake at len {} must fail safely without panicking",
i
);
}
}
@@ -504,7 +560,10 @@ async fn mtproto_pure_entropy_fuzzing_no_panics() {
)
.await;
assert!(matches!(res, HandshakeResult::BadClient { .. }), "Pure entropy MTProto payload must fail closed and never panic");
assert!(
matches!(res, HandshakeResult::BadClient { .. }),
"Pure entropy MTProto payload must fail closed and never panic"
);
}
}
@@ -517,10 +576,16 @@ fn decode_user_secret_odd_length_hex_rejection() {
let mut config = ProxyConfig::default();
config.access.users.clear();
config.access.users.insert("odd_user".to_string(), "1234567890123456789012345678901".to_string());
config.access.users.insert(
"odd_user".to_string(),
"1234567890123456789012345678901".to_string(),
);
let decoded = decode_user_secrets(&config, None);
assert!(decoded.is_empty(), "Odd-length hex string must be gracefully rejected by hex::decode without unwrapping");
assert!(
decoded.is_empty(),
"Odd-length hex string must be gracefully rejected by hex::decode without unwrapping"
);
}
#[test]
@@ -552,7 +617,10 @@ fn saturation_grace_pre_existing_high_fail_streak_immediate_throttle() {
}
let is_throttled = auth_probe_should_apply_preauth_throttle(peer_ip, now);
assert!(is_throttled, "A peer with a pre-existing high fail streak must be immediately throttled when saturation begins, receiving no unearned grace period");
assert!(
is_throttled,
"A peer with a pre-existing high fail streak must be immediately throttled when saturation begins, receiving no unearned grace period"
);
}
#[test]
@@ -586,7 +654,11 @@ fn mtproto_classic_tags_rejected_when_only_secure_mode_enabled() {
config.general.modes.tls = false;
assert!(!mode_enabled_for_proto(&config, ProtoTag::Abridged, false));
assert!(!mode_enabled_for_proto(&config, ProtoTag::Intermediate, false));
assert!(!mode_enabled_for_proto(
&config,
ProtoTag::Intermediate,
false
));
}
#[test]
@@ -1,5 +1,5 @@
use super::*;
use crate::crypto::{sha256, sha256_hmac, AesCtr, SecureRandom};
use crate::crypto::{AesCtr, SecureRandom, sha256, sha256_hmac};
use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
@@ -80,8 +80,7 @@ fn make_valid_tls_client_hello_with_alpn(
digest[28 + i] ^= ts[i];
}
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest);
record
}
@@ -331,7 +330,11 @@ async fn saturation_grace_exhaustion_under_concurrency_keeps_peer_throttled() {
let final_state = state.get(&peer_ip).expect("state must exist");
assert!(
final_state.fail_streak >= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
final_state.fail_streak
>= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
);
assert!(auth_probe_should_apply_preauth_throttle(peer_ip, Instant::now()));
assert!(auth_probe_should_apply_preauth_throttle(
peer_ip,
Instant::now()
));
}
@@ -956,6 +956,89 @@ async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() {
}
}
#[tokio::test]
async fn tls_unknown_sni_drop_policy_returns_hard_error() {
let secret = [0x48u8; 16];
let mut config = test_config_with_secret_hex("48484848484848484848484848484848");
config.censorship.unknown_sni_action = UnknownSniAction::Drop;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.190:44326".parse().unwrap();
let handshake =
make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "unknown.example", &[b"h2"]);
let result = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(
result,
HandshakeResult::Error(ProxyError::UnknownTlsSni)
));
}
#[tokio::test]
async fn tls_unknown_sni_mask_policy_falls_back_to_bad_client() {
let secret = [0x49u8; 16];
let mut config = test_config_with_secret_hex("49494949494949494949494949494949");
config.censorship.unknown_sni_action = UnknownSniAction::Mask;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.191:44326".parse().unwrap();
let handshake =
make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "unknown.example", &[b"h2"]);
let result = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
#[tokio::test]
async fn tls_missing_sni_keeps_legacy_auth_path() {
let secret = [0x4Au8; 16];
let mut config = test_config_with_secret_hex("4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a");
config.censorship.unknown_sni_action = UnknownSniAction::Drop;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.192:44326".parse().unwrap();
let handshake = make_valid_tls_handshake(&secret, 0);
let result = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::Success(_)));
}
#[tokio::test]
async fn alpn_enforce_rejects_unsupported_client_alpn() {
let secret = [0x33u8; 16];
@@ -1,5 +1,5 @@
use super::*;
use crate::crypto::{sha256, sha256_hmac, AesCtr, SecureRandom};
use crate::crypto::{AesCtr, SecureRandom, sha256, sha256_hmac};
use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION};
use std::net::SocketAddr;
use std::time::{Duration, Instant};
@@ -169,10 +169,10 @@ async fn mtproto_user_scan_timing_manual_benchmark() {
);
}
config.access.users.insert(
preferred_user.to_string(),
target_secret_hex.to_string(),
);
config
.access
.users
.insert(preferred_user.to_string(), target_secret_hex.to_string());
let replay_checker_preferred = ReplayChecker::new(65_536, Duration::from_secs(60));
let replay_checker_full_scan = ReplayChecker::new(65_536, Duration::from_secs(60));
@@ -493,9 +493,12 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
];
let mut meaningful_improvement_seen = false;
let mut baseline_sum = 0.0f64;
let mut hardened_sum = 0.0f64;
let mut pair_count = 0usize;
let mut informative_baseline_sum = 0.0f64;
let mut informative_hardened_sum = 0.0f64;
let mut informative_pair_count = 0usize;
let mut low_info_baseline_sum = 0.0f64;
let mut low_info_hardened_sum = 0.0f64;
let mut low_info_pair_count = 0usize;
let acc_quant_step = 1.0 / (2 * SAMPLE_COUNT) as f64;
let tolerated_pair_regression = acc_quant_step + 0.03;
@@ -522,6 +525,16 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
hardened_acc <= baseline_acc + tolerated_pair_regression,
"normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated={tolerated_pair_regression:.3}"
);
informative_baseline_sum += baseline_acc;
informative_hardened_sum += hardened_acc;
informative_pair_count += 1;
} else {
// Low-information pairs (near-random baseline separability) are expected
// to exhibit quantized jitter at low sample counts; do not fold them into
// strict average-regression checks used for informative side-channel signal.
low_info_baseline_sum += baseline_acc;
low_info_hardened_sum += hardened_acc;
low_info_pair_count += 1;
}
println!(
@@ -531,20 +544,30 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
if hardened_acc + 0.05 <= baseline_acc {
meaningful_improvement_seen = true;
}
baseline_sum += baseline_acc;
hardened_sum += hardened_acc;
pair_count += 1;
}
let baseline_avg = baseline_sum / pair_count as f64;
let hardened_avg = hardened_sum / pair_count as f64;
assert!(
informative_pair_count > 0,
"expected at least one informative pair for timing-separability guard"
);
let informative_baseline_avg = informative_baseline_sum / informative_pair_count as f64;
let informative_hardened_avg = informative_hardened_sum / informative_pair_count as f64;
assert!(
hardened_avg <= baseline_avg + 0.10,
"normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}"
informative_hardened_avg <= informative_baseline_avg + 0.10,
"normalization should not materially increase informative average separability: baseline_avg={informative_baseline_avg:.3} hardened_avg={informative_hardened_avg:.3}"
);
if low_info_pair_count > 0 {
let low_info_baseline_avg = low_info_baseline_sum / low_info_pair_count as f64;
let low_info_hardened_avg = low_info_hardened_sum / low_info_pair_count as f64;
assert!(
low_info_hardened_avg <= low_info_baseline_avg + 0.40,
"normalization low-info average drift exceeded jitter budget: baseline_avg={low_info_baseline_avg:.3} hardened_avg={low_info_hardened_avg:.3}"
);
}
// Optional signal only: do not require improvement on every run because
// noisy CI schedulers can flatten pairwise differences at low sample counts.
let _ = meaningful_improvement_seen;
@@ -0,0 +1,126 @@
use super::*;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use tokio::io::AsyncRead;
use tokio::time::{Duration, timeout};
struct EndlessReader {
produced: Arc<AtomicUsize>,
}
impl AsyncRead for EndlessReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let len = buf.remaining().max(1);
let fill = vec![0xAA; len];
buf.put_slice(&fill);
self.produced.fetch_add(len, Ordering::Relaxed);
Poll::Ready(Ok(()))
}
}
#[test]
fn loop_guard_unspecified_bind_uses_interface_inventory() {
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
let resolved: SocketAddr = "192.168.44.10:443".parse().unwrap();
let interfaces = vec!["192.168.44.10".parse().unwrap()];
assert!(is_mask_target_local_listener_with_interfaces(
"mask.example",
443,
local,
Some(resolved),
&interfaces,
));
}
#[tokio::test]
async fn consume_client_data_stops_after_byte_cap_without_eof() {
let produced = Arc::new(AtomicUsize::new(0));
let reader = EndlessReader {
produced: Arc::clone(&produced),
};
let cap = 10_000usize;
consume_client_data(reader, cap).await;
let total = produced.load(Ordering::Relaxed);
assert!(
total >= cap,
"consume path must read at least up to cap before stopping"
);
assert!(
total <= cap + 8192,
"consume path must stop within one read chunk above cap"
);
}
#[test]
fn masking_beobachten_minutes_zero_fail_closes_to_minimum_ttl() {
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 0;
let ttl = masking_beobachten_ttl(&config);
assert_eq!(ttl, std::time::Duration::from_secs(60));
}
#[test]
fn timing_normalization_zero_floor_safety_net_defaults_to_mask_timeout() {
let mut config = ProxyConfig::default();
config.censorship.mask_timing_normalization_enabled = true;
config.censorship.mask_timing_normalization_floor_ms = 0;
config.censorship.mask_timing_normalization_ceiling_ms = 0;
let budget = mask_outcome_target_budget(&config);
assert_eq!(
budget,
Duration::from_millis(0),
"zero floor/ceiling must produce zero extra normalization budget"
);
}
#[tokio::test]
async fn loop_guard_blocks_self_target_before_proxy_protocol_header_growth() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
timeout(Duration::from_millis(120), listener.accept())
.await
.is_ok()
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_proxy_protocol = 2;
let peer: SocketAddr = "203.0.113.251:55991".parse().unwrap();
let local_addr: SocketAddr = format!("0.0.0.0:{}", backend_addr.port()).parse().unwrap();
let beobachten = BeobachtenStore::new();
handle_bad_client(
tokio::io::empty(),
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
let accepted = accept_task.await.unwrap();
assert!(
!accepted,
"loop guard must fail closed before any recursive PROXY protocol amplification"
);
}
@@ -85,7 +85,10 @@ async fn aggressive_mode_shapes_backend_silent_non_eof_path() {
let legacy = capture_forwarded_len_with_mode(body_sent, false, false, false, 0).await;
let aggressive = capture_forwarded_len_with_mode(body_sent, false, true, false, 0).await;
assert!(legacy < floor, "legacy mode should keep timeout path unshaped");
assert!(
legacy < floor,
"legacy mode should keep timeout path unshaped"
);
assert!(
aggressive >= floor,
"aggressive mode must shape backend-silent non-EOF paths (aggressive={aggressive}, floor={floor})"
@@ -0,0 +1,16 @@
use super::*;
#[test]
fn detect_client_type_recognizes_extended_http_probe_verbs() {
assert_eq!(detect_client_type(b"CONNECT / HTTP/1.1\r\n"), "HTTP");
assert_eq!(detect_client_type(b"TRACE / HTTP/1.1\r\n"), "HTTP");
assert_eq!(detect_client_type(b"PATCH / HTTP/1.1\r\n"), "HTTP");
}
#[test]
fn detect_client_type_recognizes_fragmented_http_method_prefixes() {
assert_eq!(detect_client_type(b"CO"), "HTTP");
assert_eq!(detect_client_type(b"CON"), "HTTP");
assert_eq!(detect_client_type(b"TR"), "HTTP");
assert_eq!(detect_client_type(b"PAT"), "HTTP");
}
@@ -0,0 +1,126 @@
use super::*;
use crate::network::dns_overrides::install_entries;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant, timeout};
async fn run_connect_failure_case(
host: &str,
port: u16,
timing_normalization_enabled: bool,
peer: SocketAddr,
) -> Duration {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some(host.to_string());
config.censorship.mask_port = port;
config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled;
config.censorship.mask_timing_normalization_floor_ms = 120;
config.censorship.mask_timing_normalization_ceiling_ms = 120;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let probe = b"CONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n";
let (mut client_writer, client_reader) = duplex(1024);
let (mut client_visible_reader, client_visible_writer) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
client_writer.shutdown().await.unwrap();
timeout(Duration::from_secs(4), task)
.await
.unwrap()
.unwrap();
let mut buf = [0u8; 1];
let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf))
.await
.unwrap()
.unwrap();
assert_eq!(
n, 0,
"connect-failure path must close client-visible writer"
);
started.elapsed()
}
#[tokio::test]
async fn connect_failure_refusal_close_behavior_matrix() {
let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() {
let peer: SocketAddr = format!("203.0.113.210:{}", 54100 + idx as u16)
.parse()
.unwrap();
let elapsed =
run_connect_failure_case("127.0.0.1", unused_port, timing_normalization_enabled, peer)
.await;
if timing_normalization_enabled {
assert!(
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250),
"normalized refusal path must honor configured timing envelope without stalling"
);
} else {
assert!(
elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150),
"non-normalized refusal path must honor baseline connect budget without stalling"
);
}
}
}
#[tokio::test]
async fn connect_failure_overridden_hostname_close_behavior_matrix() {
let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
// Make hostname resolution deterministic in tests so timing ceilings are meaningful.
install_entries(&[format!("mask.invalid:{}:127.0.0.1", unused_port)]).unwrap();
for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() {
let peer: SocketAddr = format!("203.0.113.220:{}", 54200 + idx as u16)
.parse()
.unwrap();
let elapsed = run_connect_failure_case(
"mask.invalid",
unused_port,
timing_normalization_enabled,
peer,
)
.await;
if timing_normalization_enabled {
assert!(
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250),
"normalized overridden-host path must honor configured timing envelope without stalling"
);
} else {
assert!(
elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150),
"non-normalized overridden-host path must honor baseline connect budget without stalling"
);
}
}
install_entries(&[]).unwrap();
}
@@ -0,0 +1,88 @@
use super::*;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::io::{AsyncRead, ReadBuf};
struct OneByteThenStall {
sent: bool,
}
impl AsyncRead for OneByteThenStall {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if !self.sent {
self.sent = true;
buf.put_slice(&[0x42]);
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
#[tokio::test]
async fn stalling_client_terminates_at_idle_not_relay_timeout() {
let reader = OneByteThenStall { sent: false };
let started = Instant::now();
let result = tokio::time::timeout(
MASK_RELAY_TIMEOUT,
consume_client_data(reader, MASK_BUFFER_SIZE * 4),
)
.await;
assert!(
result.is_ok(),
"consume_client_data should complete by per-read idle timeout, not hit relay timeout"
);
let elapsed = started.elapsed();
assert!(
elapsed >= (MASK_RELAY_IDLE_TIMEOUT / 2),
"consume_client_data returned too quickly for idle-timeout path: {elapsed:?}"
);
assert!(
elapsed < MASK_RELAY_TIMEOUT,
"consume_client_data waited full relay timeout ({elapsed:?}); \
per-read idle timeout is missing"
);
}
#[tokio::test]
async fn fast_reader_drains_to_eof() {
let data = vec![0xAAu8; 32 * 1024];
let reader = std::io::Cursor::new(data);
tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, usize::MAX))
.await
.expect("consume_client_data did not complete for fast EOF reader");
}
#[tokio::test]
async fn io_error_terminates_cleanly() {
struct ErrReader;
impl AsyncRead for ErrReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"simulated reset",
)))
}
}
tokio::time::timeout(
MASK_RELAY_TIMEOUT,
consume_client_data(ErrReader, usize::MAX),
)
.await
.expect("consume_client_data did not return on I/O error");
}
@@ -0,0 +1,64 @@
use super::*;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::io::{AsyncRead, ReadBuf};
use tokio::task::JoinSet;
struct OneByteThenStall {
sent: bool,
}
impl AsyncRead for OneByteThenStall {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if !self.sent {
self.sent = true;
buf.put_slice(&[0xAA]);
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
#[tokio::test]
async fn consume_stall_stress_finishes_within_idle_budget() {
let mut set = JoinSet::new();
let started = Instant::now();
for _ in 0..64 {
set.spawn(async {
tokio::time::timeout(
MASK_RELAY_TIMEOUT,
consume_client_data(OneByteThenStall { sent: false }, usize::MAX),
)
.await
.expect("consume_client_data exceeded relay timeout under stall load");
});
}
while let Some(res) = set.join_next().await {
res.unwrap();
}
// Under test constants idle=100ms, relay=200ms. 64 concurrent tasks stalling
// for 100ms should complete well under a strict 600ms boundary.
assert!(
started.elapsed() < MASK_RELAY_TIMEOUT * 3,
"stall stress batch completed too slowly; possible async executor starvation or head-of-line blocking"
);
}
#[tokio::test]
async fn consume_zero_cap_returns_immediately() {
let started = Instant::now();
consume_client_data(tokio::io::empty(), 0).await;
assert!(
started.elapsed() < MASK_RELAY_IDLE_TIMEOUT,
"zero byte cap must return immediately"
);
}
@@ -0,0 +1,225 @@
use super::*;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant, timeout};
fn make_self_target_config(
timing_normalization_enabled: bool,
floor_ms: u64,
ceiling_ms: u64,
beobachten_enabled: bool,
) -> ProxyConfig {
let mut config = ProxyConfig::default();
config.general.beobachten = beobachten_enabled;
config.general.beobachten_minutes = 5;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 443;
config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled;
config.censorship.mask_timing_normalization_floor_ms = floor_ms;
config.censorship.mask_timing_normalization_ceiling_ms = ceiling_ms;
config
}
async fn run_self_target_refusal(
config: ProxyConfig,
peer: SocketAddr,
initial: &'static [u8],
) -> Duration {
let beobachten = BeobachtenStore::new();
let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr");
let (mut client, server) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
server,
tokio::io::sink(),
initial,
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
client
.shutdown()
.await
.expect("client shutdown must succeed");
timeout(Duration::from_secs(3), task)
.await
.expect("self-target refusal must complete in bounded time")
.expect("self-target refusal task must not panic");
started.elapsed()
}
#[tokio::test]
async fn positive_self_target_refusal_honors_normalization_floor() {
let config = make_self_target_config(true, 120, 120, false);
let peer: SocketAddr = "203.0.113.41:54041".parse().expect("valid peer");
let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await;
assert!(
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(260),
"normalized self-target refusal must stay within expected envelope"
);
}
#[tokio::test]
async fn negative_non_normalized_refusal_does_not_sleep_to_large_floor() {
let config = make_self_target_config(false, 240, 240, false);
let peer: SocketAddr = "203.0.113.42:54042".parse().expect("valid peer");
let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await;
assert!(
elapsed < Duration::from_millis(180),
"non-normalized path must not inherit normalization floor delays"
);
}
#[tokio::test]
async fn edge_ceiling_below_floor_uses_floor_fail_closed() {
let config = make_self_target_config(true, 140, 80, false);
let peer: SocketAddr = "203.0.113.43:54043".parse().expect("valid peer");
let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await;
assert!(
elapsed >= Duration::from_millis(130) && elapsed < Duration::from_millis(280),
"ceiling<floor must clamp to floor to preserve deterministic normalization"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_blackhat_parallel_probes_remain_bounded_and_uniform() {
let workers = 24usize;
let mut tasks = Vec::with_capacity(workers);
for idx in 0..workers {
tasks.push(tokio::spawn(async move {
let cfg = make_self_target_config(true, 110, 140, false);
let peer: SocketAddr = format!("203.0.113.50:{}", 54100 + idx as u16)
.parse()
.expect("valid peer");
run_self_target_refusal(cfg, peer, b"GET /x HTTP/1.1\r\n\r\n").await
}));
}
let mut min = Duration::from_secs(60);
let mut max = Duration::from_millis(0);
for task in tasks {
let elapsed = task.await.expect("probe task must not panic");
if elapsed < min {
min = elapsed;
}
if elapsed > max {
max = elapsed;
}
assert!(
elapsed >= Duration::from_millis(100) && elapsed < Duration::from_millis(320),
"parallel probe latency must stay bounded under normalization"
);
}
assert!(
max.saturating_sub(min) <= Duration::from_millis(130),
"normalization should limit path variance across adversarial parallel probes"
);
}
#[tokio::test]
async fn integration_beobachten_records_probe_classification_on_refusal() {
let config = make_self_target_config(false, 0, 0, true);
let peer: SocketAddr = "198.51.100.71:55071".parse().expect("valid peer");
let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr");
let beobachten = BeobachtenStore::new();
let (mut client, server) = duplex(1024);
let task = tokio::spawn(async move {
handle_bad_client(
server,
tokio::io::sink(),
b"GET /classified HTTP/1.1\r\nHost: demo\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
beobachten.snapshot_text(Duration::from_secs(60))
});
client
.shutdown()
.await
.expect("client shutdown must succeed");
let snapshot = timeout(Duration::from_secs(3), task)
.await
.expect("integration task must complete")
.expect("integration task must not panic");
assert!(snapshot.contains("[HTTP]"));
assert!(snapshot.contains("198.51.100.71-1"));
}
#[tokio::test]
async fn light_fuzz_timing_configuration_matrix_is_bounded() {
let mut seed = 0xA17E_55AA_2026_0323u64;
for case in 0..48u64 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let enabled = (seed & 1) == 0;
let floor = (seed >> 8) % 180;
let ceiling = (seed >> 24) % 180;
let config = make_self_target_config(enabled, floor, ceiling, false);
let peer: SocketAddr = format!("203.0.113.90:{}", 56000 + (case as u16))
.parse()
.expect("valid peer");
let elapsed = run_self_target_refusal(config, peer, b"HEAD /h HTTP/1.1\r\n\r\n").await;
assert!(
elapsed < Duration::from_millis(420),
"fuzz case must stay bounded and never hang"
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_high_fanout_self_target_refusal_no_deadlock_or_timeout() {
let workers = 64usize;
let mut tasks = Vec::with_capacity(workers);
for idx in 0..workers {
tasks.push(tokio::spawn(async move {
let config = make_self_target_config(false, 0, 0, false);
let peer: SocketAddr = format!("198.51.100.200:{}", 57000 + idx as u16)
.parse()
.expect("valid peer");
run_self_target_refusal(config, peer, b"GET /stress HTTP/1.1\r\n\r\n").await
}));
}
timeout(Duration::from_secs(5), async {
for task in tasks {
let elapsed = task.await.expect("stress task must not panic");
assert!(
elapsed < Duration::from_millis(260),
"stress refusal must remain bounded without normalization"
);
}
})
.await
.expect("high-fanout refusal workload must complete without deadlock");
}
@@ -0,0 +1,55 @@
use super::*;
use tokio::net::TcpListener;
use tokio::time::Duration;
#[tokio::test]
async fn http2_preface_is_forwarded_and_recorded_as_http() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec();
let accept_task = tokio::spawn({
let preface = preface.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; preface.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, preface);
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 1;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "198.51.100.130:54130".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_reader, _client_writer) = tokio::io::duplex(512);
let (_client_visible_reader, client_visible_writer) = tokio::io::duplex(512);
let beobachten = BeobachtenStore::new();
handle_bad_client(
client_reader,
client_visible_writer,
&preface,
peer,
local_addr,
&config,
&beobachten,
)
.await;
tokio::time::timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[HTTP]"));
assert!(snapshot.contains("198.51.100.130-1"));
}
@@ -0,0 +1,92 @@
use super::*;
#[test]
fn full_http2_preface_classified_as_http_probe() {
let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
assert!(
is_http_probe(preface),
"HTTP/2 connection preface must be classified as HTTP probe"
);
}
#[test]
fn partial_http2_preface_3_bytes_classified() {
assert!(
is_http_probe(b"PRI"),
"3-byte HTTP/2 preface prefix must be classified"
);
}
#[test]
fn partial_http2_preface_2_bytes_classified() {
assert!(
is_http_probe(b"PR"),
"2-byte HTTP/2 preface prefix must be classified"
);
}
#[test]
fn existing_http1_methods_unaffected() {
for prefix in [
b"GET / HTTP/1.1\r\n".as_ref(),
b"POST /api HTTP/1.1\r\n".as_ref(),
b"CONNECT example.com:443 HTTP/1.1\r\n".as_ref(),
b"TRACE / HTTP/1.1\r\n".as_ref(),
b"PATCH / HTTP/1.1\r\n".as_ref(),
] {
assert!(is_http_probe(prefix));
}
}
#[test]
fn non_http_data_not_classified() {
for data in [
b"\x16\x03\x01\x00\xf1".as_ref(),
b"SSH-2.0-OpenSSH_8.9\r\n".as_ref(),
b"\x00\x01\x02\x03".as_ref(),
b"".as_ref(),
b"P".as_ref(),
] {
assert!(!is_http_probe(data));
}
}
#[test]
fn light_fuzz_non_http_prefixes_not_misclassified() {
// Deterministic pseudo-fuzz to exercise classifier edges while avoiding
// known HTTP method and partial windows.
let mut x = 0x1234_5678u32;
for _ in 0..1024 {
x = x.wrapping_mul(1664525).wrapping_add(1013904223);
let len = 4 + ((x >> 8) as usize % 12);
let mut data = vec![0u8; len];
for byte in &mut data {
x = x.wrapping_mul(1664525).wrapping_add(1013904223);
*byte = (x & 0xFF) as u8;
}
if [
b"GET ".as_ref(),
b"POST".as_ref(),
b"HEAD".as_ref(),
b"PUT ".as_ref(),
b"DELETE".as_ref(),
b"OPTIONS".as_ref(),
b"CONNECT".as_ref(),
b"TRACE".as_ref(),
b"PATCH".as_ref(),
b"PRI ".as_ref(),
]
.iter()
.any(|m| data.starts_with(m))
{
continue;
}
assert!(
!is_http_probe(&data),
"non-http pseudo-fuzz input misclassified: {:?}",
&data[..data.len().min(8)]
);
}
}
@@ -0,0 +1,85 @@
use super::*;
#[test]
fn exact_four_byte_http_tokens_are_classified() {
for token in [
b"GET ".as_ref(),
b"POST".as_ref(),
b"HEAD".as_ref(),
b"PUT ".as_ref(),
b"PRI ".as_ref(),
] {
assert!(
is_http_probe(token),
"exact 4-byte token must be classified as HTTP probe: {:?}",
token
);
}
}
#[test]
fn exact_four_byte_non_http_tokens_are_not_classified() {
for token in [
b"GEX ".as_ref(),
b"POXT".as_ref(),
b"HEA/".as_ref(),
b"PU\0 ".as_ref(),
b"PRI/".as_ref(),
] {
assert!(
!is_http_probe(token),
"non-HTTP 4-byte token must not be classified: {:?}",
token
);
}
}
#[test]
fn detect_client_type_keeps_http_label_for_minimal_four_byte_http_prefixes() {
assert_eq!(detect_client_type(b"GET "), "HTTP");
assert_eq!(detect_client_type(b"PRI "), "HTTP");
}
#[test]
fn exact_long_http_tokens_are_classified() {
for token in [b"CONNECT".as_ref(), b"TRACE".as_ref(), b"PATCH".as_ref()] {
assert!(
is_http_probe(token),
"exact long HTTP token must be classified as HTTP probe: {:?}",
token
);
}
}
#[test]
fn detect_client_type_keeps_http_label_for_exact_long_http_tokens() {
assert_eq!(detect_client_type(b"CONNECT"), "HTTP");
assert_eq!(detect_client_type(b"TRACE"), "HTTP");
assert_eq!(detect_client_type(b"PATCH"), "HTTP");
}
#[test]
fn light_fuzz_four_byte_ascii_noise_not_misclassified() {
// Deterministic pseudo-fuzz over 4-byte printable ASCII inputs.
let mut x = 0xA17C_93E5u32;
for _ in 0..2048 {
let mut token = [0u8; 4];
for byte in &mut token {
x = x.wrapping_mul(1664525).wrapping_add(1013904223);
*byte = 32 + ((x & 0x3F) as u8); // printable ASCII subset
}
if [b"GET ", b"POST", b"HEAD", b"PUT ", b"PRI "]
.iter()
.any(|m| token.as_slice() == *m)
{
continue;
}
assert!(
!is_http_probe(&token),
"pseudo-fuzz noise misclassified as HTTP probe: {:?}",
token
);
}
}
@@ -0,0 +1,41 @@
#![cfg(unix)]
use super::*;
use std::sync::{Mutex, OnceLock};
use tokio::sync::Barrier;
fn interface_cache_test_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() {
let _guard = interface_cache_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
reset_local_interface_enumerations_for_tests();
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let workers = 32usize;
let barrier = std::sync::Arc::new(Barrier::new(workers));
let mut tasks = Vec::with_capacity(workers);
for _ in 0..workers {
let barrier = std::sync::Arc::clone(&barrier);
tasks.push(tokio::spawn(async move {
barrier.wait().await;
is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await
}));
}
for task in tasks {
let _ = task.await.expect("parallel cache task must not panic");
}
assert_eq!(
local_interface_enumerations_for_tests(),
1,
"parallel cold misses must coalesce into a single interface enumeration"
);
}
@@ -0,0 +1,51 @@
#![cfg(unix)]
use super::*;
#[test]
fn defense_in_depth_empty_refresh_preserves_previous_non_empty_interfaces() {
let previous = vec![
"192.168.100.7"
.parse::<IpAddr>()
.expect("must parse interface ip"),
];
let refreshed = Vec::new();
let next = choose_interface_snapshot(&previous, refreshed);
assert_eq!(
next, previous,
"empty refresh should preserve previous non-empty snapshot to avoid fail-open loop-guard regressions"
);
}
#[test]
fn defense_in_depth_non_empty_refresh_replaces_previous_snapshot() {
let previous = vec![
"192.168.100.7"
.parse::<IpAddr>()
.expect("must parse interface ip"),
];
let refreshed = vec![
"10.55.0.3"
.parse::<IpAddr>()
.expect("must parse refreshed interface ip"),
];
let next = choose_interface_snapshot(&previous, refreshed.clone());
assert_eq!(next, refreshed);
}
#[test]
fn defense_in_depth_empty_refresh_keeps_empty_when_no_previous_snapshot_exists() {
let previous = Vec::new();
let refreshed = Vec::new();
let next = choose_interface_snapshot(&previous, refreshed);
assert!(
next.is_empty(),
"empty refresh with no previous snapshot should remain empty"
);
}
@@ -0,0 +1,49 @@
#![cfg(unix)]
use super::*;
use std::sync::{Mutex, OnceLock};
fn interface_cache_test_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
#[tokio::test]
async fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within_window() {
let _guard = interface_cache_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
reset_local_interface_enumerations_for_tests();
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await;
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await;
assert_eq!(
local_interface_enumerations_for_tests(),
1,
"interface enumeration must be cached across repeated bad-client checks"
);
}
#[tokio::test]
async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() {
let _guard = interface_cache_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
reset_local_interface_enumerations_for_tests();
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await;
assert!(
!is_local,
"different port must not be treated as local listener"
);
assert_eq!(
local_interface_enumerations_for_tests(),
0,
"port mismatch should bypass interface enumeration entirely"
);
}
@@ -0,0 +1,178 @@
use super::*;
use std::net::{SocketAddr, TcpListener as StdTcpListener};
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant};
fn closed_local_port() -> u16 {
let listener = StdTcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
port
}
#[tokio::test]
#[ignore = "red-team expected-fail: offline mask target keeps bad-client socket alive before consume timeout boundary"]
async fn redteam_offline_target_should_drop_idle_client_early() {
let (client_read, mut client_write) = duplex(1024);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = closed_local_port();
cfg.censorship.mask_timing_normalization_enabled = false;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer_addr: SocketAddr = "192.0.2.50:5000".parse().unwrap();
let beobachten = BeobachtenStore::new();
let handler = tokio::spawn(async move {
handle_bad_client(
client_read,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer_addr,
local_addr,
&cfg,
&beobachten,
)
.await;
});
tokio::time::sleep(Duration::from_millis(150)).await;
let write_res = client_write.write_all(b"probe-should-be-closed").await;
assert!(
write_res.is_err(),
"offline target path still keeps client writable before consume timeout"
);
handler.abort();
}
#[tokio::test]
#[ignore = "red-team expected-fail: proxy should mimic immediate RST-like close when target is offline"]
async fn redteam_offline_target_should_not_sleep_to_mask_refusal() {
let (client_read, mut client_write) = duplex(1024);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = closed_local_port();
cfg.censorship.mask_timing_normalization_enabled = false;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer_addr: SocketAddr = "192.0.2.51:5000".parse().unwrap();
let beobachten = BeobachtenStore::new();
let started = Instant::now();
let handler = tokio::spawn(async move {
handle_bad_client(
client_read,
tokio::io::sink(),
b"\x16\x03\x01\x00\x05hello",
peer_addr,
local_addr,
&cfg,
&beobachten,
)
.await;
});
client_write.shutdown().await.unwrap();
let _ = handler.await;
let elapsed = started.elapsed();
assert!(
elapsed < Duration::from_millis(10),
"offline target path still applies coarse masking sleep and is fingerprintable"
);
}
#[tokio::test]
#[ignore = "red-team expected-fail: refusal path should remain below strict latency envelope under burst"]
async fn redteam_offline_refusal_burst_timing_spread_should_be_tight() {
let mut samples = Vec::new();
for i in 0..12u16 {
let (client_read, mut client_write) = duplex(1024);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = closed_local_port();
cfg.censorship.mask_timing_normalization_enabled = false;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer_addr: SocketAddr = format!("192.0.2.52:{}", 5100 + i).parse().unwrap();
let beobachten = BeobachtenStore::new();
let started = Instant::now();
let handler = tokio::spawn(async move {
handle_bad_client(
client_read,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer_addr,
local_addr,
&cfg,
&beobachten,
)
.await;
});
client_write.shutdown().await.unwrap();
let _ = handler.await;
samples.push(started.elapsed());
}
let min = samples.iter().copied().min().unwrap_or_default();
let max = samples.iter().copied().max().unwrap_or_default();
let spread = max.saturating_sub(min);
assert!(
spread <= Duration::from_millis(5),
"offline refusal timing spread too wide for strict red-team envelope: {:?}",
spread
);
}
#[tokio::test]
#[ignore = "manual red-team: host resolver failure should complete without panic"]
async fn redteam_dns_resolution_failure_must_not_panic() {
let (client_read, mut client_write) = duplex(1024);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("this.domain.definitely.does.not.exist.invalid".to_string());
cfg.censorship.mask_port = 443;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer_addr: SocketAddr = "192.0.2.99:5999".parse().unwrap();
let beobachten = BeobachtenStore::new();
let handler = tokio::spawn(async move {
handle_bad_client(
client_read,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer_addr,
local_addr,
&cfg,
&beobachten,
)
.await;
});
client_write.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(2), handler).await;
assert!(
result.is_ok(),
"dns failure path stalled or panicked instead of terminating"
);
}
@@ -0,0 +1,51 @@
use super::*;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::io::AsyncWrite;
struct NeverWritable;
impl AsyncWrite for NeverWritable {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Pending
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn shape_padding_returns_before_global_mask_timeout_on_blocked_writer() {
let mut writer = NeverWritable;
let started = Instant::now();
maybe_write_shape_padding(&mut writer, 1, true, 256, 4096, false, 0, false).await;
assert!(
started.elapsed() <= MASK_TIMEOUT + std::time::Duration::from_millis(30),
"shape padding blocked past timeout budget"
);
}
#[tokio::test]
async fn shape_padding_with_non_http_blur_disabled_at_cap_writes_nothing() {
let mut output = Vec::new();
{
let mut writer = tokio::io::BufWriter::new(&mut output);
maybe_write_shape_padding(&mut writer, 4096, true, 64, 4096, false, 128, false).await;
use tokio::io::AsyncWriteExt;
writer.flush().await.unwrap();
}
assert!(output.is_empty());
}
@@ -0,0 +1,283 @@
use super::*;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time::{Duration, Instant, timeout};
const PROD_CAP_BYTES: usize = 5 * 1024 * 1024;
struct FinitePatternReader {
remaining: usize,
chunk: usize,
read_calls: Arc<AtomicUsize>,
}
impl FinitePatternReader {
fn new(total: usize, chunk: usize, read_calls: Arc<AtomicUsize>) -> Self {
Self {
remaining: total,
chunk,
read_calls,
}
}
}
impl AsyncRead for FinitePatternReader {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.read_calls.fetch_add(1, Ordering::Relaxed);
if self.remaining == 0 {
return Poll::Ready(Ok(()));
}
let take = self.remaining.min(self.chunk).min(buf.remaining());
if take == 0 {
return Poll::Ready(Ok(()));
}
let fill = vec![0x5Au8; take];
buf.put_slice(&fill);
self.remaining -= take;
Poll::Ready(Ok(()))
}
}
#[derive(Default)]
struct CountingWriter {
written: usize,
}
impl AsyncWrite for CountingWriter {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.written = self.written.saturating_add(buf.len());
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
struct NeverReadyReader;
impl AsyncRead for NeverReadyReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Pending
}
}
struct BudgetProbeReader {
remaining: usize,
total_read: Arc<AtomicUsize>,
}
impl BudgetProbeReader {
fn new(total: usize, total_read: Arc<AtomicUsize>) -> Self {
Self {
remaining: total,
total_read,
}
}
}
impl AsyncRead for BudgetProbeReader {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if self.remaining == 0 {
return Poll::Ready(Ok(()));
}
let take = self.remaining.min(buf.remaining());
if take == 0 {
return Poll::Ready(Ok(()));
}
let fill = vec![0xA5u8; take];
buf.put_slice(&fill);
self.remaining -= take;
self.total_read.fetch_add(take, Ordering::Relaxed);
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn positive_copy_with_production_cap_stops_exactly_at_budget() {
let read_calls = Arc::new(AtomicUsize::new(0));
let mut reader = FinitePatternReader::new(PROD_CAP_BYTES + (256 * 1024), 4096, read_calls);
let mut writer = CountingWriter::default();
let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await;
assert_eq!(
outcome.total, PROD_CAP_BYTES,
"copy path must stop at explicit production cap"
);
assert_eq!(writer.written, PROD_CAP_BYTES);
assert!(
!outcome.ended_by_eof,
"byte-cap stop must not be misclassified as EOF"
);
}
#[tokio::test]
async fn negative_consume_with_zero_cap_performs_no_reads() {
let read_calls = Arc::new(AtomicUsize::new(0));
let reader = FinitePatternReader::new(1024, 64, Arc::clone(&read_calls));
consume_client_data_with_timeout_and_cap(reader, 0).await;
assert_eq!(
read_calls.load(Ordering::Relaxed),
0,
"zero cap must return before reading attacker-controlled bytes"
);
}
#[tokio::test]
async fn edge_copy_below_cap_reports_eof_without_overread() {
let read_calls = Arc::new(AtomicUsize::new(0));
let payload = 73 * 1024;
let mut reader = FinitePatternReader::new(payload, 3072, read_calls);
let mut writer = CountingWriter::default();
let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await;
assert_eq!(outcome.total, payload);
assert_eq!(writer.written, payload);
assert!(
outcome.ended_by_eof,
"finite upstream below cap must terminate via EOF path"
);
}
#[tokio::test]
async fn adversarial_blackhat_never_ready_reader_is_bounded_by_timeout_guards() {
let started = Instant::now();
consume_client_data_with_timeout_and_cap(NeverReadyReader, PROD_CAP_BYTES).await;
assert!(
started.elapsed() < Duration::from_millis(350),
"never-ready reader must be bounded by idle/relay timeout protections"
);
}
#[tokio::test]
async fn integration_consume_path_honors_production_cap_for_large_payload() {
let read_calls = Arc::new(AtomicUsize::new(0));
let reader = FinitePatternReader::new(PROD_CAP_BYTES + (1024 * 1024), 8192, read_calls);
let bounded = timeout(
Duration::from_millis(350),
consume_client_data_with_timeout_and_cap(reader, PROD_CAP_BYTES),
)
.await;
assert!(
bounded.is_ok(),
"consume path with production cap must finish within bounded time"
);
}
#[tokio::test]
async fn adversarial_consume_path_never_reads_beyond_declared_byte_cap() {
let byte_cap = 5usize;
let total_read = Arc::new(AtomicUsize::new(0));
let reader = BudgetProbeReader::new(256 * 1024, Arc::clone(&total_read));
consume_client_data_with_timeout_and_cap(reader, byte_cap).await;
assert!(
total_read.load(Ordering::Relaxed) <= byte_cap,
"consume path must not read more than configured byte cap"
);
}
#[tokio::test]
async fn light_fuzz_cap_and_payload_matrix_preserves_min_budget_invariant() {
let mut seed = 0x1234_5678_9ABC_DEF0u64;
for _case in 0..96u32 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let cap = ((seed & 0x3ffff) as usize).saturating_add(1);
let payload = ((seed.rotate_left(11) & 0x7ffff) as usize).saturating_add(1);
let chunk = (((seed >> 5) & 0x1fff) as usize).saturating_add(1);
let read_calls = Arc::new(AtomicUsize::new(0));
let mut reader = FinitePatternReader::new(payload, chunk, read_calls);
let mut writer = CountingWriter::default();
let outcome = copy_with_idle_timeout(&mut reader, &mut writer, cap, true).await;
let expected = payload.min(cap);
assert_eq!(
outcome.total, expected,
"copy total must match min(payload, cap) under fuzzed inputs"
);
assert_eq!(writer.written, expected);
if payload <= cap {
assert!(outcome.ended_by_eof);
} else {
assert!(!outcome.ended_by_eof);
}
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_copy_tasks_with_production_cap_complete_without_leaks() {
let workers = 8usize;
let mut tasks = Vec::with_capacity(workers);
for idx in 0..workers {
tasks.push(tokio::spawn(async move {
let read_calls = Arc::new(AtomicUsize::new(0));
let mut reader = FinitePatternReader::new(
PROD_CAP_BYTES + (idx + 1) * 4096,
4096 + (idx * 257),
read_calls,
);
let mut writer = CountingWriter::default();
copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await
}));
}
timeout(Duration::from_secs(3), async {
for task in tasks {
let outcome = task.await.expect("stress task must not panic");
assert_eq!(
outcome.total, PROD_CAP_BYTES,
"stress copy task must stay within production cap"
);
assert!(
!outcome.ended_by_eof,
"stress task should end due to cap, not EOF"
);
}
})
.await
.expect("stress suite must complete in bounded time");
}
@@ -0,0 +1,105 @@
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex, sink};
use tokio::time::{Duration, timeout};
#[tokio::test]
async fn relay_to_mask_enforces_masking_session_byte_cap() {
let initial = vec![0x16, 0x03, 0x01, 0x00, 0x01];
let extra = vec![0xAB; 96 * 1024];
let (client_reader, mut client_writer) = duplex(128 * 1024);
let (mask_read, _mask_read_peer) = duplex(1024);
let (mut mask_observer, mask_write) = duplex(256 * 1024);
let initial_for_task = initial.clone();
let relay = tokio::spawn(async move {
relay_to_mask(
client_reader,
sink(),
mask_read,
mask_write,
&initial_for_task,
false,
512,
4096,
false,
0,
false,
32 * 1024,
)
.await;
});
client_writer.write_all(&extra).await.unwrap();
client_writer.shutdown().await.unwrap();
timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
let mut observed = Vec::new();
timeout(
Duration::from_secs(2),
mask_observer.read_to_end(&mut observed),
)
.await
.unwrap()
.unwrap();
// In this deterministic test, relay must stop exactly at the configured cap.
assert_eq!(
observed.len(),
initial.len() + (32 * 1024),
"masked relay must forward exactly up to the cap (observed={} initial={} cap={})",
observed.len(),
initial.len(),
32 * 1024
);
}
#[tokio::test]
async fn relay_to_mask_propagates_client_half_close_without_waiting_for_other_direction_timeout() {
let initial = b"GET /half-close HTTP/1.1\r\n".to_vec();
let (client_reader, mut client_writer) = duplex(8 * 1024);
let (mask_read, _mask_read_peer) = duplex(8 * 1024);
let (mut mask_observer, mask_write) = duplex(8 * 1024);
let initial_for_task = initial.clone();
let relay = tokio::spawn(async move {
relay_to_mask(
client_reader,
sink(),
mask_read,
mask_write,
&initial_for_task,
false,
512,
4096,
false,
0,
false,
32 * 1024,
)
.await;
});
client_writer.shutdown().await.unwrap();
let mut observed = Vec::new();
timeout(
Duration::from_millis(80),
mask_observer.read_to_end(&mut observed),
)
.await
.expect("mask backend write side should be half-closed promptly")
.unwrap();
assert_eq!(&observed[..initial.len()], initial.as_slice());
timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
}
@@ -0,0 +1,100 @@
use super::*;
use tokio::io::AsyncReadExt;
use tokio::time::{Duration, timeout};
async fn collect_padding(
total_sent: usize,
enabled: bool,
floor: usize,
cap: usize,
above_cap_blur: bool,
blur_max: usize,
aggressive: bool,
) -> Vec<u8> {
let (mut tx, mut rx) = tokio::io::duplex(256 * 1024);
maybe_write_shape_padding(
&mut tx,
total_sent,
enabled,
floor,
cap,
above_cap_blur,
blur_max,
aggressive,
)
.await;
drop(tx);
let mut output = Vec::new();
timeout(Duration::from_secs(1), rx.read_to_end(&mut output))
.await
.expect("reading padded output timed out")
.expect("failed reading padded output");
output
}
#[tokio::test]
async fn padding_output_is_not_all_zero() {
let output = collect_padding(1, true, 256, 4096, false, 0, false).await;
assert!(
output.len() >= 255,
"expected at least 255 padding bytes, got {}",
output.len()
);
let nonzero = output.iter().filter(|&&b| b != 0).count();
// In 255 bytes of uniform randomness, the expected number of zero bytes is ~1.
// A weak nonzero check can miss severe entropy collapse.
assert!(
nonzero >= 240,
"RNG output entropy collapsed, too many zero bytes: {} nonzero out of {}",
nonzero,
output.len(),
);
}
#[tokio::test]
async fn padding_reaches_first_bucket_boundary() {
let output = collect_padding(1, true, 64, 4096, false, 0, false).await;
assert_eq!(output.len(), 63);
}
#[tokio::test]
async fn disabled_padding_produces_no_output() {
let output = collect_padding(0, false, 256, 4096, false, 0, false).await;
assert!(output.is_empty());
}
#[tokio::test]
async fn at_cap_without_blur_produces_no_output() {
let output = collect_padding(4096, true, 64, 4096, false, 0, false).await;
assert!(output.is_empty());
}
#[tokio::test]
async fn above_cap_blur_is_positive_and_bounded_in_aggressive_mode() {
let output = collect_padding(4096, true, 64, 4096, true, 128, true).await;
assert!(!output.is_empty());
assert!(output.len() <= 128, "blur exceeded max: {}", output.len());
}
#[tokio::test]
async fn stress_padding_runs_are_not_constant_pattern() {
// Stress and sanity-check: repeated runs should not collapse to identical
// first 16 bytes across all samples.
let mut first_chunks = Vec::new();
for _ in 0..64 {
let out = collect_padding(1, true, 64, 4096, false, 0, false).await;
first_chunks.push(out[..16].to_vec());
}
let first = &first_chunks[0];
let all_same = first_chunks.iter().all(|chunk| chunk == first);
assert!(
!all_same,
"all stress samples had identical prefix, rng output appears degenerate"
);
}
@@ -1376,6 +1376,7 @@ async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stall
false,
0,
false,
5 * 1024 * 1024,
)
.await;
});
@@ -1506,6 +1507,7 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
false,
0,
false,
5 * 1024 * 1024,
),
)
.await;
@@ -0,0 +1,330 @@
use super::*;
use std::net::SocketAddr;
use std::net::TcpListener as StdTcpListener;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, Instant, timeout};
fn closed_local_port() -> u16 {
let listener = StdTcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
port
}
#[tokio::test]
async fn self_target_detection_matches_literal_ipv4_listener() {
let local: SocketAddr = "198.51.100.40:443".parse().unwrap();
assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, None,).await);
}
#[tokio::test]
async fn self_target_detection_matches_bracketed_ipv6_listener() {
let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap();
assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, None,).await);
}
#[tokio::test]
async fn self_target_detection_keeps_same_ip_different_port_forwardable() {
let local: SocketAddr = "203.0.113.44:443".parse().unwrap();
assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, None,).await);
}
#[tokio::test]
async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() {
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, None,).await);
}
#[tokio::test]
async fn self_target_detection_unspecified_bind_blocks_loopback_target() {
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, None,).await);
}
#[tokio::test]
async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() {
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
let remote: SocketAddr = "198.51.100.44:443".parse().unwrap();
assert!(!is_mask_target_local_listener_async("mask.example", 443, local, Some(remote),).await);
}
#[tokio::test]
async fn self_target_fallback_refuses_recursive_loopback_connect() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
timeout(Duration::from_millis(120), listener.accept())
.await
.is_ok()
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some(local_addr.ip().to_string());
config.censorship.mask_port = local_addr.port();
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.90:55090".parse().unwrap();
let beobachten = BeobachtenStore::new();
handle_bad_client(
tokio::io::empty(),
tokio::io::sink(),
b"GET /",
peer,
local_addr,
&config,
&beobachten,
)
.await;
let accepted = accept_task.await.unwrap();
assert!(
!accepted,
"self-target masking must fail closed without connecting to local listener"
);
}
#[tokio::test]
async fn same_ip_different_port_still_forwards_to_mask_backend() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = b"GET /".to_vec();
let accept_task = tokio::spawn({
let expected = probe.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; expected.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, expected);
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.91:55091".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
handle_bad_client(
tokio::io::empty(),
tokio::io::sink(),
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
}
#[test]
fn detect_client_type_http_boundary_get_and_post() {
assert_eq!(detect_client_type(b"GET "), "HTTP");
assert_eq!(detect_client_type(b"GET /"), "HTTP");
assert_eq!(detect_client_type(b"POST"), "HTTP");
assert_eq!(detect_client_type(b"POST "), "HTTP");
assert_eq!(detect_client_type(b"POSTX"), "HTTP");
}
#[test]
fn detect_client_type_tls_and_length_boundaries() {
assert_eq!(detect_client_type(b"\x16\x03\x01"), "port-scanner");
assert_eq!(detect_client_type(b"\x16\x03\x01\x00"), "TLS-scanner");
assert_eq!(detect_client_type(b"123456789"), "port-scanner");
assert_eq!(detect_client_type(b"1234567890"), "unknown");
}
#[test]
fn build_mask_proxy_header_v1_cross_family_falls_back_to_unknown() {
let peer: SocketAddr = "192.168.1.5:12345".parse().unwrap();
let local: SocketAddr = "[2001:db8::1]:443".parse().unwrap();
let header = build_mask_proxy_header(1, peer, local).unwrap();
assert_eq!(header, b"PROXY UNKNOWN\r\n");
}
#[test]
fn next_mask_shape_bucket_checked_mul_overflow_fails_closed() {
let floor = usize::MAX / 2 + 1;
let cap = usize::MAX;
let total = floor + 1;
assert_eq!(next_mask_shape_bucket(total, floor, cap), total);
}
#[tokio::test]
async fn self_target_reject_path_keeps_timing_budget() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 443;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer: SocketAddr = "203.0.113.92:55092".parse().unwrap();
let beobachten = BeobachtenStore::new();
let (client, server) = duplex(1024);
drop(client);
let started = Instant::now();
handle_bad_client(
server,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(250),
"self-target reject path must keep coarse timing budget without stalling"
);
}
#[tokio::test]
async fn relay_path_idle_timeout_eviction_remains_effective() {
let (client_read, mut client_write) = duplex(1024);
let (mask_read, mask_write) = duplex(1024);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
client_write.write_all(b"a").await.unwrap();
tokio::time::sleep(Duration::from_millis(180)).await;
let _ = client_write.write_all(b"b").await;
});
let started = Instant::now();
relay_to_mask(
client_read,
tokio::io::sink(),
mask_read,
mask_write,
b"init",
false,
0,
0,
false,
0,
false,
5 * 1024 * 1024,
)
.await;
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(90) && elapsed < Duration::from_millis(180),
"idle-timeout eviction must occur before late trickle write"
);
}
#[tokio::test]
async fn offline_mask_target_refusal_respects_timing_normalization_budget() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = closed_local_port();
config.censorship.mask_timing_normalization_enabled = true;
config.censorship.mask_timing_normalization_floor_ms = 120;
config.censorship.mask_timing_normalization_ceiling_ms = 120;
let peer: SocketAddr = "203.0.113.93:55093".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let (mut client, server) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
server,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
client.shutdown().await.unwrap();
timeout(Duration::from_secs(2), task)
.await
.unwrap()
.unwrap();
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(220),
"offline-refusal path must honor normalization budget without unbounded drift"
);
}
#[tokio::test]
async fn offline_mask_target_refusal_with_idle_client_is_bounded_by_consume_timeout() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = closed_local_port();
config.censorship.mask_timing_normalization_enabled = false;
let peer: SocketAddr = "203.0.113.94:55094".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let (mut client, server) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
server,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
tokio::time::sleep(Duration::from_millis(120)).await;
client
.write_all(b"still-open-before-timeout")
.await
.expect("connection should still be open before consume timeout expires");
timeout(Duration::from_secs(2), task)
.await
.unwrap()
.unwrap();
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(190) && elapsed < Duration::from_millis(350),
"offline-refusal path must not retain idle client indefinitely"
);
}
@@ -43,6 +43,7 @@ async fn run_relay_case(
above_cap_blur,
above_cap_blur_max_bytes,
false,
5 * 1024 * 1024,
)
.await;
});
@@ -88,6 +88,7 @@ async fn relay_to_mask_applies_cap_clamped_padding_for_non_power_of_two_cap() {
false,
0,
false,
5 * 1024 * 1024,
)
.await;
});
@@ -0,0 +1,58 @@
#![cfg(unix)]
use super::*;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant, timeout};
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn adversarial_delayed_interface_lookup_does_not_consume_outcome_floor_budget() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 443;
config.censorship.mask_timing_normalization_enabled = true;
config.censorship.mask_timing_normalization_floor_ms = 120;
config.censorship.mask_timing_normalization_ceiling_ms = 120;
let peer: SocketAddr = "203.0.113.151:55151".parse().expect("valid peer");
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let beobachten = BeobachtenStore::new();
let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(()));
let held_refresh_guard = refresh_lock.lock().await;
let (mut client, server) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
server,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
tokio::time::sleep(Duration::from_millis(80)).await;
drop(held_refresh_guard);
client
.shutdown()
.await
.expect("client shutdown must succeed");
timeout(Duration::from_secs(2), task)
.await
.expect("task must finish in bounded time")
.expect("task must not panic");
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(180) && elapsed < Duration::from_millis(350),
"timing normalization floor must start after pre-outcome self-target checks"
);
}
@@ -0,0 +1,189 @@
use super::*;
use crate::crypto::AesCtr;
use bytes::Bytes;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::task::{Context, Poll};
use tokio::io::AsyncWrite;
struct CountedWriter {
write_calls: Arc<AtomicUsize>,
fail_writes: bool,
}
impl CountedWriter {
fn new(write_calls: Arc<AtomicUsize>, fail_writes: bool) -> Self {
Self {
write_calls,
fail_writes,
}
}
}
impl AsyncWrite for CountedWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
this.write_calls.fetch_add(1, Ordering::Relaxed);
if this.fail_writes {
Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"forced write failure",
)))
} else {
Poll::Ready(Ok(buf.len()))
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
fn make_crypto_writer(inner: CountedWriter) -> CryptoWriter<CountedWriter> {
let key = [0u8; 32];
let iv = 0u128;
CryptoWriter::new(inner, AesCtr::new(&key, iv), 8 * 1024)
}
#[tokio::test]
async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() {
let stats = Stats::new();
let user = "middle-me-writer-no-rollback-user";
let user_stats = stats.get_or_create_user_stats_handle(user);
let write_calls = Arc::new(AtomicUsize::new(0));
let mut writer = make_crypto_writer(CountedWriter::new(write_calls.clone(), true));
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let payload = Bytes::from_static(&[0x11, 0x22, 0x33, 0x44, 0x55]);
let result = process_me_writer_response(
MeResponse::Data {
flags: 0,
data: payload.clone(),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
user,
Some(user_stats.as_ref()),
Some(64),
0,
&bytes_me2c,
11,
true,
false,
)
.await;
assert!(
matches!(result, Err(ProxyError::Io(_))),
"write failure must propagate as I/O error"
);
assert!(
write_calls.load(Ordering::Relaxed) > 0,
"writer must be attempted after successful quota reservation"
);
assert_eq!(
stats.get_user_quota_used(user),
payload.len() as u64,
"reserved quota must not roll back on write failure"
);
assert_eq!(
stats.get_quota_write_fail_bytes_total(),
payload.len() as u64,
"write-fail byte metric must include failed payload size"
);
assert_eq!(
stats.get_quota_write_fail_events_total(),
1,
"write-fail events metric must increment once"
);
assert_eq!(
stats.get_user_total_octets(user),
0,
"telemetry octets_to should not advance when write fails"
);
assert_eq!(
bytes_me2c.load(Ordering::Relaxed),
0,
"ME->C committed byte counter must not advance on write failure"
);
}
#[tokio::test]
async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() {
let stats = Stats::new();
let user = "middle-me-writer-precheck-user";
let limit = 8u64;
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), limit);
let write_calls = Arc::new(AtomicUsize::new(0));
let mut writer = make_crypto_writer(CountedWriter::new(write_calls.clone(), false));
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let result = process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xAA, 0xBB, 0xCC]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
user,
Some(user_stats.as_ref()),
Some(limit),
0,
&bytes_me2c,
12,
true,
false,
)
.await;
assert!(
matches!(result, Err(ProxyError::DataQuotaExceeded { .. })),
"pre-write quota rejection must return typed quota error"
);
assert_eq!(
write_calls.load(Ordering::Relaxed),
0,
"writer must not be polled when pre-write quota reservation fails"
);
assert_eq!(
stats.get_me_d2c_quota_reject_pre_write_total(),
1,
"pre-write quota reject metric must increment"
);
assert_eq!(
stats.get_user_quota_used(user),
limit,
"failed pre-write reservation must keep previous quota usage unchanged"
);
assert_eq!(
stats.get_quota_write_fail_bytes_total(),
0,
"write-fail bytes metric must stay unchanged on pre-write reject"
);
assert_eq!(
stats.get_quota_write_fail_events_total(),
0,
"write-fail events metric must stay unchanged on pre-write reject"
);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
}
@@ -1,112 +0,0 @@
use super::*;
use crate::stats::Stats;
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::Barrier;
use tokio::time::{Duration, timeout};
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!(
"middle-blackhat-held-{}-{idx}",
std::process::id()
)));
}
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"precondition: bounded lock cache must be saturated"
);
let (tx, _rx) = mpsc::channel::<C2MeCommand>(1);
tx.send(C2MeCommand::Close)
.await
.expect("queue prefill should succeed");
let pressure_seq_before = relay_pressure_event_seq();
let pressure_errors = Arc::new(AtomicUsize::new(0));
let mut pressure_workers = Vec::new();
for _ in 0..16 {
let tx = tx.clone();
let pressure_errors = Arc::clone(&pressure_errors);
pressure_workers.push(tokio::spawn(async move {
if enqueue_c2me_command(&tx, C2MeCommand::Close).await.is_err() {
pressure_errors.fetch_add(1, Ordering::Relaxed);
}
}));
}
let stats = Arc::new(Stats::new());
let user = format!("middle-blackhat-quota-race-{}", std::process::id());
let gate = Arc::new(Barrier::new(16));
let mut quota_workers = Vec::new();
for _ in 0..16u8 {
let stats = Arc::clone(&stats);
let user = user.clone();
let gate = Arc::clone(&gate);
quota_workers.push(tokio::spawn(async move {
gate.wait().await;
let user_lock = quota_user_lock(&user);
let _quota_guard = user_lock.lock().await;
if quota_would_be_exceeded_for_user(&stats, &user, Some(1), 1) {
return false;
}
stats.add_user_octets_to(&user, 1);
true
}));
}
let mut ok_count = 0usize;
let mut denied_count = 0usize;
for worker in quota_workers {
let result = timeout(Duration::from_secs(2), worker)
.await
.expect("quota worker must finish")
.expect("quota worker must not panic");
if result {
ok_count += 1;
} else {
denied_count += 1;
}
}
for worker in pressure_workers {
timeout(Duration::from_secs(2), worker)
.await
.expect("pressure worker must finish")
.expect("pressure worker must not panic");
}
assert_eq!(
stats.get_user_total_octets(&user),
1,
"black-hat campaign must not overshoot same-user quota under saturation"
);
assert!(ok_count <= 1, "at most one quota contender may succeed");
assert!(
denied_count >= 15,
"all remaining contenders must be quota-denied"
);
let pressure_seq_after = relay_pressure_event_seq();
assert!(
pressure_seq_after > pressure_seq_before,
"queue pressure leg must trigger pressure accounting"
);
assert!(
pressure_errors.load(Ordering::Relaxed) >= 1,
"at least one pressure worker should fail from persistent backpressure"
);
drop(retained);
}
@@ -1,708 +0,0 @@
use super::*;
use crate::crypto::AesCtr;
use crate::crypto::SecureRandom;
use crate::stats::Stats;
use crate::stream::{BufferPool, PooledBuffer};
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::duplex;
use tokio::sync::mpsc;
use tokio::time::{Duration as TokioDuration, timeout};
fn make_pooled_payload(data: &[u8]) -> PooledBuffer {
let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4));
let mut payload = pool.get();
payload.resize(data.len(), 0);
payload[..data.len()].copy_from_slice(data);
payload
}
#[tokio::test]
async fn write_client_payload_abridged_short_quickack_sets_flag_and_preserves_payload() {
let (mut read_side, write_side) = duplex(4096);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = vec![0xA1, 0xB2, 0xC3, 0xD4, 0x10, 0x20, 0x30, 0x40];
write_client_payload(
&mut writer,
ProtoTag::Abridged,
RPC_FLAG_QUICKACK,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("abridged quickack payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 1 + payload.len()];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read serialized abridged frame");
let plaintext = decryptor.decrypt(&encrypted);
assert_eq!(plaintext[0], 0x80 | ((payload.len() / 4) as u8));
assert_eq!(&plaintext[1..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_abridged_extended_header_is_encoded_correctly() {
let (mut read_side, write_side) = duplex(16 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
// Boundary where abridged switches to extended length encoding.
let payload = vec![0x5Au8; 0x7f * 4];
write_client_payload(
&mut writer,
ProtoTag::Abridged,
RPC_FLAG_QUICKACK,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("extended abridged payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 4 + payload.len()];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read serialized extended abridged frame");
let plaintext = decryptor.decrypt(&encrypted);
assert_eq!(plaintext[0], 0xff, "0x7f with quickack bit must be set");
assert_eq!(&plaintext[1..4], &[0x7f, 0x00, 0x00]);
assert_eq!(&plaintext[4..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_abridged_misaligned_is_rejected_fail_closed() {
let (_read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let err = write_client_payload(
&mut writer,
ProtoTag::Abridged,
0,
&[1, 2, 3],
&rng,
&mut frame_buf,
)
.await
.expect_err("misaligned abridged payload must be rejected");
let msg = format!("{err}");
assert!(
msg.contains("4-byte aligned"),
"error should explain alignment contract, got: {msg}"
);
}
#[tokio::test]
async fn write_client_payload_secure_misaligned_is_rejected_fail_closed() {
let (_read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let err = write_client_payload(
&mut writer,
ProtoTag::Secure,
0,
&[9, 8, 7, 6, 5],
&rng,
&mut frame_buf,
)
.await
.expect_err("misaligned secure payload must be rejected");
let msg = format!("{err}");
assert!(
msg.contains("Secure payload must be 4-byte aligned"),
"error should be explicit for fail-closed triage, got: {msg}"
);
}
#[tokio::test]
async fn write_client_payload_intermediate_quickack_sets_length_msb() {
let (mut read_side, write_side) = duplex(4096);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = b"hello-middle-relay";
write_client_payload(
&mut writer,
ProtoTag::Intermediate,
RPC_FLAG_QUICKACK,
payload,
&rng,
&mut frame_buf,
)
.await
.expect("intermediate quickack payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 4 + payload.len()];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read intermediate frame");
let plaintext = decryptor.decrypt(&encrypted);
let mut len_bytes = [0u8; 4];
len_bytes.copy_from_slice(&plaintext[..4]);
let len_with_flags = u32::from_le_bytes(len_bytes);
assert_ne!(len_with_flags & 0x8000_0000, 0, "quickack bit must be set");
assert_eq!((len_with_flags & 0x7fff_ffff) as usize, payload.len());
assert_eq!(&plaintext[4..], payload);
}
#[tokio::test]
async fn write_client_payload_secure_quickack_prefix_and_padding_bounds_hold() {
let (mut read_side, write_side) = duplex(4096);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = vec![0x33u8; 100]; // 4-byte aligned as required by secure mode.
write_client_payload(
&mut writer,
ProtoTag::Secure,
RPC_FLAG_QUICKACK,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("secure quickack payload should serialize");
writer.flush().await.expect("flush must succeed");
// Secure mode adds 1..=3 bytes of randomized tail padding.
let mut encrypted_header = [0u8; 4];
read_side
.read_exact(&mut encrypted_header)
.await
.expect("must read secure header");
let decrypted_header = decryptor.decrypt(&encrypted_header);
let header: [u8; 4] = decrypted_header
.try_into()
.expect("decrypted secure header must be 4 bytes");
let wire_len_raw = u32::from_le_bytes(header);
assert_ne!(
wire_len_raw & 0x8000_0000,
0,
"secure quickack bit must be set"
);
let wire_len = (wire_len_raw & 0x7fff_ffff) as usize;
assert!(wire_len >= payload.len());
let padding_len = wire_len - payload.len();
assert!(
(1..=3).contains(&padding_len),
"secure writer must add bounded random tail padding, got {padding_len}"
);
let mut encrypted_body = vec![0u8; wire_len];
read_side
.read_exact(&mut encrypted_body)
.await
.expect("must read secure body");
let decrypted_body = decryptor.decrypt(&encrypted_body);
assert_eq!(&decrypted_body[..payload.len()], payload.as_slice());
}
#[tokio::test]
#[ignore = "heavy: allocates >64MiB to validate abridged too-large fail-closed branch"]
async fn write_client_payload_abridged_too_large_is_rejected_fail_closed() {
let (_read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
// Exactly one 4-byte word above the encodable 24-bit abridged length range.
let payload = vec![0x00u8; (1 << 24) * 4];
let err = write_client_payload(
&mut writer,
ProtoTag::Abridged,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect_err("oversized abridged payload must be rejected");
let msg = format!("{err}");
assert!(
msg.contains("Abridged frame too large"),
"error must clearly indicate oversize fail-close path, got: {msg}"
);
}
#[tokio::test]
async fn write_client_ack_intermediate_is_little_endian() {
let (mut read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
write_client_ack(&mut writer, ProtoTag::Intermediate, 0x11_22_33_44)
.await
.expect("ack serialization should succeed");
writer.flush().await.expect("flush must succeed");
let mut encrypted = [0u8; 4];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read ack bytes");
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain.as_slice(), &0x11_22_33_44u32.to_le_bytes());
}
#[tokio::test]
async fn write_client_ack_abridged_is_big_endian() {
let (mut read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
write_client_ack(&mut writer, ProtoTag::Abridged, 0xDE_AD_BE_EF)
.await
.expect("ack serialization should succeed");
writer.flush().await.expect("flush must succeed");
let mut encrypted = [0u8; 4];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read ack bytes");
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain.as_slice(), &0xDE_AD_BE_EFu32.to_be_bytes());
}
#[tokio::test]
async fn write_client_payload_abridged_short_boundary_0x7e_is_single_byte_header() {
let (mut read_side, write_side) = duplex(1024 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = vec![0xABu8; 0x7e * 4];
write_client_payload(
&mut writer,
ProtoTag::Abridged,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("boundary payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 1 + payload.len()];
read_side.read_exact(&mut encrypted).await.unwrap();
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain[0], 0x7e);
assert_eq!(&plain[1..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_abridged_extended_without_quickack_has_clean_prefix() {
let (mut read_side, write_side) = duplex(16 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = vec![0x42u8; 0x80 * 4];
write_client_payload(
&mut writer,
ProtoTag::Abridged,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("extended payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 4 + payload.len()];
read_side.read_exact(&mut encrypted).await.unwrap();
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain[0], 0x7f);
assert_eq!(&plain[1..4], &[0x80, 0x00, 0x00]);
assert_eq!(&plain[4..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_intermediate_zero_length_emits_header_only() {
let (mut read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
write_client_payload(
&mut writer,
ProtoTag::Intermediate,
0,
&[],
&rng,
&mut frame_buf,
)
.await
.expect("zero-length intermediate payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = [0u8; 4];
read_side.read_exact(&mut encrypted).await.unwrap();
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain.as_slice(), &[0, 0, 0, 0]);
}
#[tokio::test]
async fn write_client_payload_intermediate_ignores_unrelated_flags() {
let (mut read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = [7u8; 12];
write_client_payload(
&mut writer,
ProtoTag::Intermediate,
0x4000_0000,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = [0u8; 16];
read_side.read_exact(&mut encrypted).await.unwrap();
let plain = decryptor.decrypt(&encrypted);
let len = u32::from_le_bytes(plain[0..4].try_into().unwrap());
assert_eq!(len, payload.len() as u32, "only quickack bit may affect header");
assert_eq!(&plain[4..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_secure_without_quickack_keeps_msb_clear() {
let (mut read_side, write_side) = duplex(4096);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = [0x1Du8; 64];
write_client_payload(
&mut writer,
ProtoTag::Secure,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted_header = [0u8; 4];
read_side.read_exact(&mut encrypted_header).await.unwrap();
let plain_header = decryptor.decrypt(&encrypted_header);
let h: [u8; 4] = plain_header.as_slice().try_into().unwrap();
let wire_len_raw = u32::from_le_bytes(h);
assert_eq!(wire_len_raw & 0x8000_0000, 0, "quickack bit must stay clear");
}
#[tokio::test]
async fn secure_padding_light_fuzz_distribution_has_multiple_outcomes() {
let (mut read_side, write_side) = duplex(256 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = [0x55u8; 100];
let mut seen = [false; 4];
for _ in 0..96 {
write_client_payload(
&mut writer,
ProtoTag::Secure,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("secure payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted_header = [0u8; 4];
read_side.read_exact(&mut encrypted_header).await.unwrap();
let plain_header = decryptor.decrypt(&encrypted_header);
let h: [u8; 4] = plain_header.as_slice().try_into().unwrap();
let wire_len = (u32::from_le_bytes(h) & 0x7fff_ffff) as usize;
let padding_len = wire_len - payload.len();
assert!((1..=3).contains(&padding_len));
seen[padding_len] = true;
let mut encrypted_body = vec![0u8; wire_len];
read_side.read_exact(&mut encrypted_body).await.unwrap();
let _ = decryptor.decrypt(&encrypted_body);
}
let distinct = (1..=3).filter(|idx| seen[*idx]).count();
assert!(
distinct >= 2,
"padding generator should not collapse to a single outcome under campaign"
);
}
#[tokio::test]
async fn write_client_payload_mixed_proto_sequence_preserves_stream_sync() {
let (mut read_side, write_side) = duplex(128 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let p1 = vec![1u8; 8];
let p2 = vec![2u8; 16];
let p3 = vec![3u8; 20];
write_client_payload(&mut writer, ProtoTag::Abridged, 0, &p1, &rng, &mut frame_buf)
.await
.unwrap();
write_client_payload(
&mut writer,
ProtoTag::Intermediate,
RPC_FLAG_QUICKACK,
&p2,
&rng,
&mut frame_buf,
)
.await
.unwrap();
write_client_payload(&mut writer, ProtoTag::Secure, 0, &p3, &rng, &mut frame_buf)
.await
.unwrap();
writer.flush().await.unwrap();
// Frame 1: abridged short.
let mut e1 = vec![0u8; 1 + p1.len()];
read_side.read_exact(&mut e1).await.unwrap();
let d1 = decryptor.decrypt(&e1);
assert_eq!(d1[0], (p1.len() / 4) as u8);
assert_eq!(&d1[1..], p1.as_slice());
// Frame 2: intermediate with quickack.
let mut e2 = vec![0u8; 4 + p2.len()];
read_side.read_exact(&mut e2).await.unwrap();
let d2 = decryptor.decrypt(&e2);
let l2 = u32::from_le_bytes(d2[0..4].try_into().unwrap());
assert_ne!(l2 & 0x8000_0000, 0);
assert_eq!((l2 & 0x7fff_ffff) as usize, p2.len());
assert_eq!(&d2[4..], p2.as_slice());
// Frame 3: secure with bounded tail.
let mut e3h = [0u8; 4];
read_side.read_exact(&mut e3h).await.unwrap();
let d3h = decryptor.decrypt(&e3h);
let l3 = (u32::from_le_bytes(d3h.as_slice().try_into().unwrap()) & 0x7fff_ffff) as usize;
assert!(l3 >= p3.len());
assert!((1..=3).contains(&(l3 - p3.len())));
let mut e3b = vec![0u8; l3];
read_side.read_exact(&mut e3b).await.unwrap();
let d3b = decryptor.decrypt(&e3b);
assert_eq!(&d3b[..p3.len()], p3.as_slice());
}
#[test]
fn should_yield_sender_boundary_matrix_blackhat() {
assert!(!should_yield_c2me_sender(0, false));
assert!(!should_yield_c2me_sender(0, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false));
assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true));
assert!(should_yield_c2me_sender(
C2ME_SENDER_FAIRNESS_BUDGET.saturating_add(1024),
true
));
}
#[test]
fn should_yield_sender_light_fuzz_matches_oracle() {
let mut s: u64 = 0xD00D_BAAD_F00D_CAFE;
for _ in 0..5000 {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let sent = (s as usize) & 0x1fff;
let backlog = (s & 1) != 0;
let expected = backlog && sent >= C2ME_SENDER_FAIRNESS_BUDGET;
assert_eq!(should_yield_c2me_sender(sent, backlog), expected);
}
}
#[test]
fn quota_would_be_exceeded_exact_remaining_one_byte() {
let stats = Stats::new();
let user = "quota-edge";
let quota = 100u64;
stats.add_user_octets_to(user, 99);
assert!(
!quota_would_be_exceeded_for_user(&stats, user, Some(quota), 1),
"exactly remaining budget should be allowed"
);
assert!(
quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2),
"one byte beyond remaining budget must be rejected"
);
}
#[test]
fn quota_would_be_exceeded_saturating_edge_remains_fail_closed() {
let stats = Stats::new();
let user = "quota-saturating-edge";
let quota = u64::MAX - 3;
stats.add_user_octets_to(user, u64::MAX - 4);
assert!(
quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2),
"saturating arithmetic edge must stay fail-closed"
);
}
#[test]
fn quota_exceeded_boundary_is_inclusive() {
let stats = Stats::new();
let user = "quota-inclusive-boundary";
stats.add_user_octets_to(user, 50);
assert!(quota_exceeded_for_user(&stats, user, Some(50)));
assert!(!quota_exceeded_for_user(&stats, user, Some(51)));
}
#[tokio::test]
async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(4);
enqueue_c2me_command(&tx, C2MeCommand::Close)
.await
.expect("close should enqueue on fast path");
let recv = timeout(TokioDuration::from_millis(50), rx.recv())
.await
.expect("must receive close command")
.expect("close command should be present");
assert!(matches!(recv, C2MeCommand::Close));
}
#[tokio::test]
async fn enqueue_c2me_data_full_then_drain_preserves_order() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
tx.send(C2MeCommand::Data {
payload: make_pooled_payload(&[1]),
flags: 10,
})
.await
.unwrap();
let tx2 = tx.clone();
let producer = tokio::spawn(async move {
enqueue_c2me_command(
&tx2,
C2MeCommand::Data {
payload: make_pooled_payload(&[2, 2]),
flags: 20,
},
)
.await
});
tokio::time::sleep(TokioDuration::from_millis(10)).await;
let first = rx.recv().await.expect("first item should exist");
match first {
C2MeCommand::Data { payload, flags } => {
assert_eq!(payload.as_ref(), &[1]);
assert_eq!(flags, 10);
}
C2MeCommand::Close => panic!("unexpected close as first item"),
}
producer.await.unwrap().expect("producer should complete");
let second = timeout(TokioDuration::from_millis(100), rx.recv())
.await
.unwrap()
.expect("second item should exist");
match second {
C2MeCommand::Data { payload, flags } => {
assert_eq!(payload.as_ref(), &[2, 2]);
assert_eq!(flags, 20);
}
C2MeCommand::Close => panic!("unexpected close as second item"),
}
}
@@ -2,8 +2,8 @@ use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::sync::{Arc, Mutex, OnceLock};
use tokio::io::AsyncWriteExt;
use tokio::io::duplex;
use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout};
@@ -48,18 +48,6 @@ fn make_idle_policy(soft_ms: u64, hard_ms: u64, grace_ms: u64) -> RelayClientIdl
}
}
fn idle_pressure_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
fn acquire_idle_pressure_test_lock() -> std::sync::MutexGuard<'static, ()> {
match idle_pressure_test_lock().lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
#[tokio::test]
async fn idle_policy_soft_mark_then_hard_close_increments_reason_counters() {
let (reader, _writer) = duplex(1024);
@@ -372,7 +360,7 @@ async fn stress_many_idle_sessions_fail_closed_without_hang() {
#[test]
fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@@ -402,7 +390,7 @@ fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() {
#[test]
fn pressure_does_not_evict_without_new_pressure_signal() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@@ -421,7 +409,7 @@ fn pressure_does_not_evict_without_new_pressure_signal() {
#[test]
fn stress_pressure_eviction_preserves_fifo_across_many_candidates() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@@ -457,7 +445,7 @@ fn stress_pressure_eviction_preserves_fifo_across_many_candidates() {
#[test]
fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@@ -491,7 +479,7 @@ fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() {
#[test]
fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@@ -524,7 +512,7 @@ fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() {
#[test]
fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@@ -543,7 +531,7 @@ fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() {
#[test]
fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@@ -575,7 +563,7 @@ fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() {
#[test]
fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@@ -601,7 +589,7 @@ fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated(
#[test]
fn blackhat_stale_pressure_must_not_survive_candidate_churn() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@@ -621,7 +609,7 @@ fn blackhat_stale_pressure_must_not_survive_candidate_churn() {
#[test]
fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
{
@@ -646,7 +634,7 @@ fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting(
#[test]
fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
{
@@ -673,7 +661,7 @@ fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() {
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims()
{
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Arc::new(Stats::new());
@@ -738,7 +726,7 @@ async fn integration_race_single_pressure_event_allows_at_most_one_eviction_unde
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalidation_and_budget() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Arc::new(Stats::new());
@@ -0,0 +1,62 @@
use super::*;
use std::panic::{AssertUnwindSafe, catch_unwind};
#[test]
fn blackhat_registry_poison_recovers_with_fail_closed_reset_and_pressure_accounting() {
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let _ = catch_unwind(AssertUnwindSafe(|| {
let registry = relay_idle_candidate_registry();
let mut guard = registry
.lock()
.expect("registry lock must be acquired before poison");
guard.by_conn_id.insert(
999,
RelayIdleCandidateMeta {
mark_order_seq: 1,
mark_pressure_seq: 0,
},
);
guard.ordered.insert((1, 999));
panic!("intentional poison for idle-registry recovery");
}));
// Helper lock must recover from poison, reset stale state, and continue.
assert!(mark_relay_idle_candidate(42));
assert_eq!(oldest_relay_idle_candidate(), Some(42));
let before = relay_pressure_event_seq();
note_relay_pressure_event();
let after = relay_pressure_event_seq();
assert!(
after > before,
"pressure accounting must still advance after poison"
);
clear_relay_idle_pressure_state_for_testing();
}
#[test]
fn clear_state_helper_must_reset_poisoned_registry_for_deterministic_fifo_tests() {
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let _ = catch_unwind(AssertUnwindSafe(|| {
let registry = relay_idle_candidate_registry();
let _guard = registry
.lock()
.expect("registry lock must be acquired before poison");
panic!("intentional poison while lock held");
}));
clear_relay_idle_pressure_state_for_testing();
assert_eq!(oldest_relay_idle_candidate(), None);
assert_eq!(relay_pressure_event_seq(), 0);
assert!(mark_relay_idle_candidate(7));
assert_eq!(oldest_relay_idle_candidate(), Some(7));
clear_relay_idle_pressure_state_for_testing();
}
@@ -1,131 +0,0 @@
use super::*;
use dashmap::DashMap;
use std::sync::Arc;
#[test]
fn saturation_uses_stable_overflow_lock_without_cache_growth() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let prefix = format!("middle-quota-held-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX);
let user = format!("middle-quota-overflow-{}", std::process::id());
let first = quota_user_lock(&user);
let second = quota_user_lock(&user);
assert!(
Arc::ptr_eq(&first, &second),
"overflow user must get deterministic same lock while cache is saturated"
);
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"overflow path must not grow bounded lock map"
);
assert!(
map.get(&user).is_none(),
"overflow user should stay outside bounded lock map under saturation"
);
drop(retained);
}
#[test]
fn overflow_striping_keeps_different_users_distributed() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let prefix = format!("middle-quota-dist-held-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
let a = quota_user_lock("middle-overflow-user-a");
let b = quota_user_lock("middle-overflow-user-b");
let c = quota_user_lock("middle-overflow-user-c");
let distinct = [
Arc::as_ptr(&a) as usize,
Arc::as_ptr(&b) as usize,
Arc::as_ptr(&c) as usize,
]
.iter()
.copied()
.collect::<std::collections::HashSet<_>>()
.len();
assert!(
distinct >= 2,
"striped overflow lock set should avoid collapsing all users to one lock"
);
drop(retained);
}
#[test]
fn reclaim_path_caches_new_user_after_stale_entries_drop() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let prefix = format!("middle-quota-reclaim-held-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
drop(retained);
let user = format!("middle-quota-reclaim-user-{}", std::process::id());
let got = quota_user_lock(&user);
assert!(map.get(&user).is_some());
assert!(
Arc::strong_count(&got) >= 2,
"after reclaim, lock should be held both by caller and map"
);
}
#[test]
fn overflow_path_same_user_is_stable_across_parallel_threads() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!(
"middle-quota-thread-held-{}-{idx}",
std::process::id()
)));
}
let user = format!("middle-quota-overflow-thread-user-{}", std::process::id());
let mut workers = Vec::new();
for _ in 0..32 {
let user = user.clone();
workers.push(std::thread::spawn(move || quota_user_lock(&user)));
}
let first = workers
.remove(0)
.join()
.expect("thread must return lock handle");
for worker in workers {
let got = worker.join().expect("thread must return lock handle");
assert!(
Arc::ptr_eq(&first, &got),
"same overflow user should resolve to one striped lock even under contention"
);
}
drop(retained);
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,372 @@
use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
use tokio::task::JoinSet;
use tokio::time::{Duration as TokioDuration, sleep};
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
RelayForensicsState {
trace_id: 0xB200_0000 + conn_id,
conn_id,
user: format!("tiny-frame-debt-concurrency-user-{conn_id}"),
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
started_at,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
RelayClientIdlePolicy {
enabled: true,
soft_idle: Duration::from_millis(50),
hard_idle: Duration::from_millis(120),
grace_after_downstream_activity: Duration::from_secs(0),
legacy_frame_read_timeout: Duration::from_millis(50),
}
}
async fn read_once(
crypto_reader: &mut CryptoReader<tokio::io::DuplexStream>,
proto: ProtoTag,
forensics: &RelayForensicsState,
frame_counter: &mut u64,
idle_state: &mut RelayClientIdleState,
) -> Result<Option<(PooledBuffer, bool)>> {
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
read_client_payload_with_idle_policy(
crypto_reader,
proto,
1024,
&buffer_pool,
forensics,
frame_counter,
&stats,
&idle_policy,
idle_state,
&last_downstream_activity_ms,
forensics.started_at,
)
.await
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_pure_tiny_floods_all_fail_closed() {
let mut set = JoinSet::new();
for idx in 0..32u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(1000 + idx, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let flood_plaintext = vec![0u8; 1024];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = run_relay_test_step_timeout(
"tiny flood task",
read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
assert!(matches!(result, Err(ProxyError::Proxy(_))));
assert_eq!(frame_counter, 0);
});
}
while let Some(result) = set.join_next().await {
result.expect("parallel tiny flood worker must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_benign_tiny_burst_then_real_all_pass() {
let mut set = JoinSet::new();
for idx in 0..24u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(2048);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(2000 + idx, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let payload = [idx as u8, 2, 3, 4];
let mut plaintext = Vec::with_capacity(20);
for _ in 0..6 {
plaintext.push(0x00);
}
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
let result = run_relay_test_step_timeout(
"benign tiny burst read",
read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await
.expect("benign payload must parse")
.expect("benign payload must return frame");
assert_eq!(result.0.as_ref(), &payload);
assert_eq!(frame_counter, 1);
});
}
while let Some(result) = set.join_next().await {
result.expect("parallel benign worker must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_lockstep_alternating_attack_under_jitter_closes() {
let mut set = JoinSet::new();
for idx in 0..12u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(3000 + idx, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(2000);
for n in 0..180u8 {
plaintext.push(0x00);
plaintext.push(0x01);
plaintext.extend_from_slice(&[n, n ^ 0x21, n ^ 0x42, n ^ 0x84]);
}
let encrypted = encrypt_for_reader(&plaintext);
let writer_task = tokio::spawn(async move {
for chunk in encrypted.chunks(17) {
writer.write_all(chunk).await.unwrap();
sleep(TokioDuration::from_millis(1)).await;
}
drop(writer);
});
let mut closed = false;
for _ in 0..220 {
let result = run_relay_test_step_timeout(
"alternating jitter read step",
read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
match result {
Ok(Some((_payload, _))) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected error in alternating jitter case: {other}"),
}
}
writer_task
.await
.expect("writer jitter task must not panic");
assert!(closed, "alternating attack must close before EOF");
});
}
while let Some(result) = set.join_next().await {
result.expect("alternating jitter worker must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_mixed_population_attackers_close_benign_survive() {
let mut set = JoinSet::new();
for idx in 0..20u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(4000 + idx, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
if idx % 2 == 0 {
let mut plaintext = Vec::with_capacity(1280);
for n in 0..140u8 {
plaintext.push(0x00);
plaintext.push(0x01);
plaintext.extend_from_slice(&[n, n, n, n]);
}
writer
.write_all(&encrypt_for_reader(&plaintext))
.await
.unwrap();
drop(writer);
let mut closed = false;
for _ in 0..200 {
match read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
)
.await
{
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected attacker error: {other}"),
}
}
assert!(closed, "attacker session must fail closed");
} else {
let payload = [1u8, 9, 8, 7];
let mut plaintext = Vec::new();
for _ in 0..4 {
plaintext.push(0x00);
}
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
writer
.write_all(&encrypt_for_reader(&plaintext))
.await
.unwrap();
let got = read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
)
.await
.expect("benign session must parse")
.expect("benign session must return a frame");
assert_eq!(got.0.as_ref(), &payload);
}
});
}
while let Some(result) = set.join_next().await {
result.expect("mixed-population worker must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn light_fuzz_parallel_patterns_no_hang_or_panic() {
let mut set = JoinSet::new();
for case in 0..40u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(5000 + case, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut seed = 0x9E37_79B9u64 ^ (case << 8);
let mut plaintext = Vec::with_capacity(2048);
for _ in 0..256 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let is_tiny = (seed & 1) == 0;
if is_tiny {
plaintext.push(0x00);
} else {
plaintext.push(0x01);
plaintext.extend_from_slice(&[(seed >> 8) as u8, 2, 3, 4]);
}
}
writer
.write_all(&encrypt_for_reader(&plaintext))
.await
.unwrap();
drop(writer);
for _ in 0..320 {
let step = run_relay_test_step_timeout(
"fuzz case read step",
read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
match step {
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => break,
Ok(None) => break,
Err(other) => panic!("unexpected fuzz case error: {other}"),
}
}
});
}
while let Some(result) = set.join_next().await {
result.expect("fuzz worker must not panic");
}
}
@@ -0,0 +1,435 @@
use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, PooledBuffer};
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
use tokio::time::{Duration as TokioDuration, sleep};
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
RelayForensicsState {
trace_id: 0xB300_0000 + conn_id,
conn_id,
user: format!("tiny-frame-debt-proto-chunk-user-{conn_id}"),
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
started_at,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
RelayClientIdlePolicy {
enabled: true,
soft_idle: Duration::from_millis(50),
hard_idle: Duration::from_millis(120),
grace_after_downstream_activity: Duration::from_secs(0),
legacy_frame_read_timeout: Duration::from_millis(50),
}
}
fn append_tiny_frame(plaintext: &mut Vec<u8>, proto: ProtoTag) {
match proto {
ProtoTag::Abridged => plaintext.push(0x00),
ProtoTag::Intermediate | ProtoTag::Secure => {
plaintext.extend_from_slice(&0u32.to_le_bytes())
}
}
}
fn append_real_frame(plaintext: &mut Vec<u8>, proto: ProtoTag, payload: [u8; 4]) {
match proto {
ProtoTag::Abridged => {
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
}
ProtoTag::Intermediate | ProtoTag::Secure => {
plaintext.extend_from_slice(&4u32.to_le_bytes());
plaintext.extend_from_slice(&payload);
}
}
}
async fn write_chunked_with_jitter(
writer: &mut tokio::io::DuplexStream,
bytes: &[u8],
mut seed: u64,
) {
let mut offset = 0usize;
while offset < bytes.len() {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let chunk_len = 1 + ((seed as usize) & 0x1f);
let end = (offset + chunk_len).min(bytes.len());
writer.write_all(&bytes[offset..end]).await.unwrap();
let delay_ms = ((seed >> 16) % 3) as u64;
if delay_ms > 0 {
sleep(TokioDuration::from_millis(delay_ms)).await;
}
offset = end;
}
}
async fn read_once_with_state(
crypto_reader: &mut CryptoReader<tokio::io::DuplexStream>,
proto: ProtoTag,
forensics: &RelayForensicsState,
frame_counter: &mut u64,
idle_state: &mut RelayClientIdleState,
) -> Result<Option<(PooledBuffer, bool)>> {
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
read_client_payload_with_idle_policy(
crypto_reader,
proto,
1024,
&buffer_pool,
forensics,
frame_counter,
&stats,
&idle_policy,
idle_state,
&last_downstream_activity_ms,
forensics.started_at,
)
.await
}
fn is_fail_closed_outcome(result: &Result<Option<(PooledBuffer, bool)>>) -> bool {
matches!(result, Err(ProxyError::Proxy(_)))
|| matches!(result, Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut)
}
#[tokio::test]
async fn intermediate_chunked_zero_flood_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6101, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(4 * 256);
for _ in 0..256 {
append_tiny_frame(&mut plaintext, ProtoTag::Intermediate);
}
let encrypted = encrypt_for_reader(&plaintext);
write_chunked_with_jitter(&mut writer, &encrypted, 0x1111_2222).await;
drop(writer);
let result = run_relay_test_step_timeout(
"intermediate flood read",
read_once_with_state(
&mut crypto_reader,
ProtoTag::Intermediate,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
assert!(
is_fail_closed_outcome(&result),
"zero-length flood must fail closed via debt guard or idle timeout"
);
assert_eq!(frame_counter, 0);
}
#[tokio::test]
async fn secure_chunked_zero_flood_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6102, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(4 * 256);
for _ in 0..256 {
append_tiny_frame(&mut plaintext, ProtoTag::Secure);
}
let encrypted = encrypt_for_reader(&plaintext);
write_chunked_with_jitter(&mut writer, &encrypted, 0x3333_4444).await;
drop(writer);
let result = run_relay_test_step_timeout(
"secure flood read",
read_once_with_state(
&mut crypto_reader,
ProtoTag::Secure,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
assert!(
is_fail_closed_outcome(&result),
"secure zero-length flood must fail closed via debt guard or idle timeout"
);
assert_eq!(frame_counter, 0);
}
#[tokio::test]
async fn intermediate_chunked_alternating_attack_closes_before_eof() {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6103, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(8 * 200);
for n in 0..180u8 {
append_tiny_frame(&mut plaintext, ProtoTag::Intermediate);
append_real_frame(
&mut plaintext,
ProtoTag::Intermediate,
[n, n ^ 1, n ^ 2, n ^ 3],
);
}
let encrypted = encrypt_for_reader(&plaintext);
let writer_task = tokio::spawn(async move {
write_chunked_with_jitter(&mut writer, &encrypted, 0x5555_6666).await;
drop(writer);
});
let mut closed = false;
for _ in 0..240 {
let step = run_relay_test_step_timeout(
"intermediate alternating read step",
read_once_with_state(
&mut crypto_reader,
ProtoTag::Intermediate,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
match step {
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected intermediate alternating error: {other}"),
}
}
writer_task
.await
.expect("intermediate writer task must not panic");
assert!(closed, "intermediate alternating attack must fail closed");
}
#[tokio::test]
async fn secure_chunked_alternating_attack_closes_before_eof() {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6104, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(8 * 200);
for n in 0..180u8 {
append_tiny_frame(&mut plaintext, ProtoTag::Secure);
append_real_frame(&mut plaintext, ProtoTag::Secure, [n, n ^ 7, n ^ 11, n ^ 19]);
}
let encrypted = encrypt_for_reader(&plaintext);
let writer_task = tokio::spawn(async move {
write_chunked_with_jitter(&mut writer, &encrypted, 0x7777_8888).await;
drop(writer);
});
let mut closed = false;
for _ in 0..240 {
let step = run_relay_test_step_timeout(
"secure alternating read step",
read_once_with_state(
&mut crypto_reader,
ProtoTag::Secure,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
match step {
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected secure alternating error: {other}"),
}
}
writer_task
.await
.expect("secure writer task must not panic");
assert!(closed, "secure alternating attack must fail closed");
}
#[tokio::test]
async fn intermediate_chunked_safe_small_burst_still_returns_real_frame() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6105, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let payload = [9u8, 8, 7, 6];
let mut plaintext = Vec::new();
for _ in 0..7 {
append_tiny_frame(&mut plaintext, ProtoTag::Intermediate);
}
append_real_frame(&mut plaintext, ProtoTag::Intermediate, payload);
let encrypted = encrypt_for_reader(&plaintext);
write_chunked_with_jitter(&mut writer, &encrypted, 0xAAAA_BBBB).await;
let result = read_once_with_state(
&mut crypto_reader,
ProtoTag::Intermediate,
&forensics,
&mut frame_counter,
&mut idle_state,
)
.await
.expect("intermediate safe burst should parse")
.expect("intermediate safe burst should return a frame");
assert_eq!(result.0.as_ref(), &payload);
assert_eq!(frame_counter, 1);
}
#[tokio::test]
async fn secure_chunked_safe_small_burst_still_returns_real_frame() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6106, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let payload = [3u8, 1, 4, 1];
let mut plaintext = Vec::new();
for _ in 0..7 {
append_tiny_frame(&mut plaintext, ProtoTag::Secure);
}
append_real_frame(&mut plaintext, ProtoTag::Secure, payload);
let encrypted = encrypt_for_reader(&plaintext);
write_chunked_with_jitter(&mut writer, &encrypted, 0xCCCC_DDDD).await;
let result = read_once_with_state(
&mut crypto_reader,
ProtoTag::Secure,
&forensics,
&mut frame_counter,
&mut idle_state,
)
.await
.expect("secure safe burst should parse")
.expect("secure safe burst should return a frame");
assert_eq!(result.0.as_ref(), &payload);
assert_eq!(frame_counter, 1);
}
#[tokio::test]
async fn light_fuzz_proto_chunking_outcomes_are_bounded() {
let mut seed = 0xDEAD_BEEF_2026_0322u64;
for case in 0..48u64 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let proto = if (seed & 1) == 0 {
ProtoTag::Intermediate
} else {
ProtoTag::Secure
};
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6200 + case, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut stream = Vec::new();
let mut local_seed = seed ^ case;
for _ in 0..220 {
local_seed ^= local_seed << 7;
local_seed ^= local_seed >> 9;
local_seed ^= local_seed << 8;
if (local_seed & 1) == 0 {
append_tiny_frame(&mut stream, proto);
} else {
let b = (local_seed >> 8) as u8;
append_real_frame(&mut stream, proto, [b, b ^ 0x12, b ^ 0x24, b ^ 0x48]);
}
}
let encrypted = encrypt_for_reader(&stream);
write_chunked_with_jitter(&mut writer, &encrypted, seed ^ 0x1234_5678).await;
drop(writer);
for _ in 0..260 {
let step = run_relay_test_step_timeout(
"fuzz proto read step",
read_once_with_state(
&mut crypto_reader,
proto,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
match step {
Ok(Some((_payload, _))) => {}
Err(ProxyError::Proxy(_)) => break,
Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut => break,
Ok(None) => break,
Err(other) => panic!("unexpected proto chunking fuzz error: {other}"),
}
}
}
}
@@ -0,0 +1,804 @@
use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
RelayForensicsState {
trace_id: 0xB100_0000 + conn_id,
conn_id,
user: format!("tiny-frame-debt-user-{conn_id}"),
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
started_at,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
RelayClientIdlePolicy {
enabled: true,
soft_idle: Duration::from_millis(50),
hard_idle: Duration::from_millis(120),
grace_after_downstream_activity: Duration::from_secs(0),
legacy_frame_read_timeout: Duration::from_millis(50),
}
}
async fn read_bounded(
crypto_reader: &mut CryptoReader<tokio::io::DuplexStream>,
proto_tag: ProtoTag,
buffer_pool: &Arc<BufferPool>,
forensics: &RelayForensicsState,
frame_counter: &mut u64,
stats: &Stats,
idle_policy: &RelayClientIdlePolicy,
idle_state: &mut RelayClientIdleState,
last_downstream_activity_ms: &AtomicU64,
session_started_at: Instant,
) -> Result<Option<(PooledBuffer, bool)>> {
run_relay_test_step_timeout(
"tiny-frame debt read step",
read_client_payload_with_idle_policy(
crypto_reader,
proto_tag,
1024,
buffer_pool,
forensics,
frame_counter,
stats,
idle_policy,
idle_state,
last_downstream_activity_ms,
session_started_at,
),
)
.await
}
fn simulate_tiny_debt_pattern(pattern: &[bool], max_steps: usize) -> (Option<usize>, u32, usize) {
let mut debt = 0u32;
let mut reals = 0usize;
for (idx, is_tiny) in pattern.iter().copied().take(max_steps).enumerate() {
if is_tiny {
debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY);
if debt >= TINY_FRAME_DEBT_LIMIT {
return (Some(idx + 1), debt, reals);
}
} else {
reals = reals.saturating_add(1);
debt = debt.saturating_sub(1);
}
}
(None, debt, reals)
}
#[test]
fn tiny_frame_debt_constants_match_security_budget_expectations() {
assert_eq!(TINY_FRAME_DEBT_PER_TINY, 8);
assert_eq!(TINY_FRAME_DEBT_LIMIT, 512);
}
#[test]
fn relay_client_idle_state_initial_debt_is_zero() {
let state = RelayClientIdleState::new(Instant::now());
assert_eq!(state.tiny_frame_debt, 0);
}
#[test]
fn on_client_frame_does_not_reset_tiny_frame_debt() {
let now = Instant::now();
let mut state = RelayClientIdleState::new(now);
state.tiny_frame_debt = 77;
state.on_client_frame(now);
assert_eq!(state.tiny_frame_debt, 77);
}
#[test]
fn tiny_frame_debt_increment_is_saturating() {
let mut debt = u32::MAX - 1;
debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY);
assert_eq!(debt, u32::MAX);
}
#[test]
fn tiny_frame_debt_decrement_is_saturating() {
let mut debt = 0u32;
debt = debt.saturating_sub(1);
assert_eq!(debt, 0);
}
#[test]
fn consecutive_tiny_frames_close_exactly_at_threshold() {
let max_tiny_without_close = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize;
let pattern = vec![true; max_tiny_without_close];
let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert_eq!(closed_at, Some(max_tiny_without_close));
}
#[test]
fn one_less_than_threshold_tiny_frames_do_not_close() {
let tiny_count = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize - 1;
let pattern = vec![true; tiny_count];
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert_eq!(closed_at, None);
assert!(debt < TINY_FRAME_DEBT_LIMIT);
}
#[test]
fn alternating_one_to_one_closes_with_bounded_real_frame_count() {
let mut pattern = Vec::with_capacity(512);
for _ in 0..256 {
pattern.push(true);
pattern.push(false);
}
let (closed_at, _, reals) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert!(closed_at.is_some());
assert!(
reals <= 80,
"expected bounded real frames before close, got {reals}"
);
}
#[test]
fn alternating_one_to_eight_is_stable_for_long_runs() {
let mut pattern = Vec::with_capacity(9 * 5000);
for _ in 0..5000 {
pattern.push(true);
for _ in 0..8 {
pattern.push(false);
}
}
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert_eq!(closed_at, None);
assert!(debt <= TINY_FRAME_DEBT_PER_TINY);
}
#[test]
fn alternating_one_to_seven_eventually_closes() {
let mut pattern = Vec::with_capacity(8 * 2000);
for _ in 0..2000 {
pattern.push(true);
for _ in 0..7 {
pattern.push(false);
}
}
let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert!(
closed_at.is_some(),
"1:7 tiny-to-real must eventually close"
);
}
#[test]
fn two_tiny_one_real_closes_faster_than_one_to_one() {
let mut one_to_one = Vec::with_capacity(512);
for _ in 0..256 {
one_to_one.push(true);
one_to_one.push(false);
}
let mut two_to_one = Vec::with_capacity(768);
for _ in 0..256 {
two_to_one.push(true);
two_to_one.push(true);
two_to_one.push(false);
}
let (a_close, _, _) = simulate_tiny_debt_pattern(&one_to_one, one_to_one.len());
let (b_close, _, _) = simulate_tiny_debt_pattern(&two_to_one, two_to_one.len());
assert!(a_close.is_some() && b_close.is_some());
assert!(b_close.unwrap_or(usize::MAX) < a_close.unwrap_or(0));
}
#[test]
fn burst_then_drain_can_recover_without_close() {
let burst_tiny = ((TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) / 2) as usize;
let mut pattern = Vec::with_capacity(burst_tiny + 600);
for _ in 0..burst_tiny {
pattern.push(true);
}
pattern.extend(std::iter::repeat_n(false, 600));
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert_eq!(closed_at, None);
assert_eq!(debt, 0);
}
#[test]
fn light_fuzz_tiny_frame_debt_model_stays_within_bounds() {
let mut seed = 0xA5A5_91C3_2026_0322u64;
for _case in 0..128 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let len = 512 + ((seed as usize) & 0x3ff);
let mut pattern = Vec::with_capacity(len);
let mut local_seed = seed;
for _ in 0..len {
local_seed ^= local_seed << 7;
local_seed ^= local_seed >> 9;
local_seed ^= local_seed << 8;
pattern.push((local_seed & 1) == 0);
}
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
if closed_at.is_none() {
assert!(debt < TINY_FRAME_DEBT_LIMIT);
}
assert!(debt <= u32::MAX);
}
}
#[test]
fn stress_many_independent_simulations_keep_isolated_debt_state() {
for idx in 0..2048usize {
let mut pattern = Vec::with_capacity(64);
for j in 0..64usize {
pattern.push(((idx ^ j) & 3) == 0);
}
let (_closed_at, debt, _reals) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert!(debt <= TINY_FRAME_DEBT_LIMIT.saturating_add(TINY_FRAME_DEBT_PER_TINY));
}
}
#[tokio::test]
async fn idle_policy_enabled_intermediate_zero_length_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(11, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let flood_plaintext = vec![0u8; 4 * 256];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Intermediate,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(matches!(result, Err(ProxyError::Proxy(_))));
}
#[tokio::test]
async fn idle_policy_enabled_secure_zero_length_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(12, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let flood_plaintext = vec![0u8; 4 * 256];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Secure,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(matches!(result, Err(ProxyError::Proxy(_))));
}
#[tokio::test]
async fn intermediate_alternating_zero_and_real_eventually_closes() {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(13, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let mut plaintext = Vec::with_capacity(3000);
for idx in 0..160u8 {
plaintext.extend_from_slice(&0u32.to_le_bytes());
plaintext.extend_from_slice(&4u32.to_le_bytes());
plaintext.extend_from_slice(&[idx, idx ^ 0x11, idx ^ 0x22, idx ^ 0x33]);
}
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
drop(writer);
let mut closed = false;
for _ in 0..220 {
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Intermediate,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
match result {
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected error while probing alternating close: {other}"),
}
}
assert!(closed, "intermediate alternating attack must fail closed");
}
#[tokio::test]
async fn small_tiny_burst_followed_by_real_frame_does_not_spuriously_close() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(14, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let mut plaintext = Vec::with_capacity(64);
for _ in 0..8 {
plaintext.push(0x00);
}
plaintext.push(0x01);
plaintext.extend_from_slice(&[1, 2, 3, 4]);
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
let first = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
match first {
Ok(Some((payload, _))) => assert_eq!(payload.as_ref(), &[1, 2, 3, 4]),
Err(e) => panic!("unexpected close after small tiny burst: {e}"),
Ok(None) => panic!("unexpected EOF before real frame"),
}
}
#[tokio::test]
async fn idle_policy_enabled_zero_length_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(1, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let flood_plaintext = vec![0u8; 1024];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer
.write_all(&flood_encrypted)
.await
.expect("zero-length flood bytes must be writable");
drop(writer);
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(
matches!(result, Err(ProxyError::Proxy(_))),
"idle policy enabled must fail closed for pure zero-length flood"
);
}
#[tokio::test]
async fn idle_policy_enabled_alternating_tiny_real_eventually_closes() {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(2, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let mut plaintext = Vec::with_capacity(256 * 6);
for idx in 0..=255u8 {
plaintext.push(0x00);
plaintext.push(0x01);
plaintext.extend_from_slice(&[idx, idx ^ 0x55, idx ^ 0xAA, 0x11]);
}
let encrypted = encrypt_for_reader(&plaintext);
writer
.write_all(&encrypted)
.await
.expect("alternating flood bytes must be writable");
drop(writer);
let mut saw_proxy_close = false;
for _ in 0..300 {
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
match result {
Ok(Some((_payload, _quickack))) => {}
Err(ProxyError::Proxy(_)) => {
saw_proxy_close = true;
break;
}
Err(ProxyError::Io(e)) => panic!("unexpected IO error before close: {e}"),
Ok(None) => panic!("unexpected EOF before debt-based closure"),
Err(other) => panic!("unexpected error before close: {other}"),
}
}
assert!(
saw_proxy_close,
"alternating tiny/real sequence must eventually fail closed"
);
}
#[tokio::test]
async fn enabled_idle_policy_valid_nonzero_frame_still_passes() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(3, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let payload = [7u8, 8, 9, 10];
let mut plaintext = Vec::with_capacity(1 + payload.len());
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
let encrypted = encrypt_for_reader(&plaintext);
writer
.write_all(&encrypted)
.await
.expect("nonzero frame must be writable");
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await
.expect("valid frame should decode")
.expect("valid frame should return payload");
assert_eq!(result.0.as_ref(), &payload);
assert!(!result.1);
assert_eq!(frame_counter, 1);
}
#[tokio::test]
async fn abridged_quickack_tiny_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(21, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let flood_plaintext = vec![0x80u8; 256];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(
matches!(result, Err(ProxyError::Proxy(_))),
"quickack-marked zero-length flood must fail closed"
);
}
#[tokio::test]
async fn abridged_extended_zero_len_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(22, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let mut flood_plaintext = Vec::with_capacity(4 * 256);
for _ in 0..256 {
flood_plaintext.extend_from_slice(&[0x7f, 0x00, 0x00, 0x00]);
}
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(
matches!(result, Err(ProxyError::Proxy(_))),
"extended zero-length abridged flood must fail closed"
);
}
#[tokio::test]
async fn one_to_eight_abridged_wire_pattern_survives_without_false_positive_close() {
let mut plaintext = Vec::with_capacity(9 * 300);
for idx in 0..300usize {
plaintext.push(0x00);
for _ in 0..8 {
let b = idx as u8;
plaintext.push(0x01);
plaintext.extend_from_slice(&[b, b ^ 0x11, b ^ 0x22, b ^ 0x33]);
}
}
// Keep the test single-task and deterministic: make duplex capacity larger than the
// generated ciphertext so write_all cannot block waiting for a concurrent reader.
let duplex_capacity = plaintext.len().saturating_add(1024);
let (reader, mut writer) = duplex(duplex_capacity);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(23, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
drop(writer);
let mut closed = false;
for _ in 0..3000 {
match read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await
{
Ok(Some(_)) => {}
Ok(None) => break,
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Err(other) => panic!("unexpected error in 1:8 wire test: {other}"),
}
}
assert!(
!closed,
"wire-level 1:8 tiny-to-real pattern should not trigger debt close"
);
}
#[tokio::test]
async fn deterministic_light_fuzz_abridged_wire_behavior_matches_model() {
let mut seed = 0xD1CE_BAAD_2026_0322u64;
for case_idx in 0..32u64 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let events = 300 + ((seed as usize) & 0xff);
let mut pattern = Vec::with_capacity(events);
let mut local = seed;
for _ in 0..events {
local ^= local << 7;
local ^= local >> 9;
local ^= local << 8;
pattern.push((local & 0x03) == 0);
}
let mut plaintext = Vec::with_capacity(events * 6);
for (idx, tiny) in pattern.iter().copied().enumerate() {
if tiny {
plaintext.push(0x00);
} else {
let b = (idx as u8) ^ (case_idx as u8);
plaintext.push(0x01);
plaintext.extend_from_slice(&[b, b ^ 0x1F, b ^ 0x7A, b ^ 0xC3]);
}
}
let (reader, mut writer) = duplex(16 * 1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(500 + case_idx, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
writer
.write_all(&encrypt_for_reader(&plaintext))
.await
.unwrap();
drop(writer);
let (expected_close, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
let mut observed_close = false;
for _ in 0..(events + 8) {
match read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await
{
Ok(Some(_)) => {}
Ok(None) => break,
Err(ProxyError::Proxy(_)) => {
observed_close = true;
break;
}
Err(other) => panic!("unexpected fuzz error: {other}"),
}
}
assert_eq!(
observed_close,
expected_close.is_some(),
"wire parser behavior must match debt model for case {case_idx}"
);
}
}
@@ -0,0 +1,121 @@
use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
RelayForensicsState {
trace_id: 0xB000_0000 + conn_id,
conn_id,
user: format!("zero-len-test-user-{conn_id}"),
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
started_at,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
#[tokio::test]
async fn adversarial_legacy_zero_length_flood_is_fail_closed() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(1, session_started_at);
let mut frame_counter = 0u64;
let flood_plaintext = vec![0u8; 128];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer
.write_all(&flood_encrypted)
.await
.expect("zero-length flood bytes must be writable");
drop(writer);
let result = read_client_payload_legacy(
&mut crypto_reader,
ProtoTag::Abridged,
1024,
Duration::from_millis(30),
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
)
.await;
match result {
Err(ProxyError::Proxy(msg)) => {
assert!(
msg.contains("Excessive zero-length"),
"legacy mode must close flood with explicit zero-length reason, got: {msg}"
);
}
Ok(None) => panic!("legacy zero-length flood must not be accepted as EOF"),
Ok(Some(_)) => panic!("legacy zero-length flood must not produce a data frame"),
Err(err) => panic!("legacy zero-length flood must be a Proxy error, got: {err}"),
}
}
#[tokio::test]
async fn business_abridged_nonzero_frame_still_passes() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(2, session_started_at);
let mut frame_counter = 0u64;
let payload = [1u8, 2, 3, 4];
let mut plaintext = Vec::with_capacity(1 + payload.len());
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
let encrypted = encrypt_for_reader(&plaintext);
writer
.write_all(&encrypted)
.await
.expect("nonzero abridged frame must be writable");
let result = read_client_payload_legacy(
&mut crypto_reader,
ProtoTag::Abridged,
1024,
Duration::from_millis(30),
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
)
.await
.expect("valid abridged frame should decode")
.expect("valid abridged frame should return payload");
assert_eq!(result.0.as_ref(), &payload);
assert!(!result.1, "quickack flag must remain false");
assert_eq!(frame_counter, 1);
}
+22 -5
View File
@@ -78,7 +78,8 @@ async fn relay_hol_blocking_prevention_regression() {
async fn relay_quota_mid_session_cutoff() {
let stats = Arc::new(Stats::new());
let user = "quota-mid-user";
let quota = 5000;
let quota = 5000u64;
let c2s_buf_size = 1024usize;
let (client_peer, relay_client) = duplex(8192);
let (relay_server, server_peer) = duplex(8192);
@@ -93,7 +94,7 @@ async fn relay_quota_mid_session_cutoff() {
client_writer,
server_reader,
server_writer,
1024,
c2s_buf_size,
1024,
user,
Arc::clone(&stats),
@@ -120,9 +121,25 @@ async fn relay_quota_mid_session_cutoff() {
other => panic!("Expected DataQuotaExceeded error, got: {:?}", other),
}
let mut small_buf = [0u8; 1];
let n = sp_reader.read(&mut small_buf).await.unwrap();
assert_eq!(n, 0, "Server must see EOF after quota reached");
let mut overshoot_bytes = 0usize;
let mut buf = [0u8; 256];
loop {
match timeout(Duration::from_millis(20), sp_reader.read(&mut buf)).await {
Ok(Ok(0)) => break,
Ok(Ok(n)) => overshoot_bytes = overshoot_bytes.saturating_add(n),
Ok(Err(e)) => panic!("server read must not fail after relay cutoff: {e}"),
Err(_) => break,
}
}
assert!(
overshoot_bytes <= c2s_buf_size,
"post-write cutoff may leak at most one C->S chunk after boundary, got {overshoot_bytes}"
);
assert!(
stats.get_user_quota_used(user) <= quota.saturating_add(c2s_buf_size as u64),
"accounted quota must remain bounded by one in-flight chunk overshoot"
);
}
#[tokio::test]
@@ -0,0 +1,243 @@
use super::*;
use std::collections::VecDeque;
use std::io;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::time::Instant;
struct ScriptedWriter {
scripted_writes: Arc<Mutex<VecDeque<usize>>>,
write_calls: Arc<AtomicUsize>,
}
impl ScriptedWriter {
fn new(script: &[usize], write_calls: Arc<AtomicUsize>) -> Self {
Self {
scripted_writes: Arc::new(Mutex::new(script.iter().copied().collect())),
write_calls,
}
}
}
impl AsyncWrite for ScriptedWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
this.write_calls.fetch_add(1, Ordering::Relaxed);
let planned = this
.scripted_writes
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.pop_front()
.unwrap_or(buf.len());
Poll::Ready(Ok(planned.min(buf.len())))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
fn make_stats_io_with_script(
user: &str,
quota_limit: u64,
precharged_quota: u64,
script: &[usize],
) -> (
StatsIo<ScriptedWriter>,
Arc<Stats>,
Arc<AtomicUsize>,
Arc<AtomicBool>,
) {
let stats = Arc::new(Stats::new());
if precharged_quota > 0 {
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), precharged_quota);
}
let write_calls = Arc::new(AtomicUsize::new(0));
let quota_exceeded = Arc::new(AtomicBool::new(false));
let io = StatsIo::new(
ScriptedWriter::new(script, write_calls.clone()),
Arc::new(SharedCounters::new()),
stats.clone(),
user.to_string(),
Some(quota_limit),
quota_exceeded.clone(),
Instant::now(),
);
(io, stats, write_calls, quota_exceeded)
}
#[tokio::test]
async fn direct_partial_write_charges_only_committed_bytes_without_double_charge() {
let user = "direct-partial-charge-user";
let (mut io, stats, write_calls, quota_exceeded) =
make_stats_io_with_script(user, 1_048_576, 0, &[8 * 1024, 8 * 1024, 48 * 1024]);
let payload = vec![0xAB; 64 * 1024];
let n1 = io
.write(&payload)
.await
.expect("first partial write must succeed");
let n2 = io
.write(&payload)
.await
.expect("second partial write must succeed");
let n3 = io.write(&payload).await.expect("tail write must succeed");
assert_eq!(n1, 8 * 1024);
assert_eq!(n2, 8 * 1024);
assert_eq!(n3, 48 * 1024);
assert_eq!(write_calls.load(Ordering::Relaxed), 3);
assert_eq!(
stats.get_user_quota_used(user),
(n1 + n2 + n3) as u64,
"quota accounting must follow committed bytes only"
);
assert_eq!(
stats.get_user_total_octets(user),
(n1 + n2 + n3) as u64,
"telemetry octets should match committed bytes on successful writes"
);
assert!(
!quota_exceeded.load(Ordering::Acquire),
"quota flag should stay false under large remaining budget"
);
}
#[tokio::test]
async fn direct_hybrid_branch_selection_matches_contract() {
let near_limit = 256 * 1024u64;
let near_remaining = 32 * 1024u64;
let (mut near_io, _stats, _calls, _flag) = make_stats_io_with_script(
"direct-near-limit-hard-check-user",
near_limit,
near_limit - near_remaining,
&[4 * 1024],
);
let near_payload = vec![0x11; 4 * 1024];
let near_written = near_io
.write(&near_payload)
.await
.expect("near-limit write must succeed");
assert_eq!(near_written, 4 * 1024);
assert_eq!(
near_io.quota_bytes_since_check, 0,
"near-limit branch must go through immediate hard check"
);
let (mut far_small_io, _stats, _calls, _flag) =
make_stats_io_with_script("direct-far-small-amortized-user", 1_048_576, 0, &[4 * 1024]);
let far_small_payload = vec![0x22; 4 * 1024];
let far_small_written = far_small_io
.write(&far_small_payload)
.await
.expect("small far-from-limit write must succeed");
assert_eq!(far_small_written, 4 * 1024);
assert_eq!(
far_small_io.quota_bytes_since_check,
4 * 1024,
"small far-from-limit write must go through amortized path"
);
let (mut far_large_io, _stats, _calls, _flag) = make_stats_io_with_script(
"direct-far-large-hard-check-user",
1_048_576,
0,
&[32 * 1024],
);
let far_large_payload = vec![0x33; 32 * 1024];
let far_large_written = far_large_io
.write(&far_large_payload)
.await
.expect("large write must succeed");
assert_eq!(far_large_written, 32 * 1024);
assert_eq!(
far_large_io.quota_bytes_since_check, 0,
"large write must force immediate hard check even far from limit"
);
}
#[tokio::test]
async fn remaining_before_zero_rejects_without_calling_inner_writer() {
let user = "direct-zero-remaining-user";
let limit = 8u64;
let (mut io, stats, write_calls, quota_exceeded) =
make_stats_io_with_script(user, limit, limit, &[1]);
let err = io
.write(&[0x44])
.await
.expect_err("write must fail when remaining quota is zero");
assert!(
is_quota_io_error(&err),
"zero-remaining gate must return typed quota I/O error"
);
assert_eq!(
write_calls.load(Ordering::Relaxed),
0,
"inner poll_write must not be called when remaining quota is zero"
);
assert!(
quota_exceeded.load(Ordering::Acquire),
"zero-remaining gate must set exceeded flag"
);
assert_eq!(stats.get_user_quota_used(user), limit);
}
#[tokio::test]
async fn exceeded_flag_blocks_following_poll_before_inner_write() {
let user = "direct-exceeded-visibility-user";
let (mut io, stats, write_calls, quota_exceeded) =
make_stats_io_with_script(user, 1, 0, &[1, 1]);
let first = io
.write(&[0x55])
.await
.expect("first byte should consume remaining quota");
assert_eq!(first, 1);
assert!(
quota_exceeded.load(Ordering::Acquire),
"hard check should store quota_exceeded after boundary hit"
);
let second = io
.write(&[0x66])
.await
.expect_err("next write must be rejected by early exceeded gate");
assert!(
is_quota_io_error(&second),
"following write must fail with typed quota error"
);
assert_eq!(
write_calls.load(Ordering::Relaxed),
1,
"second write must be cut before touching inner writer"
);
assert_eq!(stats.get_user_quota_used(user), 1);
}
#[test]
fn adaptive_interval_clamp_matches_contract() {
assert_eq!(quota_adaptive_interval_bytes(0), 4 * 1024);
assert_eq!(quota_adaptive_interval_bytes(2 * 1024), 4 * 1024);
assert_eq!(quota_adaptive_interval_bytes(32 * 1024), 16 * 1024);
assert_eq!(quota_adaptive_interval_bytes(256 * 1024), 64 * 1024);
assert!(should_immediate_quota_check(32 * 1024, 4 * 1024));
assert!(should_immediate_quota_check(1_048_576, 32 * 1024));
assert!(!should_immediate_quota_check(1_048_576, 4 * 1024));
}
@@ -29,6 +29,11 @@ async fn read_available<R: AsyncRead + Unpin>(reader: &mut R, budget: Duration)
total
}
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
}
#[tokio::test]
async fn integration_full_duplex_exact_budget_then_hard_cutoff() {
let stats = Arc::new(Stats::new());
@@ -102,14 +107,14 @@ async fn integration_full_duplex_exact_budget_then_hard_cutoff() {
relay_result,
Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-full-duplex-boundary-user"
));
assert!(stats.get_user_total_octets(user) <= 10);
assert!(stats.get_user_quota_used(user) <= 10);
}
#[tokio::test]
async fn negative_preloaded_quota_blocks_both_directions_immediately() {
let stats = Arc::new(Stats::new());
let user = "quota-preloaded-cutoff-user";
stats.add_user_octets_from(user, 5);
preload_user_quota(stats.as_ref(), user, 5);
let (mut client_peer, relay_client) = duplex(2048);
let (relay_server, mut server_peer) = duplex(2048);
@@ -154,7 +159,7 @@ async fn negative_preloaded_quota_blocks_both_directions_immediately() {
relay_result,
Err(ProxyError::DataQuotaExceeded { .. })
));
assert!(stats.get_user_total_octets(user) <= 5);
assert!(stats.get_user_quota_used(user) <= 5);
}
#[tokio::test]
@@ -212,7 +217,7 @@ async fn edge_quota_one_bidirectional_race_allows_at_most_one_forwarded_octet()
relay_result,
Err(ProxyError::DataQuotaExceeded { .. })
));
assert!(stats.get_user_total_octets(user) <= 1);
assert!(stats.get_user_quota_used(user) <= 1);
}
#[tokio::test]
@@ -277,7 +282,7 @@ async fn adversarial_blackhat_alternating_fragmented_jitter_never_overshoots_glo
delivered_to_server + delivered_to_client <= quota as usize,
"combined forwarded bytes must never exceed configured quota"
);
assert!(stats.get_user_total_octets(user) <= quota);
assert!(stats.get_user_quota_used(user) <= quota);
}
#[tokio::test]
@@ -356,7 +361,7 @@ async fn light_fuzz_randomized_schedule_preserves_quota_and_forwarded_byte_invar
"fuzz case {case}: forwarded bytes must not exceed quota"
);
assert!(
stats.get_user_total_octets(&user) <= quota,
stats.get_user_quota_used(&user) <= quota,
"fuzz case {case}: accounted bytes must not exceed quota"
);
}
@@ -451,7 +456,7 @@ async fn stress_multi_relay_same_user_mixed_direction_jitter_respects_global_quo
}
assert!(
stats.get_user_total_octets(user) <= quota,
stats.get_user_quota_used(user) <= quota,
"global per-user quota must hold under concurrent mixed-direction relay stress"
);
assert!(
@@ -0,0 +1,399 @@
use super::relay_bidirectional;
use crate::error::ProxyError;
use crate::stats::Stats;
use crate::stream::BufferPool;
use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::time::{Duration, timeout};
async fn read_available<R: tokio::io::AsyncRead + Unpin>(
reader: &mut R,
budget: Duration,
) -> usize {
let start = tokio::time::Instant::now();
let mut total = 0usize;
let mut buf = [0u8; 128];
loop {
let elapsed = start.elapsed();
if elapsed >= budget {
break;
}
let remaining = budget.saturating_sub(elapsed);
match timeout(remaining, reader.read(&mut buf)).await {
Ok(Ok(0)) => break,
Ok(Ok(n)) => total = total.saturating_add(n),
Ok(Err(_)) | Err(_) => break,
}
}
total
}
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
}
#[tokio::test]
async fn positive_quota_path_forwards_both_directions_within_limit() {
let stats = Arc::new(Stats::new());
let user = "quota-extended-positive-user";
let (mut client_peer, relay_client) = duplex(2048);
let (relay_server, mut server_peer) = duplex(2048);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
256,
256,
user,
Arc::clone(&stats),
Some(16),
Arc::new(BufferPool::new()),
));
client_peer
.write_all(&[0xAA, 0xBB, 0xCC, 0xDD])
.await
.unwrap();
server_peer.read_exact(&mut [0u8; 4]).await.unwrap();
server_peer
.write_all(&[0x11, 0x22, 0x33, 0x44])
.await
.unwrap();
client_peer.read_exact(&mut [0u8; 4]).await.unwrap();
drop(client_peer);
drop(server_peer);
let relay_result = timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
assert!(relay_result.is_ok());
assert!(stats.get_user_quota_used(user) <= 16);
}
#[tokio::test]
async fn negative_preloaded_quota_forbids_any_forwarding() {
let stats = Arc::new(Stats::new());
let user = "quota-extended-negative-user";
preload_user_quota(stats.as_ref(), user, 8);
let (mut client_peer, relay_client) = duplex(1024);
let (relay_server, mut server_peer) = duplex(1024);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
128,
128,
user,
Arc::clone(&stats),
Some(8),
Arc::new(BufferPool::new()),
));
client_peer.write_all(&[0xAA]).await.unwrap();
server_peer.write_all(&[0xBB]).await.unwrap();
assert_eq!(
read_available(&mut server_peer, Duration::from_millis(120)).await,
0
);
assert_eq!(
read_available(&mut client_peer, Duration::from_millis(120)).await,
0
);
let relay_result = timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
assert!(matches!(
relay_result,
Err(ProxyError::DataQuotaExceeded { .. })
));
assert!(stats.get_user_quota_used(user) <= 8);
}
#[tokio::test]
async fn edge_quota_one_ensures_at_most_one_byte_across_directions() {
let stats = Arc::new(Stats::new());
let user = "quota-extended-edge-user";
let (mut client_peer, relay_client) = duplex(1024);
let (relay_server, mut server_peer) = duplex(1024);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
128,
128,
user,
Arc::clone(&stats),
Some(1),
Arc::new(BufferPool::new()),
));
let _ = tokio::join!(
client_peer.write_all(&[0xFE]),
server_peer.write_all(&[0xEF]),
);
let mut buf = [0u8; 1];
let delivered_s2c = timeout(Duration::from_millis(120), client_peer.read(&mut buf))
.await
.unwrap()
.unwrap_or(0);
let delivered_c2s = timeout(Duration::from_millis(120), server_peer.read(&mut buf))
.await
.unwrap()
.unwrap_or(0);
assert!(delivered_s2c + delivered_c2s <= 1);
let relay_result = timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
assert!(matches!(
relay_result,
Err(ProxyError::DataQuotaExceeded { .. })
));
}
#[tokio::test]
async fn adversarial_blackhat_alternating_jitter_does_not_overshoot_quota() {
let stats = Arc::new(Stats::new());
let user = "quota-extended-blackhat-user";
let quota = 24u64;
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
256,
256,
user,
Arc::clone(&stats),
Some(quota),
Arc::new(BufferPool::new()),
));
let mut total_forwarded = 0usize;
for i in 0..256usize {
if relay.is_finished() {
break;
}
if (i & 1) == 0 {
let _ = client_peer.write_all(&[(i as u8) ^ 0x57]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await {
total_forwarded += n;
}
} else {
let _ = server_peer.write_all(&[(i as u8) ^ 0xA8]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await {
total_forwarded += n;
}
}
tokio::time::sleep(Duration::from_millis(((i % 3) + 1) as u64)).await;
}
let relay_result = timeout(Duration::from_secs(3), relay)
.await
.unwrap()
.unwrap();
assert!(matches!(
relay_result,
Err(ProxyError::DataQuotaExceeded { .. })
));
assert!(total_forwarded <= quota as usize);
assert!(stats.get_user_quota_used(user) <= quota);
}
#[tokio::test]
async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() {
let mut rng = StdRng::seed_from_u64(0xBEEF_C0DE);
for case in 0..32u64 {
let stats = Arc::new(Stats::new());
let user = format!("quota-extended-fuzz-{case}");
let quota = rng.random_range(1u64..=35u64);
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_user = user.clone();
let relay_stats = Arc::clone(&stats);
let relay = tokio::spawn(async move {
relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
256,
256,
&relay_user,
Arc::clone(&relay_stats),
Some(quota),
Arc::new(BufferPool::new()),
)
.await
});
let mut total_forwarded = 0usize;
for _ in 0..96usize {
if relay.is_finished() {
break;
}
if rng.random::<bool>() {
let _ = client_peer.write_all(&[rng.random::<u8>()]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) =
timeout(Duration::from_millis(4), server_peer.read(&mut one)).await
{
total_forwarded += n;
}
} else {
let _ = server_peer.write_all(&[rng.random::<u8>()]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) =
timeout(Duration::from_millis(4), client_peer.read(&mut one)).await
{
total_forwarded += n;
}
}
}
drop(client_peer);
drop(server_peer);
let relay_result = timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
assert!(
relay_result.is_ok()
|| matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))
);
assert!(total_forwarded <= quota as usize);
assert!(stats.get_user_quota_used(&user) <= quota);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_relays_for_one_user_obey_global_quota() {
let stats = Arc::new(Stats::new());
let user = "quota-extended-stress-user".to_string();
let quota = 64u64;
let mut tasks = Vec::new();
for worker in 0..4u8 {
let stats = Arc::clone(&stats);
let user = user.clone();
tasks.push(tokio::spawn(async move {
let (mut client_peer, relay_client) = duplex(2048);
let (relay_server, mut server_peer) = duplex(2048);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_user = user.clone();
let relay_stats = Arc::clone(&stats);
let relay = tokio::spawn(async move {
relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
128,
128,
&relay_user,
Arc::clone(&relay_stats),
Some(quota),
Arc::new(BufferPool::new()),
)
.await
});
let mut total = 0usize;
for step in 0..64u8 {
if relay.is_finished() {
break;
}
if (step as usize + worker as usize) % 2 == 0 {
let _ = client_peer.write_all(&[(step ^ 0x5A)]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) =
timeout(Duration::from_millis(6), server_peer.read(&mut one)).await
{
total += n;
}
} else {
let _ = server_peer.write_all(&[(step ^ 0xA5)]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) =
timeout(Duration::from_millis(6), client_peer.read(&mut one)).await
{
total += n;
}
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
drop(client_peer);
drop(server_peer);
let relay_result = timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
assert!(
relay_result.is_ok()
|| matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))
);
total
}));
}
let mut delivered = 0usize;
for task in tasks {
delivered += task.await.unwrap();
}
assert!(stats.get_user_quota_used(&user) <= quota);
assert!(delivered <= quota as usize);
}
@@ -1,438 +0,0 @@
use super::*;
use crate::error::ProxyError;
use crate::stats::Stats;
use crate::stream::BufferPool;
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::sync::Barrier;
use tokio::time::Instant;
#[test]
fn quota_lock_same_user_returns_same_arc_instance() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let a = quota_user_lock("quota-lock-same-user");
let b = quota_user_lock("quota-lock-same-user");
assert!(Arc::ptr_eq(&a, &b));
}
#[test]
fn quota_lock_parallel_same_user_reuses_single_lock() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let user = "quota-lock-parallel-same";
let mut handles = Vec::new();
for _ in 0..64 {
handles.push(std::thread::spawn(move || quota_user_lock(user)));
}
let first = handles
.remove(0)
.join()
.expect("thread must return lock handle");
for handle in handles {
let got = handle.join().expect("thread must return lock handle");
assert!(Arc::ptr_eq(&first, &got));
}
}
#[test]
fn quota_lock_unique_users_materialize_distinct_entries() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let base = format!("quota-lock-distinct-{}", std::process::id());
let users: Vec<String> = (0..(QUOTA_USER_LOCKS_MAX / 2))
.map(|idx| format!("{base}-{idx}"))
.collect();
for user in &users {
let _ = quota_user_lock(user);
}
for user in &users {
assert!(
map.get(user).is_some(),
"lock cache must contain entry for {user}"
);
}
}
#[test]
fn quota_lock_unique_churn_stress_keeps_all_inserted_keys_addressable() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let base = format!("quota-lock-churn-{}", std::process::id());
for idx in 0..(QUOTA_USER_LOCKS_MAX + 256) {
let _ = quota_user_lock(&format!("{base}-{idx}"));
}
assert!(
map.len() <= QUOTA_USER_LOCKS_MAX,
"quota lock cache must stay bounded under unique-user churn"
);
}
#[test]
fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let prefix = format!("quota-held-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"cache must be saturated for overflow check"
);
let overflow_user = format!("quota-overflow-{}", std::process::id());
let overflow_a = quota_user_lock(&overflow_user);
let overflow_b = quota_user_lock(&overflow_user);
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"overflow path must not grow lock cache"
);
assert!(
map.get(&overflow_user).is_none(),
"overflow user lock must stay outside bounded cache under saturation"
);
assert!(
Arc::ptr_eq(&overflow_a, &overflow_b),
"overflow user must receive stable striped overflow lock while saturated"
);
drop(retained);
}
#[test]
fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
// Saturate with retained strong references first so parallel tests cannot
// reclaim our fixture entries before we validate the reclaim path.
let prefix = format!("quota-reclaim-drop-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
drop(retained);
let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id());
let overflow = quota_user_lock(&overflow_user);
assert!(
map.get(&overflow_user).is_some(),
"after reclaiming stale entries, overflow user should become cacheable"
);
assert!(
Arc::strong_count(&overflow) >= 2,
"cacheable overflow lock should be held by both map and caller"
);
}
#[test]
fn quota_lock_saturated_same_user_must_not_return_distinct_locks() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!(
"quota-saturated-held-{}-{idx}",
std::process::id()
)));
}
let overflow_user = format!("quota-saturated-same-user-{}", std::process::id());
let a = quota_user_lock(&overflow_user);
let b = quota_user_lock(&overflow_user);
assert!(
Arc::ptr_eq(&a, &b),
"same user must not receive distinct locks under saturation because that enables quota race bypass"
);
drop(retained);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn quota_lock_saturation_concurrent_same_user_never_overshoots_quota() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!(
"quota-saturated-race-held-{}-{idx}",
std::process::id()
)));
}
let stats = Arc::new(Stats::new());
let user = format!("quota-saturated-race-user-{}", std::process::id());
let gate = Arc::new(Barrier::new(2));
let worker = |label: u8, stats: Arc<Stats>, user: String, gate: Arc<Barrier>| {
tokio::spawn(async move {
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user,
Some(1),
quota_exceeded,
Instant::now(),
);
gate.wait().await;
io.write_all(&[label]).await
})
};
let one = worker(0x11, Arc::clone(&stats), user.clone(), Arc::clone(&gate));
let two = worker(0x22, Arc::clone(&stats), user.clone(), Arc::clone(&gate));
let _ = tokio::time::timeout(Duration::from_secs(2), async {
let _ = one.await.expect("task one must not panic");
let _ = two.await.expect("task two must not panic");
})
.await
.expect("quota race workers must complete");
assert!(
stats.get_user_total_octets(&user) <= 1,
"saturated lock path must never overshoot quota for same user"
);
drop(retained);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn quota_lock_saturation_stress_same_user_never_overshoots_quota() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!(
"quota-saturated-stress-held-{}-{idx}",
std::process::id()
)));
}
for round in 0..128u32 {
let stats = Arc::new(Stats::new());
let user = format!("quota-saturated-stress-user-{}-{round}", std::process::id());
let gate = Arc::new(Barrier::new(2));
let one = {
let stats = Arc::clone(&stats);
let user = user.clone();
let gate = Arc::clone(&gate);
tokio::spawn(async move {
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user,
Some(1),
quota_exceeded,
Instant::now(),
);
gate.wait().await;
io.write_all(&[0x31]).await
})
};
let two = {
let stats = Arc::clone(&stats);
let user = user.clone();
let gate = Arc::clone(&gate);
tokio::spawn(async move {
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user,
Some(1),
quota_exceeded,
Instant::now(),
);
gate.wait().await;
io.write_all(&[0x32]).await
})
};
let _ = one.await.expect("stress task one must not panic");
let _ = two.await.expect("stress task two must not panic");
assert!(
stats.get_user_total_octets(&user) <= 1,
"round {round}: saturated path must not overshoot quota"
);
}
drop(retained);
}
#[test]
fn quota_error_classifier_accepts_internal_quota_sentinel_only() {
let err = quota_io_error();
assert!(is_quota_io_error(&err));
}
#[test]
fn quota_error_classifier_rejects_plain_permission_denied() {
let err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied");
assert!(!is_quota_io_error(&err));
}
#[test]
fn quota_lock_test_scope_recovers_after_guard_poison() {
let poison_result = std::thread::spawn(|| {
let _guard = super::quota_user_lock_test_scope();
panic!("intentional test-only guard poison");
})
.join();
assert!(poison_result.is_err(), "poison setup thread must panic");
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let a = quota_user_lock("quota-lock-poison-recovery-user");
let b = quota_user_lock("quota-lock-poison-recovery-user");
assert!(Arc::ptr_eq(&a, &b));
}
#[tokio::test]
async fn quota_lock_integration_zero_quota_cuts_off_without_forwarding() {
let stats = Arc::new(Stats::new());
let user = "quota-zero-user";
let (mut client_peer, relay_client) = duplex(2048);
let (relay_server, mut server_peer) = duplex(2048);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
512,
512,
user,
Arc::clone(&stats),
Some(0),
Arc::new(BufferPool::new()),
));
client_peer
.write_all(b"x")
.await
.expect("client write must succeed");
let mut probe = [0u8; 1];
let forwarded =
tokio::time::timeout(Duration::from_millis(80), server_peer.read(&mut probe)).await;
if let Ok(Ok(n)) = forwarded {
assert_eq!(n, 0, "zero quota path must not forward payload bytes");
}
let result = tokio::time::timeout(Duration::from_secs(2), relay)
.await
.expect("relay must terminate under zero quota")
.expect("relay task must not panic");
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
}
#[tokio::test]
async fn quota_lock_integration_no_quota_relays_both_directions_under_burst() {
let stats = Arc::new(Stats::new());
let (mut client_peer, relay_client) = duplex(8192);
let (relay_server, mut server_peer) = duplex(8192);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
"quota-none-burst-user",
Arc::clone(&stats),
None,
Arc::new(BufferPool::new()),
));
let c2s = vec![0xA5; 2048];
let s2c = vec![0x5A; 1536];
client_peer
.write_all(&c2s)
.await
.expect("client burst write must succeed");
let mut got_c2s = vec![0u8; c2s.len()];
server_peer
.read_exact(&mut got_c2s)
.await
.expect("server must receive c2s burst");
assert_eq!(got_c2s, c2s);
server_peer
.write_all(&s2c)
.await
.expect("server burst write must succeed");
let mut got_s2c = vec![0u8; s2c.len()];
client_peer
.read_exact(&mut got_s2c)
.await
.expect("client must receive s2c burst");
assert_eq!(got_s2c, s2c);
drop(client_peer);
drop(server_peer);
let done = tokio::time::timeout(Duration::from_secs(2), relay)
.await
.expect("relay must terminate after peers close")
.expect("relay task must not panic");
assert!(done.is_ok());
}
@@ -32,6 +32,7 @@ async fn drain_available<R: AsyncRead + Unpin>(reader: &mut R, out: &mut Vec<u8>
#[tokio::test]
async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() {
let mut rng = StdRng::seed_from_u64(0xC0DE_CAFE_D15C_F00D);
const MAX_INPUT_CHUNK: usize = 12;
for case in 0..64u64 {
let stats = Arc::new(Stats::new());
@@ -92,12 +93,12 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget()
assert_is_prefix(&recv_at_server, &sent_c2s, "C->S");
assert_is_prefix(&recv_at_client, &sent_s2c, "S->C");
assert!(
recv_at_server.len() + recv_at_client.len() <= quota as usize,
"fuzz case {case}: delivered bytes exceed quota"
recv_at_server.len() + recv_at_client.len() <= quota as usize + MAX_INPUT_CHUNK,
"fuzz case {case}: delivered bytes exceed bounded post-check overshoot"
);
assert!(
stats.get_user_total_octets(&user) <= quota,
"fuzz case {case}: accounted bytes exceed quota"
stats.get_user_quota_used(&user) <= quota + MAX_INPUT_CHUNK as u64,
"fuzz case {case}: accounted bytes exceed bounded post-check overshoot"
);
}
@@ -117,8 +118,8 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget()
assert_is_prefix(&recv_at_server, &sent_c2s, "C->S final");
assert_is_prefix(&recv_at_client, &sent_s2c, "S->C final");
assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize);
assert!(stats.get_user_total_octets(&user) <= quota);
assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize + MAX_INPUT_CHUNK);
assert!(stats.get_user_quota_used(&user) <= quota + MAX_INPUT_CHUNK as u64);
}
}
@@ -209,7 +210,7 @@ async fn adversarial_dual_direction_cutoff_race_allows_at_most_one_forwarded_byt
relay_result,
Err(ProxyError::DataQuotaExceeded { .. })
));
assert!(stats.get_user_total_octets(user) <= 1);
assert!(stats.get_user_quota_used(user) <= 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
@@ -217,9 +218,12 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode
let stats = Arc::new(Stats::new());
let user = "quota-model-stress-user";
let quota = 96u64;
const WORKERS: usize = 6;
const MAX_WORKER_CHUNK: u64 = 10;
let max_parallel_post_write_overshoot = WORKERS as u64 * MAX_WORKER_CHUNK;
let mut workers = Vec::new();
for worker_id in 0..6u64 {
for worker_id in 0..WORKERS as u64 {
let stats = Arc::clone(&stats);
let user = user.to_string();
@@ -305,11 +309,11 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode
}
assert!(
stats.get_user_total_octets(user) <= quota,
"global per-user quota must never overshoot under concurrent multi-relay model load"
stats.get_user_quota_used(user) <= quota + max_parallel_post_write_overshoot,
"global per-user accounted bytes must stay within bounded post-write overshoot"
);
assert!(
delivered_sum <= quota as usize,
"aggregate delivered bytes across relays must remain within global quota"
delivered_sum as u64 <= quota + max_parallel_post_write_overshoot,
"aggregate delivered bytes must stay within bounded post-write overshoot"
);
}
@@ -19,13 +19,22 @@ async fn read_available<R: AsyncRead + Unpin>(reader: &mut R, budget_ms: u64) ->
total
}
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
}
#[tokio::test]
async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_accounting() {
let stats = Arc::new(Stats::new());
let user = "quota-overflow-regression-client-chunk";
let quota = 10u64;
let preloaded = 9u64;
let attempted_chunk = [0x11, 0x22, 0x33, 0x44];
let max_post_write_overshoot = attempted_chunk.len() as u64;
// Leave only 1 byte remaining under quota.
stats.add_user_octets_from(user, 9);
preload_user_quota(stats.as_ref(), user, preloaded);
let (mut client_peer, relay_client) = duplex(2048);
let (relay_server, mut server_peer) = duplex(2048);
@@ -41,15 +50,12 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_
512,
user,
Arc::clone(&stats),
Some(10),
Some(quota),
Arc::new(BufferPool::new()),
));
// Single chunk attempts to cross remaining budget (4 > 1).
client_peer
.write_all(&[0x11, 0x22, 0x33, 0x44])
.await
.unwrap();
client_peer.write_all(&attempted_chunk).await.unwrap();
client_peer.shutdown().await.unwrap();
let forwarded = read_available(&mut server_peer, 60).await;
@@ -59,17 +65,17 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_
.expect("relay must terminate after quota overflow attempt")
.expect("relay task must not panic");
assert_eq!(
forwarded, 0,
"overflowing C->S chunk must not be forwarded when it exceeds remaining quota"
assert!(
forwarded <= attempted_chunk.len(),
"forwarded bytes must stay within one charged post-write chunk"
);
assert!(matches!(
relay_result,
Err(ProxyError::DataQuotaExceeded { .. })
));
assert!(
stats.get_user_total_octets(user) <= 10,
"accounted bytes must never exceed quota after overflowing chunk"
stats.get_user_quota_used(user) <= quota + max_post_write_overshoot,
"accounted bytes must stay within bounded post-write overshoot"
);
}
@@ -79,7 +85,7 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of
let user = "quota-overflow-regression-boundary";
// Leave exactly 4 bytes remaining.
stats.add_user_octets_from(user, 6);
preload_user_quota(stats.as_ref(), user, 6);
let (mut client_peer, relay_client) = duplex(2048);
let (relay_server, mut server_peer) = duplex(2048);
@@ -131,7 +137,7 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of
relay_result,
Err(ProxyError::DataQuotaExceeded { .. })
));
assert!(stats.get_user_total_octets(user) <= 10);
assert!(stats.get_user_quota_used(user) <= 10);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
@@ -139,9 +145,12 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() {
let stats = Arc::new(Stats::new());
let user = "quota-overflow-regression-stress";
let quota = 12u64;
const WORKERS: usize = 4;
const BURST_LEN: usize = 64;
let max_parallel_post_write_overshoot = (WORKERS * BURST_LEN) as u64;
let mut handles = Vec::new();
for _ in 0..4usize {
for _ in 0..WORKERS {
let stats = Arc::clone(&stats);
let user = user.to_string();
@@ -170,7 +179,7 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() {
});
// Aggressive sender tries to overflow shared user quota.
let burst = vec![0x5Au8; 64];
let burst = vec![0x5Au8; BURST_LEN];
let _ = client_peer.write_all(&burst).await;
let _ = client_peer.shutdown().await;
@@ -197,11 +206,11 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() {
}
assert!(
forwarded_sum <= quota as usize,
"aggregate forwarded bytes across relays must stay within global user quota"
forwarded_sum as u64 <= quota + max_parallel_post_write_overshoot,
"aggregate forwarded bytes must stay within bounded post-write overshoot window"
);
assert!(
stats.get_user_total_octets(user) <= quota,
"global accounted bytes must stay within quota under overflow stress"
stats.get_user_quota_used(user) <= quota + max_parallel_post_write_overshoot,
"global accounted bytes must stay within bounded post-write overshoot window"
);
}
@@ -1,294 +0,0 @@
use super::*;
use crate::stats::Stats;
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::Barrier;
use tokio::time::{Duration, timeout};
fn saturate_lock_cache() -> Vec<Arc<std::sync::Mutex<()>>> {
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("quota-liveness-saturated-{idx}")));
}
retained
}
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
#[tokio::test]
async fn positive_writer_progresses_after_contention_release_without_external_wake() {
let _guard = quota_test_guard();
let _retained = saturate_lock_cache();
let user = "quota-liveness-writer-positive";
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold user quota lock before write");
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let writer = tokio::spawn(async move { io.write_all(&[0x11]).await });
// Let the initial deferred wake fire while contention is still active.
tokio::time::sleep(Duration::from_millis(4)).await;
drop(held_guard);
let completed = timeout(Duration::from_millis(250), writer)
.await
.expect("writer must be re-polled and complete after lock release")
.expect("writer task must not panic");
assert!(completed.is_ok(), "writer must complete after lock release");
}
#[tokio::test]
async fn edge_reader_progresses_after_contention_release_without_external_wake() {
let _guard = quota_test_guard();
let _retained = saturate_lock_cache();
let user = "quota-liveness-reader-edge";
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold user quota lock before read");
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::empty(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let reader = tokio::spawn(async move {
let mut one = [0u8; 1];
io.read(&mut one).await
});
tokio::time::sleep(Duration::from_millis(4)).await;
drop(held_guard);
let completed = timeout(Duration::from_millis(250), reader)
.await
.expect("reader must be re-polled and complete after lock release")
.expect("reader task must not panic");
assert!(completed.is_ok(), "reader must complete after lock release");
}
#[tokio::test]
async fn adversarial_early_deferred_wake_consumption_does_not_deadlock_writer() {
let _guard = quota_test_guard();
let _retained = saturate_lock_cache();
let user = "quota-liveness-adversarial";
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold user quota lock before adversarial write");
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let writer = tokio::spawn(async move { io.write_all(&[0x22]).await });
// Force multiple scheduler rounds while lock remains held so the first
// deferred wake has already been consumed under contention.
for _ in 0..32 {
tokio::task::yield_now().await;
}
drop(held_guard);
let completed = timeout(Duration::from_millis(300), writer)
.await
.expect("writer must not stay parked forever after release")
.expect("writer task must not panic");
assert!(completed.is_ok());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_parallel_waiters_resume_after_single_release_event() {
let _guard = quota_test_guard();
let _retained = saturate_lock_cache();
let user = format!("quota-liveness-integration-{}", std::process::id());
let stats = Arc::new(Stats::new());
let barrier = Arc::new(Barrier::new(13));
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold user quota lock before launching waiters");
let mut waiters = Vec::new();
for _ in 0..12 {
let stats = Arc::clone(&stats);
let user = user.clone();
let barrier = Arc::clone(&barrier);
waiters.push(tokio::spawn(async move {
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
stats,
user,
Some(4096),
quota_exceeded,
tokio::time::Instant::now(),
);
barrier.wait().await;
io.write_all(&[0x33]).await
}));
}
barrier.wait().await;
tokio::time::sleep(Duration::from_millis(4)).await;
drop(held_guard);
timeout(Duration::from_secs(1), async {
for waiter in waiters {
let outcome = waiter.await.expect("waiter must not panic");
assert!(
outcome.is_ok(),
"waiter must resume and complete after release"
);
}
})
.await
.expect("all waiters must complete in bounded time");
}
#[tokio::test]
async fn light_fuzz_release_timing_matrix_preserves_liveness() {
let _guard = quota_test_guard();
let _retained = saturate_lock_cache();
let stats = Arc::new(Stats::new());
let mut seed = 0xD1CE_F00D_0123_4567u64;
for round in 0..64u32 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let delay_ms = 1 + (seed & 0x7) as u64;
let user = format!("quota-liveness-fuzz-{}-{round}", std::process::id());
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold user quota lock in fuzz round");
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user,
Some(2048),
quota_exceeded,
tokio::time::Instant::now(),
);
let writer = tokio::spawn(async move { io.write_all(&[0x44]).await });
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
drop(held_guard);
let done = timeout(Duration::from_millis(300), writer)
.await
.expect("fuzz round writer must complete")
.expect("fuzz writer task must not panic");
assert!(
done.is_ok(),
"fuzz round writer must not stall after release"
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_repeated_contention_cycles_remain_live() {
let _guard = quota_test_guard();
let _retained = saturate_lock_cache();
let stats = Arc::new(Stats::new());
for cycle in 0..40u32 {
let user = format!("quota-liveness-stress-{}-{cycle}", std::process::id());
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold lock before stress cycle");
let mut tasks = Vec::new();
for _ in 0..6 {
let stats = Arc::clone(&stats);
let user = user.clone();
tasks.push(tokio::spawn(async move {
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
stats,
user,
Some(2048),
quota_exceeded,
tokio::time::Instant::now(),
);
io.write_all(&[0x55]).await
}));
}
tokio::task::yield_now().await;
drop(held_guard);
timeout(Duration::from_millis(700), async {
for task in tasks {
let outcome = task.await.expect("stress task must not panic");
assert!(outcome.is_ok(), "stress writer must complete");
}
})
.await
.expect("stress cycle must finish in bounded time");
}
}
@@ -1,310 +0,0 @@
use super::*;
use crate::stats::Stats;
use dashmap::DashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Waker};
use tokio::io::{AsyncWriteExt, ReadBuf};
use tokio::time::{Duration, timeout};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
fn saturate_quota_user_locks() -> Vec<Arc<std::sync::Mutex<()>>> {
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("quota-waker-saturate-{idx}")));
}
retained
}
#[tokio::test]
async fn positive_contended_writer_emits_deferred_wake_for_liveness() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let stats = Arc::new(Stats::new());
let user = "quota-waker-positive-user";
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold overflow lock before polling writer");
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]);
assert!(pending.is_pending());
timeout(Duration::from_millis(100), async {
loop {
if wake_counter.wakes.load(Ordering::Relaxed) >= 1 {
break;
}
tokio::task::yield_now().await;
}
})
.await
.expect("contended writer must receive deferred wake");
drop(held_guard);
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]);
assert!(
ready.is_ready(),
"writer must progress after contention release"
);
}
#[tokio::test]
async fn adversarial_blackhat_writer_contention_does_not_create_waker_storm() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let stats = Arc::new(Stats::new());
let user = "quota-waker-blackhat-writer";
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold overflow lock before polling writer");
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
for _ in 0..512 {
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xBE]);
assert!(
poll.is_pending(),
"writer must stay pending while lock is held"
);
tokio::task::yield_now().await;
}
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
assert!(
wakes <= 128,
"pending writer retries must not trigger wake storm; observed wakes={wakes}"
);
drop(held_guard);
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xEF]);
assert!(ready.is_ready());
}
#[tokio::test]
async fn edge_read_path_contention_keeps_wake_budget_bounded() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let stats = Arc::new(Stats::new());
let user = "quota-waker-read-edge";
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold overflow lock before polling reader");
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::empty(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let mut storage = [0u8; 1];
for _ in 0..512 {
let mut buf = ReadBuf::new(&mut storage);
let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
assert!(poll.is_pending());
tokio::task::yield_now().await;
}
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
assert!(
wakes <= 128,
"pending reader retries must not trigger wake storm; observed wakes={wakes}"
);
drop(held_guard);
let mut buf = ReadBuf::new(&mut storage);
let ready = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
assert!(ready.is_ready());
}
#[tokio::test]
async fn light_fuzz_mixed_poll_schedule_under_contention_stays_bounded() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let stats = Arc::new(Stats::new());
let user = "quota-waker-fuzz-user";
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold overflow lock before fuzz polling");
let counters_w = Arc::new(SharedCounters::new());
let mut writer_io = StatsIo::new(
tokio::io::sink(),
counters_w,
Arc::clone(&stats),
user.to_string(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let counters_r = Arc::new(SharedCounters::new());
let mut reader_io = StatsIo::new(
tokio::io::empty(),
counters_r,
Arc::clone(&stats),
user.to_string(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let mut seed = 0xBADC_0FFE_EE11_2211u64;
let mut storage = [0u8; 1];
for _ in 0..1024 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
if (seed & 1) == 0 {
let poll = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x44]);
assert!(poll.is_pending());
} else {
let mut buf = ReadBuf::new(&mut storage);
let poll = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf);
assert!(poll.is_pending());
}
tokio::task::yield_now().await;
}
assert!(
wake_counter.wakes.load(Ordering::Relaxed) <= 192,
"mixed contention fuzz must keep deferred wake count tightly bounded"
);
drop(held_guard);
let ready_w = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x55]);
assert!(ready_w.is_ready());
let mut buf = ReadBuf::new(&mut storage);
let ready_r = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf);
assert!(ready_r.is_ready());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[ignore = "red-team detector: reveals possible starvation if deferred wake fires before contention release"]
async fn stress_many_contended_writers_complete_after_release() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let user = "quota-waker-stress-user".to_string();
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold overflow lock before launching contended tasks");
let mut tasks = Vec::new();
for _ in 0..32 {
let stats = Arc::clone(&stats);
let user = user.clone();
tasks.push(tokio::spawn(async move {
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
stats,
user,
Some(2048),
quota_exceeded,
tokio::time::Instant::now(),
);
io.write_all(&[0xAA]).await
}));
}
for _ in 0..8 {
tokio::task::yield_now().await;
}
drop(held_guard);
timeout(Duration::from_secs(2), async {
for task in tasks {
let result = task.await.expect("stress task must not panic");
assert!(result.is_ok(), "task must complete after lock release");
}
})
.await
.expect("all contended writer tasks must finish in bounded time after release");
}
File diff suppressed because it is too large Load Diff
+307 -61
View File
@@ -238,10 +238,12 @@ pub struct Stats {
me_inline_recovery_total: AtomicU64,
ip_reservation_rollback_tcp_limit_total: AtomicU64,
ip_reservation_rollback_quota_limit_total: AtomicU64,
quota_write_fail_bytes_total: AtomicU64,
quota_write_fail_events_total: AtomicU64,
telemetry_core_enabled: AtomicBool,
telemetry_user_enabled: AtomicBool,
telemetry_me_level: AtomicU8,
user_stats: DashMap<String, UserStats>,
user_stats: DashMap<String, Arc<UserStats>>,
user_stats_last_cleanup_epoch_secs: AtomicU64,
start_time: parking_lot::RwLock<Option<Instant>>,
}
@@ -254,9 +256,51 @@ pub struct UserStats {
pub octets_to_client: AtomicU64,
pub msgs_from_client: AtomicU64,
pub msgs_to_client: AtomicU64,
/// Total bytes charged against per-user quota admission.
///
/// This counter is the single source of truth for quota enforcement and
/// intentionally tracks attempted traffic, not guaranteed delivery.
pub quota_used: AtomicU64,
pub last_seen_epoch_secs: AtomicU64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QuotaReserveError {
LimitExceeded,
Contended,
}
impl UserStats {
#[inline]
pub fn quota_used(&self) -> u64 {
self.quota_used.load(Ordering::Relaxed)
}
/// Attempts one CAS reservation step against the quota counter.
///
/// Callers control retry/yield policy. This primitive intentionally does
/// not block or sleep so both sync poll paths and async paths can wrap it
/// with their own contention strategy.
#[inline]
pub fn quota_try_reserve(&self, bytes: u64, limit: u64) -> Result<u64, QuotaReserveError> {
let current = self.quota_used.load(Ordering::Relaxed);
if bytes > limit.saturating_sub(current) {
return Err(QuotaReserveError::LimitExceeded);
}
let next = current.saturating_add(bytes);
match self.quota_used.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => Ok(next),
Err(_) => Err(QuotaReserveError::Contended),
}
}
}
impl Stats {
pub fn new() -> Self {
let stats = Self::default();
@@ -316,6 +360,74 @@ impl Stats {
.store(Self::now_epoch_secs(), Ordering::Relaxed);
}
pub(crate) fn get_or_create_user_stats_handle(&self, user: &str) -> Arc<UserStats> {
self.maybe_cleanup_user_stats();
if let Some(existing) = self.user_stats.get(user) {
let handle = Arc::clone(existing.value());
Self::touch_user_stats(handle.as_ref());
return handle;
}
let entry = self.user_stats.entry(user.to_string()).or_default();
if entry.last_seen_epoch_secs.load(Ordering::Relaxed) == 0 {
Self::touch_user_stats(entry.value().as_ref());
}
Arc::clone(entry.value())
}
#[inline]
pub(crate) fn add_user_octets_from_handle(&self, user_stats: &UserStats, bytes: u64) {
if !self.telemetry_user_enabled() {
return;
}
Self::touch_user_stats(user_stats);
user_stats
.octets_from_client
.fetch_add(bytes, Ordering::Relaxed);
}
#[inline]
pub(crate) fn add_user_octets_to_handle(&self, user_stats: &UserStats, bytes: u64) {
if !self.telemetry_user_enabled() {
return;
}
Self::touch_user_stats(user_stats);
user_stats
.octets_to_client
.fetch_add(bytes, Ordering::Relaxed);
}
#[inline]
pub(crate) fn increment_user_msgs_from_handle(&self, user_stats: &UserStats) {
if !self.telemetry_user_enabled() {
return;
}
Self::touch_user_stats(user_stats);
user_stats.msgs_from_client.fetch_add(1, Ordering::Relaxed);
}
#[inline]
pub(crate) fn increment_user_msgs_to_handle(&self, user_stats: &UserStats) {
if !self.telemetry_user_enabled() {
return;
}
Self::touch_user_stats(user_stats);
user_stats.msgs_to_client.fetch_add(1, Ordering::Relaxed);
}
/// Charges already committed bytes in a post-I/O path.
///
/// This helper is intentionally separate from `quota_try_reserve` to avoid
/// mixing reserve and post-charge on a single I/O event.
#[inline]
pub(crate) fn quota_charge_post_write(&self, user_stats: &UserStats, bytes: u64) -> u64 {
Self::touch_user_stats(user_stats);
user_stats
.quota_used
.fetch_add(bytes, Ordering::Relaxed)
.saturating_add(bytes)
}
fn maybe_cleanup_user_stats(&self) {
const USER_STATS_CLEANUP_INTERVAL_SECS: u64 = 60;
const USER_STATS_IDLE_TTL_SECS: u64 = 24 * 60 * 60;
@@ -704,7 +816,8 @@ impl Stats {
}
pub fn increment_me_d2c_data_frames_total(&self) {
if self.telemetry_me_allows_normal() {
self.me_d2c_data_frames_total.fetch_add(1, Ordering::Relaxed);
self.me_d2c_data_frames_total
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_me_d2c_ack_frames_total(&self) {
@@ -1114,6 +1227,18 @@ impl Stats {
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn add_quota_write_fail_bytes_total(&self, bytes: u64) {
if self.telemetry_core_enabled() {
self.quota_write_fail_bytes_total
.fetch_add(bytes, Ordering::Relaxed);
}
}
pub fn increment_quota_write_fail_events_total(&self) {
if self.telemetry_core_enabled() {
self.quota_write_fail_events_total
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_me_endpoint_quarantine_total(&self) {
if self.telemetry_me_allows_normal() {
self.me_endpoint_quarantine_total
@@ -1588,7 +1713,8 @@ impl Stats {
self.me_d2c_batch_bytes_bucket_1k_4k.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_bytes_bucket_4k_16k(&self) -> u64 {
self.me_d2c_batch_bytes_bucket_4k_16k.load(Ordering::Relaxed)
self.me_d2c_batch_bytes_bucket_4k_16k
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_bytes_bucket_16k_64k(&self) -> u64 {
self.me_d2c_batch_bytes_bucket_16k_64k
@@ -1764,19 +1890,19 @@ impl Stats {
self.ip_reservation_rollback_quota_limit_total
.load(Ordering::Relaxed)
}
pub fn get_quota_write_fail_bytes_total(&self) -> u64 {
self.quota_write_fail_bytes_total.load(Ordering::Relaxed)
}
pub fn get_quota_write_fail_events_total(&self) -> u64 {
self.quota_write_fail_events_total.load(Ordering::Relaxed)
}
pub fn increment_user_connects(&self, user: &str) {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.connects.fetch_add(1, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
let stats = self.get_or_create_user_stats_handle(user);
Self::touch_user_stats(stats.as_ref());
stats.connects.fetch_add(1, Ordering::Relaxed);
}
@@ -1784,14 +1910,8 @@ impl Stats {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.curr_connects.fetch_add(1, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
let stats = self.get_or_create_user_stats_handle(user);
Self::touch_user_stats(stats.as_ref());
stats.curr_connects.fetch_add(1, Ordering::Relaxed);
}
@@ -1800,9 +1920,8 @@ impl Stats {
return true;
}
self.maybe_cleanup_user_stats();
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
let stats = self.get_or_create_user_stats_handle(user);
Self::touch_user_stats(stats.as_ref());
let counter = &stats.curr_connects;
let mut current = counter.load(Ordering::Relaxed);
@@ -1827,7 +1946,7 @@ impl Stats {
pub fn decrement_user_curr_connects(&self, user: &str) {
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
Self::touch_user_stats(stats.value().as_ref());
let counter = &stats.curr_connects;
let mut current = counter.load(Ordering::Relaxed);
loop {
@@ -1858,60 +1977,32 @@ impl Stats {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
let stats = self.get_or_create_user_stats_handle(user);
self.add_user_octets_from_handle(stats.as_ref(), bytes);
}
pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
let stats = self.get_or_create_user_stats_handle(user);
self.add_user_octets_to_handle(stats.as_ref(), bytes);
}
pub fn increment_user_msgs_from(&self, user: &str) {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.msgs_from_client.fetch_add(1, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
stats.msgs_from_client.fetch_add(1, Ordering::Relaxed);
let stats = self.get_or_create_user_stats_handle(user);
self.increment_user_msgs_from_handle(stats.as_ref());
}
pub fn increment_user_msgs_to(&self, user: &str) {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.msgs_to_client.fetch_add(1, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
stats.msgs_to_client.fetch_add(1, Ordering::Relaxed);
let stats = self.get_or_create_user_stats_handle(user);
self.increment_user_msgs_to_handle(stats.as_ref());
}
pub fn get_user_total_octets(&self, user: &str) -> u64 {
@@ -1924,6 +2015,13 @@ impl Stats {
.unwrap_or(0)
}
pub fn get_user_quota_used(&self, user: &str) -> u64 {
self.user_stats
.get(user)
.map(|s| s.quota_used.load(Ordering::Relaxed))
.unwrap_or(0)
}
pub fn get_handshake_timeouts(&self) -> u64 {
self.handshake_timeouts.load(Ordering::Relaxed)
}
@@ -1989,7 +2087,7 @@ impl Stats {
.load(Ordering::Relaxed)
}
pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, UserStats> {
pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, Arc<UserStats>> {
self.user_stats.iter()
}
@@ -2137,6 +2235,22 @@ impl ReplayChecker {
found
}
fn check_only_internal(
&self,
data: &[u8],
shards: &[Mutex<ReplayShard>],
window: Duration,
) -> bool {
self.checks.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = shards[idx].lock();
let found = shard.check(data, Instant::now(), window);
if found {
self.hits.fetch_add(1, Ordering::Relaxed);
}
found
}
fn add_only(&self, data: &[u8], shards: &[Mutex<ReplayShard>], window: Duration) {
self.additions.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
@@ -2160,7 +2274,7 @@ impl ReplayChecker {
self.add_only(data, &self.handshake_shards, self.window)
}
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
self.check_and_add_tls_digest(data)
self.check_only_internal(data, &self.tls_shards, self.tls_window)
}
pub fn add_tls_digest(&self, data: &[u8]) {
self.add_only(data, &self.tls_shards, self.tls_window)
@@ -2264,6 +2378,7 @@ mod tests {
use super::*;
use crate::config::MeTelemetryLevel;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
#[test]
fn test_stats_shared_counters() {
@@ -2431,6 +2546,137 @@ mod tests {
}
assert_eq!(checker.stats().total_entries, 500);
}
#[test]
fn test_quota_reserve_under_contention_hits_limit_exactly() {
let user_stats = Arc::new(UserStats::default());
let successes = Arc::new(AtomicU64::new(0));
let limit = 8_192u64;
let mut workers = Vec::new();
for _ in 0..8 {
let user_stats = user_stats.clone();
let successes = successes.clone();
workers.push(std::thread::spawn(move || {
loop {
match user_stats.quota_try_reserve(1, limit) {
Ok(_) => {
successes.fetch_add(1, Ordering::Relaxed);
}
Err(QuotaReserveError::Contended) => {
std::hint::spin_loop();
}
Err(QuotaReserveError::LimitExceeded) => {
break;
}
}
}
}));
}
for worker in workers {
worker.join().expect("worker thread must finish");
}
assert_eq!(
successes.load(Ordering::Relaxed),
limit,
"successful reservations must stop exactly at limit"
);
assert_eq!(user_stats.quota_used(), limit);
}
#[test]
fn test_quota_reserve_200x_1k_reaches_100k_without_overshoot() {
let user_stats = Arc::new(UserStats::default());
let successes = Arc::new(AtomicU64::new(0));
let failures = Arc::new(AtomicU64::new(0));
let attempts = 200usize;
let reserve_bytes = 1_024u64;
let limit = 100 * 1_024u64;
let mut workers = Vec::with_capacity(attempts);
for _ in 0..attempts {
let user_stats = user_stats.clone();
let successes = successes.clone();
let failures = failures.clone();
workers.push(std::thread::spawn(move || {
loop {
match user_stats.quota_try_reserve(reserve_bytes, limit) {
Ok(_) => {
successes.fetch_add(1, Ordering::Relaxed);
return;
}
Err(QuotaReserveError::LimitExceeded) => {
failures.fetch_add(1, Ordering::Relaxed);
return;
}
Err(QuotaReserveError::Contended) => {
std::hint::spin_loop();
}
}
}
}));
}
for worker in workers {
worker.join().expect("reservation worker must finish");
}
assert_eq!(
successes.load(Ordering::Relaxed),
100,
"exactly 100 reservations of 1 KiB must fit into a 100 KiB quota"
);
assert_eq!(
failures.load(Ordering::Relaxed),
100,
"remaining workers must fail once quota is fully reserved"
);
assert_eq!(user_stats.quota_used(), limit);
}
#[test]
fn test_quota_used_is_authoritative_and_independent_from_octets_telemetry() {
let stats = Stats::new();
let user = "quota-authoritative-user";
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.add_user_octets_to_handle(&user_stats, 5);
assert_eq!(stats.get_user_total_octets(user), 5);
assert_eq!(stats.get_user_quota_used(user), 0);
stats.quota_charge_post_write(&user_stats, 7);
assert_eq!(stats.get_user_total_octets(user), 5);
assert_eq!(stats.get_user_quota_used(user), 7);
}
#[test]
fn test_cached_handle_survives_map_cleanup_until_last_drop() {
let stats = Stats::new();
let user = "quota-handle-lifetime-user";
let user_stats = stats.get_or_create_user_stats_handle(user);
let weak = Arc::downgrade(&user_stats);
stats.user_stats.remove(user);
assert!(
stats.user_stats.get(user).is_none(),
"map cleanup should remove idle entry"
);
assert!(
weak.upgrade().is_some(),
"cached handle must keep user stats object alive after map removal"
);
stats.quota_charge_post_write(user_stats.as_ref(), 3);
assert_eq!(user_stats.quota_used(), 3);
drop(user_stats);
assert!(
weak.upgrade().is_none(),
"user stats object must be dropped after the last cached handle is released"
);
}
}
#[cfg(test)]
@@ -14,7 +14,10 @@ fn padding_rounding_equivalent_for_extensive_safe_domain() {
let old = old_padding_round_up_to_4(len).expect("old expression must be safe");
let new = new_padding_round_up_to_4(len).expect("new expression must be safe");
assert_eq!(old, new, "mismatch for len={len}");
assert!(new >= len, "rounded length must not shrink: len={len}, out={new}");
assert!(
new >= len,
"rounded length must not shrink: len={len}, out={new}"
);
assert_eq!(new % 4, 0, "rounded length must stay 4-byte aligned");
}
}

Some files were not shown because too many files have changed in this diff Show More