Compare commits

...

41 Commits

Author SHA1 Message Date
Roman Martynov 6377a4ba18
Merge 4d83d02a8f into ec7e808daf 2026-03-24 09:09:02 +01:00
Alexey ec7e808daf
Update release.yml 2026-03-24 11:05:50 +03:00
Alexey e4b7e23e76
New TLS-Fetcher + TLS SNI Validator + Upstream-driver getProxySecret/Config + Workflow Tunings + Redesign Quotas on Atomics + Tests Swap: merge pull request #569 from telemt/flow
New TLS-Fetcher + TLS SNI Validator + Upstream-driver getProxySecret/Config + Workflow Tunings + Redesign Quotas on Atomics + Tests Swap
2026-03-24 10:56:15 +03:00
Alexey 8b92b80b4a
Rustks CryptoProvider fixes + Rustfmt 2026-03-24 10:33:06 +03:00
Alexey f7868aa00f
Advanced TLS Fetcher
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-24 09:58:24 +03:00
Alexey 655a08fa5c
TLS Fetcher fixes 2026-03-23 23:12:50 +03:00
Alexey 8bc432db49
Rustfmt 2026-03-23 23:00:46 +03:00
Alexey a40d6929e5
Upstream-driver getProxyConfig and getProxyConfig
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 22:41:17 +03:00
Alexey 8db566dbe9
TLS Validator
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 21:58:39 +03:00
Alexey bb71de0230
Missing proxy_protocol_trusted_cidrs as trust-
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 20:54:58 +03:00
Alexey 62a258f8e3
Update test.yml 2026-03-23 20:49:17 +03:00
Alexey c868eaae74
Update test.yml 2026-03-23 20:36:25 +03:00
Alexey 8e1860f912
Update test.yml 2026-03-23 20:34:59 +03:00
Alexey 814bef9d99
Rustfmt 2026-03-23 20:32:55 +03:00
Alexey 3ceda15073
Update relay_quota_model_adversarial_tests.rs
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 20:18:18 +03:00
Alexey a3a6ea2880
Update relay_quota_overflow_regression_tests.rs
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 20:06:11 +03:00
Alexey 24156b5067
Workflow for Docker and correct binary naming
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 17:42:18 +03:00
Alexey a1dfa5b11d
Merge branch 'flow' of https://github.com/telemt/telemt into flow 2026-03-23 17:05:26 +03:00
Alexey 800356c751
Rewiring tests
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 17:04:47 +03:00
Alexey 1546b012a6
Merge pull request #568 from avbor/main
DOCS: Update VPS_DOUBLE_HOP.*.md - AmneziaWG 2.0
2026-03-23 16:49:57 +03:00
Alexey e6b77af931
Workflows Swap
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 16:49:23 +03:00
Alexey 8cfaab9320
Fixes in tests
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 16:39:49 +03:00
Alexey 2d69b9d0ae
New wave of tests
Co-Authored-By: brekotis <93345790+brekotis@users.noreply.github.com>
2026-03-23 16:39:23 +03:00
Alexander 41c2b4de65
Update VPS_DOUBLE_HOP.en.md
Added S3-S4 parameters for AWG and update AWG generator.
2026-03-23 16:30:37 +03:00
Alexander 0a5e8a09fd
Update VPS_DOUBLE_HOP.ru.md
Added S3-S4 parameters for AWG and update AWG generator.
2026-03-23 16:29:08 +03:00
Alexey 2f9fddfa6f
Old Test Deletion 2026-03-23 16:21:53 +03:00
Alexey 6f4356f72a
Redesign Quotas on Atomics 2026-03-23 15:53:44 +03:00
Alexey 0c3c9009a9
Merge pull request #538 from DavidOsipov/flow
Cross-mode Quota Locks, Masking Prefetch & Tiny-Frame Debt Protection
2026-03-23 11:35:57 +03:00
Alexey 0475844701
Merge branch 'flow' into flow 2026-03-23 11:35:44 +03:00
David Osipov 1abf9bd05c
Refactor CI workflows: rename build job and streamline stress testing setup 2026-03-23 12:27:57 +04:00
David Osipov 6f17d4d231
Add comprehensive security tests for quota management and relay functionality
- Introduced `relay_dual_lock_race_harness_security_tests.rs` to validate user liveness during lock hold and release cycles.
- Added `relay_quota_extended_attack_surface_security_tests.rs` to cover various quota scenarios including positive, negative, edge cases, and adversarial conditions.
- Implemented `relay_quota_lock_eviction_lifecycle_tdd_tests.rs` to ensure proper eviction of stale entries and lifecycle management of quota locks.
- Created `relay_quota_lock_eviction_stress_security_tests.rs` to stress test the eviction mechanism under high churn conditions.
- Enhanced `relay_quota_lock_pressure_adversarial_tests.rs` to verify reclaiming of unreferenced entries after explicit eviction.
- Developed `relay_quota_retry_allocation_latency_security_tests.rs` to benchmark and validate latency and allocation behavior under contention.
2026-03-23 12:04:41 +04:00
David Osipov 91be148b72
Security hardening, concurrency fixes, and expanded test coverage
This commit introduces a comprehensive set of improvements to enhance
the security, reliability, and configurability of the proxy server,
specifically targeting adversarial resilience and high-load concurrency.

Security & Cryptography:
- Zeroize MTProto cryptographic key material (`dec_key`, `enc_key`)
  immediately after use to prevent memory leakage on early returns.
- Move TLS handshake replay tracking after full policy/ALPN validation
  to prevent cache poisoning by unauthenticated probes.
- Add `proxy_protocol_trusted_cidrs` configuration to restrict PROXY
  protocol headers to trusted networks, rejecting spoofed IPs.

Adversarial Resilience & DoS Mitigation:
- Implement "Tiny Frame Debt" tracking in the middle-relay to prevent
  CPU exhaustion from malicious 0-byte or 1-byte frame floods.
- Add `mask_relay_max_bytes` to strictly bound unauthenticated fallback
  connections, preventing the proxy from being abused as an open relay.
- Add a 5ms prefetch window (`mask_classifier_prefetch_timeout_ms`) to
  correctly assemble and classify fragmented HTTP/1.1 and HTTP/2 probes
  (e.g., `PRI * HTTP/2.0`) before routing them to masking heuristics.
- Prevent recursive masking loops (FD exhaustion) by verifying the mask
  target is not the proxy's own listener via local interface enumeration.

Concurrency & Reliability:
- Eliminate executor waker storms during quota lock contention by replacing
  the spin-waker task with inline `Sleep` and exponential backoff.
- Roll back user quota reservations (`rollback_me2c_quota_reservation`)
  if a network write fails, preventing Head-of-Line (HoL) blocking from
  permanently burning data quotas.
- Recover gracefully from idle-registry `Mutex` poisoning instead of
  panicking, ensuring isolated thread failures do not break the proxy.
- Fix `auth_probe_scan_start_offset` modulo logic to ensure bounds safety.

Testing:
- Add extensive adversarial, timing, fuzzing, and invariant test suites
  for both the client and handshake modules.
2026-03-22 23:09:49 +04:00
Alexander e46d2cfc52
Update VPS_DOUBLE_HOP.ru.md
Fix typo
2026-03-22 21:59:20 +03:00
David Osipov 6fc188f0c4
Update src/proxy/tests/handshake_more_clever_tests.rs
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-22 17:08:23 +04:00
David Osipov 5c9fea5850
Update src/proxy/tests/client_security_tests.rs
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-22 17:08:16 +04:00
Alexey 3011a9ef6d
Merge branch 'flow' into flow 2026-03-22 15:50:21 +03:00
David Osipov ead23608f0
Add stress and manual benchmark tests for handshake protocols
- Introduced `handshake_real_bug_stress_tests.rs` to validate TLS and MTProto handshake behaviors under various conditions, including ALPN rejection and session ID handling.
- Implemented tests to ensure replay cache integrity and proper handling of malicious input without panicking.
- Added `handshake_timing_manual_bench_tests.rs` for performance benchmarking of user authentication paths, comparing preferred user handling against full user scans in both MTProto and TLS contexts.
- Included timing-sensitive tests to measure the impact of SNI on handshake performance.
2026-03-22 15:39:57 +04:00
sintanial 4d83d02a8f
Apply [timeouts] tg_connect to upstream DC TCP connect attempts
Wire config.timeouts.tg_connect into UpstreamManager; per-attempt timeout uses
the same .max(1) pattern as connect_budget_ms.

Reject timeouts.tg_connect = 0 at config load (consistent with
general.upstream_connect_budget_ms and related checks). Default when the key
is omitted remains default_connect_timeout() via serde.

Fixes telemt/telemt#439
2026-03-21 16:26:51 +03:00
sintanial fea8bc63fd
Merge branch 'main' of https://github.com/telemt/telemt 2026-03-20 23:27:02 +03:00
sintanial d8f7173f15 Merge branch 'main' of https://github.com/telemt/telemt 2026-03-01 15:18:47 +03:00
sintanial b23d433e19 Merge branch 'main' of https://github.com/telemt/telemt 2026-03-01 13:48:59 +03:00
108 changed files with 12977 additions and 6913 deletions

39
.github/workflows/build.yml vendored Normal file
View File

@ -0,0 +1,39 @@
name: Build
on:
push:
branches: [ "*" ]
pull_request:
branches: [ "*" ]
env:
CARGO_TERM_COLOR: always
jobs:
build:
name: Build
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install latest stable Rust toolchain
uses: dtolnay/rust-toolchain@stable
- name: Cache cargo registry & build artifacts
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
- name: Build Release
run: cargo build --release --verbose

View File

@ -26,6 +26,9 @@ jobs:
name: GNU ${{ matrix.target }}
runs-on: ubuntu-latest
container:
image: rust:slim-bookworm
strategy:
fail-fast: false
matrix:
@ -47,8 +50,8 @@ jobs:
- name: Install deps
run: |
sudo apt-get update
sudo apt-get install -y \
apt-get update
apt-get install -y \
build-essential \
clang \
lld \
@ -69,14 +72,10 @@ jobs:
if [ "${{ matrix.target }}" = "aarch64-unknown-linux-gnu" ]; then
export CC=aarch64-linux-gnu-gcc
export CXX=aarch64-linux-gnu-g++
export CC_aarch64_unknown_linux_gnu=aarch64-linux-gnu-gcc
export CXX_aarch64_unknown_linux_gnu=aarch64-linux-gnu-g++
export RUSTFLAGS="-C linker=aarch64-linux-gnu-gcc"
else
export CC=clang
export CXX=clang++
export CC_x86_64_unknown_linux_gnu=clang
export CXX_x86_64_unknown_linux_gnu=clang++
export RUSTFLAGS="-C linker=clang -C link-arg=-fuse-ld=lld"
fi
@ -85,20 +84,19 @@ jobs:
- name: Package
run: |
mkdir -p dist
BIN=target/${{ matrix.target }}/release/${{ env.BINARY_NAME }}
cp "$BIN" dist/${{ env.BINARY_NAME }}-${{ matrix.target }}
cp target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} dist/telemt
cd dist
tar -czf ${{ matrix.asset }}.tar.gz ${{ env.BINARY_NAME }}-${{ matrix.target }}
tar -czf ${{ matrix.asset }}.tar.gz \
--owner=0 --group=0 --numeric-owner \
telemt
sha256sum ${{ matrix.asset }}.tar.gz > ${{ matrix.asset }}.sha256
- uses: actions/upload-artifact@v4
with:
name: ${{ matrix.asset }}
path: |
dist/${{ matrix.asset }}.tar.gz
dist/${{ matrix.asset }}.sha256
path: dist/*
# ==========================
# MUSL
@ -146,9 +144,9 @@ jobs:
URL="https://github.com/telemt/telemt/releases/download/toolchains/$ARCHIVE"
if [ -x "$TOOLCHAIN_DIR/bin/aarch64-linux-musl-gcc" ]; then
echo "✅ MUSL toolchain already installed"
echo "✅ MUSL toolchain cached"
else
echo "⬇️ Downloading musl toolchain from Telemt GitHub Releases..."
echo "⬇️ Downloading MUSL toolchain..."
curl -fL \
--retry 5 \
@ -191,69 +189,19 @@ jobs:
- name: Package
run: |
mkdir -p dist
BIN=target/${{ matrix.target }}/release/${{ env.BINARY_NAME }}
cp "$BIN" dist/${{ env.BINARY_NAME }}-${{ matrix.target }}
cp target/${{ matrix.target }}/release/${{ env.BINARY_NAME }} dist/telemt
cd dist
tar -czf ${{ matrix.asset }}.tar.gz ${{ env.BINARY_NAME }}-${{ matrix.target }}
tar -czf ${{ matrix.asset }}.tar.gz \
--owner=0 --group=0 --numeric-owner \
telemt
sha256sum ${{ matrix.asset }}.tar.gz > ${{ matrix.asset }}.sha256
- uses: actions/upload-artifact@v4
with:
name: ${{ matrix.asset }}
path: |
dist/${{ matrix.asset }}.tar.gz
dist/${{ matrix.asset }}.sha256
# ==========================
# Docker
# ==========================
docker:
name: Docker
runs-on: ubuntu-latest
needs: [build-gnu, build-musl]
continue-on-error: true
steps:
- uses: actions/checkout@v4
- uses: actions/download-artifact@v4
with:
path: artifacts
- name: Extract binaries
run: |
mkdir dist
find artifacts -name "*.tar.gz" -exec tar -xzf {} -C dist \;
cp dist/telemt-x86_64-unknown-linux-musl dist/telemt || true
- uses: docker/setup-qemu-action@v3
- uses: docker/setup-buildx-action@v3
- name: Login to GHCR
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract version
id: vars
run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
- name: Build & Push
uses: docker/build-push-action@v6
with:
context: .
push: true
platforms: linux/amd64,linux/arm64
tags: |
ghcr.io/${{ github.repository }}:${{ steps.vars.outputs.VERSION }}
ghcr.io/${{ github.repository }}:latest
build-args: |
BINARY=dist/telemt
path: dist/*
# ==========================
# Release
@ -271,7 +219,7 @@ jobs:
with:
path: artifacts
- name: Flatten artifacts
- name: Flatten
run: |
mkdir dist
find artifacts -type f -exec cp {} dist/ \;
@ -281,5 +229,61 @@ jobs:
with:
files: dist/*
generate_release_notes: true
draft: false
prerelease: ${{ contains(github.ref, '-rc') || contains(github.ref, '-beta') || contains(github.ref, '-alpha') }}
prerelease: ${{ contains(github.ref, '-') }}
# ==========================
# Docker (FROM RELEASE)
# ==========================
docker:
name: Docker (from release)
runs-on: ubuntu-latest
needs: release
permissions:
contents: read
packages: write
steps:
- uses: actions/checkout@v4
- name: Install gh
run: apt-get update && apt-get install -y gh
- name: Extract version
id: vars
run: echo "VERSION=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
- name: Download binary
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
mkdir dist
gh release download ${{ steps.vars.outputs.VERSION }} \
--repo ${{ github.repository }} \
--pattern "telemt-x86_64-linux-musl.tar.gz" \
--dir dist
tar -xzf dist/telemt-x86_64-linux-musl.tar.gz -C dist
chmod +x dist/telemt
- uses: docker/setup-qemu-action@v3
- uses: docker/setup-buildx-action@v3
- uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build & Push
uses: docker/build-push-action@v6
with:
context: .
push: true
platforms: linux/amd64,linux/arm64
tags: |
ghcr.io/${{ github.repository }}:${{ steps.vars.outputs.VERSION }}
ghcr.io/${{ github.repository }}:latest
build-args: |
BINARY=dist/telemt

View File

@ -1,66 +0,0 @@
name: Rust
on:
push:
branches: [ "*" ]
pull_request:
branches: [ "*" ]
env:
CARGO_TERM_COLOR: always
jobs:
build:
name: Build
runs-on: ubuntu-latest
permissions:
contents: read
actions: write
checks: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install latest stable Rust toolchain
uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt, clippy
- name: Cache cargo registry & build artifacts
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
- name: Build Release
run: cargo build --release --verbose
- name: Run tests
run: cargo test --verbose
- name: Stress quota-lock suites (PR only)
if: github.event_name == 'pull_request'
env:
RUST_TEST_THREADS: 16
run: |
set -euo pipefail
for i in $(seq 1 12); do
echo "[quota-lock-stress] iteration ${i}/12"
cargo test quota_lock_ --bin telemt -- --nocapture --test-threads 16
cargo test relay_quota_wake --bin telemt -- --nocapture --test-threads 16
done
# clippy dont fail on warnings because of active development of telemt
# and many warnings
- name: Run clippy
run: cargo clippy -- --cap-lints warn
- name: Check for unused dependencies
run: cargo udeps || true

127
.github/workflows/test.yml vendored Normal file
View File

@ -0,0 +1,127 @@
name: Check
on:
push:
branches: [ "*" ]
pull_request:
branches: [ "*" ]
env:
CARGO_TERM_COLOR: always
concurrency:
group: test-${{ github.ref }}
cancel-in-progress: true
jobs:
# ==========================
# Formatting
# ==========================
fmt:
name: Fmt
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt
- run: cargo fmt -- --check
# ==========================
# Tests
# ==========================
test:
name: Test
runs-on: ubuntu-latest
permissions:
contents: read
actions: write
checks: write
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- name: Cache cargo
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
- run: cargo test --verbose
# ==========================
# Clippy
# ==========================
clippy:
name: Clippy
runs-on: ubuntu-latest
permissions:
contents: read
checks: write
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
components: clippy
- name: Cache cargo
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
- run: cargo clippy -- --cap-lints warn
# ==========================
# Udeps
# ==========================
udeps:
name: Udeps
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- name: Cache cargo
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
restore-keys: |
${{ runner.os }}-cargo-
- name: Install cargo-udeps
run: cargo install cargo-udeps || true
# тоже не валит билд
- run: cargo udeps || true

38
Cargo.lock generated
View File

@ -1454,9 +1454,9 @@ dependencies = [
[[package]]
name = "iri-string"
version = "0.7.10"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a"
checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb"
dependencies = [
"memchr",
"serde",
@ -1486,7 +1486,7 @@ dependencies = [
"cesu8",
"cfg-if",
"combine",
"jni-sys",
"jni-sys 0.3.1",
"log",
"thiserror 1.0.69",
"walkdir",
@ -1495,9 +1495,31 @@ dependencies = [
[[package]]
name = "jni-sys"
version = "0.3.0"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130"
checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258"
dependencies = [
"jni-sys 0.4.1",
]
[[package]]
name = "jni-sys"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2"
dependencies = [
"jni-sys-macros",
]
[[package]]
name = "jni-sys-macros"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264"
dependencies = [
"quote",
"syn",
]
[[package]]
name = "jobserver"
@ -1659,9 +1681,9 @@ dependencies = [
[[package]]
name = "moka"
version = "0.12.14"
version = "0.12.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85f8024e1c8e71c778968af91d43700ce1d11b219d127d79fb2934153b82b42b"
checksum = "957228ad12042ee839f93c8f257b62b4c0ab5eaae1d4fa60de53b27c9d7c5046"
dependencies = [
"crossbeam-channel",
"crossbeam-epoch",
@ -2771,7 +2793,7 @@ checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
[[package]]
name = "telemt"
version = "3.3.29"
version = "3.3.30"
dependencies = [
"aes",
"anyhow",

View File

@ -1,8 +1,11 @@
[package]
name = "telemt"
version = "3.3.29"
version = "3.3.30"
edition = "2024"
[features]
redteam_offline_expected_fail = []
[dependencies]
# C
libc = "0.2"

View File

@ -1,29 +1,9 @@
# syntax=docker/dockerfile:1
# ==========================
# Stage 1: Build
# ==========================
FROM rust:1.88-slim-bookworm AS builder
RUN apt-get update && apt-get install -y --no-install-recommends \
pkg-config \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /build
# Depcache
COPY Cargo.toml Cargo.lock* ./
RUN mkdir src && echo 'fn main() {}' > src/main.rs && \
cargo build --release 2>/dev/null || true && \
rm -rf src
# Build
COPY . .
RUN cargo build --release && strip target/release/telemt
ARG BINARY
# ==========================
# Stage 2: Compress (strip + UPX)
# Stage: minimal
# ==========================
FROM debian:12-slim AS minimal
@ -33,7 +13,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
&& rm -rf /var/lib/apt/lists/* \
\
# install UPX from Telemt releases
&& curl -fL \
--retry 5 \
--retry-delay 3 \
@ -46,15 +25,15 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& chmod +x /usr/local/bin/upx \
&& rm -rf /tmp/upx*
COPY --from=builder /build/target/release/telemt /telemt
COPY ${BINARY} /telemt
RUN strip /telemt || true
RUN upx --best --lzma /telemt || true
# ==========================
# Stage 3: Debug base
# Debug image
# ==========================
FROM debian:12-slim AS debug-base
FROM debian:12-slim AS debug
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
@ -64,48 +43,29 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
busybox \
&& rm -rf /var/lib/apt/lists/*
# ==========================
# Stage 4: Debug image
# ==========================
FROM debug-base AS debug
WORKDIR /app
COPY --from=minimal /telemt /app/telemt
COPY config.toml /app/config.toml
USER root
EXPOSE 443
EXPOSE 9090
EXPOSE 9091
EXPOSE 443 9090 9091
ENTRYPOINT ["/app/telemt"]
CMD ["config.toml"]
# ==========================
# Stage 5: Production (distroless)
# Production (REAL distroless)
# ==========================
FROM gcr.io/distroless/base-debian12 AS prod
FROM gcr.io/distroless/static-debian12 AS prod
WORKDIR /app
COPY --from=minimal /telemt /app/telemt
COPY config.toml /app/config.toml
# TLS + timezone + shell
COPY --from=debug-base /etc/ssl/certs /etc/ssl/certs
COPY --from=debug-base /usr/share/zoneinfo /usr/share/zoneinfo
COPY --from=debug-base /bin/busybox /bin/busybox
RUN ["/bin/busybox", "--install", "-s", "/bin"]
# distroless user
USER nonroot:nonroot
EXPOSE 443
EXPOSE 9090
EXPOSE 9091
EXPOSE 443 9090 9091
ENTRYPOINT ["/app/telemt"]
CMD ["config.toml"]

View File

@ -202,12 +202,15 @@ This document lists all configuration keys accepted by `config.toml`.
| listen_tcp | `bool \| null` | `null` (auto) | — | Explicit TCP listener enable/disable override. |
| proxy_protocol | `bool` | `false` | — | Enables HAProxy PROXY protocol parsing on incoming client connections. |
| proxy_protocol_header_timeout_ms | `u64` | `500` | Must be `> 0`. | Timeout for PROXY protocol header read/parse (ms). |
| proxy_protocol_trusted_cidrs | `IpNetwork[]` | `[]` | — | When non-empty, only connections from these proxy source CIDRs are allowed to provide PROXY protocol headers. If empty, PROXY headers are rejected by default (security hardening). |
| metrics_port | `u16 \| null` | `null` | — | Metrics endpoint port (enables metrics listener). |
| metrics_listen | `String \| null` | `null` | — | Full metrics bind address (`IP:PORT`), overrides `metrics_port`. |
| metrics_whitelist | `IpNetwork[]` | `["127.0.0.1/32", "::1/128"]` | — | CIDR whitelist for metrics endpoint access. |
| max_connections | `u32` | `10000` | — | Max concurrent client connections (`0` = unlimited). |
| accept_permit_timeout_ms | `u64` | `250` | `0..=60000`. | Maximum wait for acquiring a connection-slot permit before the accepted connection is dropped (`0` keeps legacy unbounded wait). |
Note: When `server.proxy_protocol` is enabled, incoming PROXY protocol headers are parsed from the first bytes of the connection and the client source address is replaced with `src_addr` from the header. For security, the peer source IP (the direct connection address) is verified against `server.proxy_protocol_trusted_cidrs`; if this list is empty, PROXY headers are rejected and the connection is considered untrusted.
## [server.api]
| Parameter | Type | Default | Constraints / validation | Description |
@ -271,6 +274,8 @@ This document lists all configuration keys accepted by `config.toml`.
| mask_shape_bucket_cap_bytes | `usize` | `4096` | Must be `>= mask_shape_bucket_floor_bytes`. | Maximum bucket size used by shape-channel hardening; traffic above cap is not padded further. |
| mask_shape_above_cap_blur | `bool` | `false` | Requires `mask_shape_hardening = true`; requires `mask_shape_above_cap_blur_max_bytes > 0`. | Adds bounded randomized tail bytes even when forwarded size already exceeds cap. |
| mask_shape_above_cap_blur_max_bytes | `usize` | `512` | Must be `<= 1048576`; must be `> 0` when `mask_shape_above_cap_blur = true`. | Maximum randomized extra bytes appended above cap. |
| mask_relay_max_bytes | `usize` | `5242880` | Must be `> 0`; must be `<= 67108864`. | Maximum relayed bytes per direction on unauthenticated masking fallback path. |
| mask_classifier_prefetch_timeout_ms | `u64` | `5` | Must be within `[5, 50]`. | Timeout budget (ms) for extending fragmented initial classifier window on masking fallback. |
| mask_timing_normalization_enabled | `bool` | `false` | Requires `mask_timing_normalization_floor_ms > 0`; requires `ceiling >= floor`. | Enables timing envelope normalization on masking outcomes. |
| mask_timing_normalization_floor_ms | `u64` | `0` | Must be `> 0` when timing normalization is enabled; must be `<= ceiling`. | Lower bound (ms) for masking outcome normalization target. |
| mask_timing_normalization_ceiling_ms | `u64` | `0` | Must be `>= floor`; must be `<= 60000`. | Upper bound (ms) for masking outcome normalization target. |

View File

@ -63,7 +63,7 @@ recommended range from 5 to 2147483647 inclusive
> [!IMPORTANT]
> It is recommended to use your own, unique values.\
> You can use the [generator](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/e8b269ff0089a27effd88f8d925179b78e5666c4/awg-gen.html) to select parameters.
> You can use the [generator](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/13f5517ca473b47c412b9a99407066de973732bd/awg-gen.html) to select parameters.
#### Server B Configuration (Netherlands):
@ -84,6 +84,8 @@ Jmin = 8
Jmax = 80
S1 = 29
S2 = 15
S3 = 18
S4 = 0
H1 = 2087563914
H2 = 188817757
H3 = 101784570
@ -121,6 +123,8 @@ Jmin = 8
Jmax = 80
S1 = 29
S2 = 15
S3 = 18
S4 = 0
H1 = 2087563914
H2 = 188817757
H3 = 101784570

View File

@ -44,7 +44,7 @@ awg genkey | tee private.key | awg pubkey > public.key
Параметры обфускации `S1`, `S2`, `H1`, `H2`, `H3`, `H4` должны быть строго идентичными на обоих серверах.\
Параметры `Jc`, `Jmin` и `Jmax` могут отличатся.\
Параметры `I1-I5` [(Custom Protocol Signature)](https://docs.amnezia.org/documentation/amnezia-wg/) нужно указывать на стороне _клиента_ (Сервер **А**).
Параметры `I1-I5` ([Custom Protocol Signature](https://docs.amnezia.org/documentation/amnezia-wg/)) нужно указывать на стороне _клиента_ (Сервер **А**).
Рекомендации по выбору значений:
```text
@ -62,7 +62,7 @@ H1/H2/H3/H4 — должны быть уникальны и отличаться
```
> [!IMPORTANT]
> Рекомендуется использовать собственные, уникальные значения.\
> Для выбора параметров можете воспользоваться [генератором](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/e8b269ff0089a27effd88f8d925179b78e5666c4/awg-gen.html).
> Для выбора параметров можете воспользоваться [генератором](https://htmlpreview.github.io/?https://gist.githubusercontent.com/avbor/955782b5c37b06240b243aa375baeac5/raw/13f5517ca473b47c412b9a99407066de973732bd/awg-gen.html).
#### Конфигурация Сервера B (_Нидерланды_):
@ -83,6 +83,8 @@ Jmin = 8
Jmax = 80
S1 = 29
S2 = 15
S3 = 18
S4 = 0
H1 = 2087563914
H2 = 188817757
H3 = 101784570
@ -121,6 +123,8 @@ Jmin = 8
Jmax = 80
S1 = 29
S2 = 15
S3 = 18
S4 = 0
H1 = 2087563914
H2 = 188817757
H3 = 101784570
@ -272,7 +276,7 @@ backend telemt_nodes
```
>[!WARNING]
>**Файл должен заканчиваться пустой строкой, иначе HAProxy не запуститься!**
>**Файл должен заканчиваться пустой строкой, иначе HAProxy не запустится!**
#### Разрешаем порт 443\tcp в фаерволе (если включен)
```bash

View File

@ -71,6 +71,22 @@ pub(crate) fn default_tls_fetch_scope() -> String {
String::new()
}
pub(crate) fn default_tls_fetch_attempt_timeout_ms() -> u64 {
5_000
}
pub(crate) fn default_tls_fetch_total_budget_ms() -> u64 {
15_000
}
pub(crate) fn default_tls_fetch_strict_route() -> bool {
true
}
pub(crate) fn default_tls_fetch_profile_cache_ttl_secs() -> u64 {
600
}
pub(crate) fn default_mask_port() -> u16 {
443
}
@ -185,6 +201,10 @@ pub(crate) fn default_proxy_protocol_header_timeout_ms() -> u64 {
500
}
pub(crate) fn default_proxy_protocol_trusted_cidrs() -> Vec<IpNetwork> {
vec!["0.0.0.0/0".parse().unwrap(), "::/0".parse().unwrap()]
}
pub(crate) fn default_server_max_connections() -> u32 {
10_000
}
@ -553,6 +573,20 @@ pub(crate) fn default_mask_shape_above_cap_blur_max_bytes() -> usize {
512
}
#[cfg(not(test))]
pub(crate) fn default_mask_relay_max_bytes() -> usize {
5 * 1024 * 1024
}
#[cfg(test)]
pub(crate) fn default_mask_relay_max_bytes() -> usize {
32 * 1024
}
pub(crate) fn default_mask_classifier_prefetch_timeout_ms() -> u64 {
5
}
pub(crate) fn default_mask_timing_normalization_enabled() -> bool {
false
}

View File

@ -228,7 +228,9 @@ impl HotFields {
me_d2c_flush_batch_max_delay_us: cfg.general.me_d2c_flush_batch_max_delay_us,
me_d2c_ack_flush_immediate: cfg.general.me_d2c_ack_flush_immediate,
me_quota_soft_overshoot_bytes: cfg.general.me_quota_soft_overshoot_bytes,
me_d2c_frame_buf_shrink_threshold_bytes: cfg.general.me_d2c_frame_buf_shrink_threshold_bytes,
me_d2c_frame_buf_shrink_threshold_bytes: cfg
.general
.me_d2c_frame_buf_shrink_threshold_bytes,
direct_relay_copy_buf_c2s_bytes: cfg.general.direct_relay_copy_buf_c2s_bytes,
direct_relay_copy_buf_s2c_bytes: cfg.general.direct_relay_copy_buf_s2c_bytes,
me_health_interval_ms_unhealthy: cfg.general.me_health_interval_ms_unhealthy,
@ -600,6 +602,9 @@ fn warn_non_hot_changes(old: &ProxyConfig, new: &ProxyConfig, non_hot_changed: b
|| old.censorship.mask_shape_above_cap_blur != new.censorship.mask_shape_above_cap_blur
|| old.censorship.mask_shape_above_cap_blur_max_bytes
!= new.censorship.mask_shape_above_cap_blur_max_bytes
|| old.censorship.mask_relay_max_bytes != new.censorship.mask_relay_max_bytes
|| old.censorship.mask_classifier_prefetch_timeout_ms
!= new.censorship.mask_classifier_prefetch_timeout_ms
|| old.censorship.mask_timing_normalization_enabled
!= new.censorship.mask_timing_normalization_enabled
|| old.censorship.mask_timing_normalization_floor_ms

View File

@ -1,6 +1,6 @@
#![allow(deprecated)]
use std::collections::{BTreeSet, HashMap};
use std::collections::{BTreeSet, HashMap, HashSet};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::net::{IpAddr, SocketAddr};
use std::path::{Path, PathBuf};
@ -346,6 +346,12 @@ impl ProxyConfig {
));
}
if config.timeouts.tg_connect == 0 {
return Err(ProxyError::Config(
"timeouts.tg_connect must be > 0".to_string(),
));
}
if config.general.upstream_unhealthy_fail_threshold == 0 {
return Err(ProxyError::Config(
"general.upstream_unhealthy_fail_threshold must be > 0".to_string(),
@ -430,6 +436,24 @@ impl ProxyConfig {
));
}
if config.censorship.mask_relay_max_bytes == 0 {
return Err(ProxyError::Config(
"censorship.mask_relay_max_bytes must be > 0".to_string(),
));
}
if config.censorship.mask_relay_max_bytes > 67_108_864 {
return Err(ProxyError::Config(
"censorship.mask_relay_max_bytes must be <= 67108864".to_string(),
));
}
if !(5..=50).contains(&config.censorship.mask_classifier_prefetch_timeout_ms) {
return Err(ProxyError::Config(
"censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]".to_string(),
));
}
if config.censorship.mask_timing_normalization_ceiling_ms
< config.censorship.mask_timing_normalization_floor_ms
{
@ -539,7 +563,9 @@ impl ProxyConfig {
));
}
if !(4096..=16 * 1024 * 1024).contains(&config.general.me_d2c_frame_buf_shrink_threshold_bytes) {
if !(4096..=16 * 1024 * 1024)
.contains(&config.general.me_d2c_frame_buf_shrink_threshold_bytes)
{
return Err(ProxyError::Config(
"general.me_d2c_frame_buf_shrink_threshold_bytes must be within [4096, 16777216]"
.to_string(),
@ -957,6 +983,28 @@ impl ProxyConfig {
// Normalize optional TLS fetch scope: whitespace-only values disable scoped routing.
config.censorship.tls_fetch_scope = config.censorship.tls_fetch_scope.trim().to_string();
if config.censorship.tls_fetch.profiles.is_empty() {
config.censorship.tls_fetch.profiles = TlsFetchConfig::default().profiles;
} else {
let mut seen = HashSet::new();
config
.censorship
.tls_fetch
.profiles
.retain(|profile| seen.insert(*profile));
}
if config.censorship.tls_fetch.attempt_timeout_ms == 0 {
return Err(ProxyError::Config(
"censorship.tls_fetch.attempt_timeout_ms must be > 0".to_string(),
));
}
if config.censorship.tls_fetch.total_budget_ms == 0 {
return Err(ProxyError::Config(
"censorship.tls_fetch.total_budget_ms must be > 0".to_string(),
));
}
// Merge primary + extra TLS domains, deduplicate (primary always first).
if !config.censorship.tls_domains.is_empty() {
let mut all = Vec::with_capacity(1 + config.censorship.tls_domains.len());
@ -1134,6 +1182,10 @@ mod load_security_tests;
#[path = "tests/load_mask_shape_security_tests.rs"]
mod load_mask_shape_security_tests;
#[cfg(test)]
#[path = "tests/load_mask_classifier_prefetch_timeout_security_tests.rs"]
mod load_mask_classifier_prefetch_timeout_security_tests;
#[cfg(test)]
mod tests {
use super::*;
@ -1239,6 +1291,11 @@ mod tests {
assert_eq!(cfg.general.update_every, default_update_every());
assert_eq!(cfg.server.listen_addr_ipv4, default_listen_addr_ipv4());
assert_eq!(cfg.server.listen_addr_ipv6, default_listen_addr_ipv6_opt());
assert_eq!(
cfg.server.proxy_protocol_trusted_cidrs,
default_proxy_protocol_trusted_cidrs()
);
assert_eq!(cfg.censorship.unknown_sni_action, UnknownSniAction::Drop);
assert_eq!(cfg.server.api.listen, default_api_listen());
assert_eq!(cfg.server.api.whitelist, default_api_whitelist());
assert_eq!(
@ -1371,6 +1428,14 @@ mod tests {
let server = ServerConfig::default();
assert_eq!(server.listen_addr_ipv6, Some(default_listen_addr_ipv6()));
assert_eq!(
server.proxy_protocol_trusted_cidrs,
default_proxy_protocol_trusted_cidrs()
);
assert_eq!(
AntiCensorshipConfig::default().unknown_sni_action,
UnknownSniAction::Drop
);
assert_eq!(server.api.listen, default_api_listen());
assert_eq!(server.api.whitelist, default_api_whitelist());
assert_eq!(
@ -1406,6 +1471,75 @@ mod tests {
assert_eq!(access.users, default_access_users());
}
#[test]
fn proxy_protocol_trusted_cidrs_missing_uses_trust_all_but_explicit_empty_stays_empty() {
let cfg_missing: ProxyConfig = toml::from_str(
r#"
[server]
[general]
[network]
[access]
"#,
)
.unwrap();
assert_eq!(
cfg_missing.server.proxy_protocol_trusted_cidrs,
default_proxy_protocol_trusted_cidrs()
);
let cfg_explicit_empty: ProxyConfig = toml::from_str(
r#"
[server]
proxy_protocol_trusted_cidrs = []
[general]
[network]
[access]
"#,
)
.unwrap();
assert!(
cfg_explicit_empty
.server
.proxy_protocol_trusted_cidrs
.is_empty()
);
}
#[test]
fn unknown_sni_action_parses_and_defaults_to_drop() {
let cfg_default: ProxyConfig = toml::from_str(
r#"
[server]
[general]
[network]
[access]
[censorship]
"#,
)
.unwrap();
assert_eq!(
cfg_default.censorship.unknown_sni_action,
UnknownSniAction::Drop
);
let cfg_mask: ProxyConfig = toml::from_str(
r#"
[server]
[general]
[network]
[access]
[censorship]
unknown_sni_action = "mask"
"#,
)
.unwrap();
assert_eq!(
cfg_mask.censorship.unknown_sni_action,
UnknownSniAction::Mask
);
}
#[test]
fn dc_overrides_allow_string_and_array() {
let toml = r#"
@ -1777,6 +1911,26 @@ mod tests {
let _ = std::fs::remove_file(path);
}
#[test]
fn tg_connect_zero_is_rejected() {
let toml = r#"
[timeouts]
tg_connect = 0
[censorship]
tls_domain = "example.com"
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_tg_connect_zero_test.toml");
std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!(err.contains("timeouts.tg_connect must be > 0"));
let _ = std::fs::remove_file(path);
}
#[test]
fn rpc_proxy_req_every_out_of_range_is_rejected() {
let toml = r#"
@ -2353,6 +2507,94 @@ mod tests {
let _ = std::fs::remove_file(path);
}
#[test]
fn tls_fetch_defaults_are_applied() {
let toml = r#"
[censorship]
tls_domain = "example.com"
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_tls_fetch_defaults_test.toml");
std::fs::write(&path, toml).unwrap();
let cfg = ProxyConfig::load(&path).unwrap();
assert_eq!(
cfg.censorship.tls_fetch.profiles,
TlsFetchConfig::default().profiles
);
assert!(cfg.censorship.tls_fetch.strict_route);
assert_eq!(cfg.censorship.tls_fetch.attempt_timeout_ms, 5_000);
assert_eq!(cfg.censorship.tls_fetch.total_budget_ms, 15_000);
assert_eq!(cfg.censorship.tls_fetch.profile_cache_ttl_secs, 600);
let _ = std::fs::remove_file(path);
}
#[test]
fn tls_fetch_profiles_are_deduplicated_preserving_order() {
let toml = r#"
[censorship]
tls_domain = "example.com"
[censorship.tls_fetch]
profiles = ["compat_tls12", "modern_chrome_like", "compat_tls12", "legacy_minimal"]
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_tls_fetch_profiles_dedup_test.toml");
std::fs::write(&path, toml).unwrap();
let cfg = ProxyConfig::load(&path).unwrap();
assert_eq!(
cfg.censorship.tls_fetch.profiles,
vec![
TlsFetchProfile::CompatTls12,
TlsFetchProfile::ModernChromeLike,
TlsFetchProfile::LegacyMinimal
]
);
let _ = std::fs::remove_file(path);
}
#[test]
fn tls_fetch_attempt_timeout_zero_is_rejected() {
let toml = r#"
[censorship]
tls_domain = "example.com"
[censorship.tls_fetch]
attempt_timeout_ms = 0
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_tls_fetch_attempt_timeout_zero_test.toml");
std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!(err.contains("censorship.tls_fetch.attempt_timeout_ms must be > 0"));
let _ = std::fs::remove_file(path);
}
#[test]
fn tls_fetch_total_budget_zero_is_rejected() {
let toml = r#"
[censorship]
tls_domain = "example.com"
[censorship.tls_fetch]
total_budget_ms = 0
[access.users]
user = "00000000000000000000000000000000"
"#;
let dir = std::env::temp_dir();
let path = dir.join("telemt_tls_fetch_total_budget_zero_test.toml");
std::fs::write(&path, toml).unwrap();
let err = ProxyConfig::load(&path).unwrap_err().to_string();
assert!(err.contains("censorship.tls_fetch.total_budget_ms must be > 0"));
let _ = std::fs::remove_file(path);
}
#[test]
fn invalid_ad_tag_is_disabled_during_load() {
let toml = r#"

View File

@ -0,0 +1,76 @@
use super::*;
use std::fs;
use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};
fn write_temp_config(contents: &str) -> PathBuf {
let nonce = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system time must be after unix epoch")
.as_nanos();
let path = std::env::temp_dir().join(format!(
"telemt-load-mask-prefetch-timeout-security-{nonce}.toml"
));
fs::write(&path, contents).expect("temp config write must succeed");
path
}
fn remove_temp_config(path: &PathBuf) {
let _ = fs::remove_file(path);
}
#[test]
fn load_rejects_mask_classifier_prefetch_timeout_below_min_bound() {
let path = write_temp_config(
r#"
[censorship]
mask_classifier_prefetch_timeout_ms = 4
"#,
);
let err = ProxyConfig::load(&path)
.expect_err("prefetch timeout below minimum security bound must be rejected");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"),
"error must explain timeout bound invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_mask_classifier_prefetch_timeout_above_max_bound() {
let path = write_temp_config(
r#"
[censorship]
mask_classifier_prefetch_timeout_ms = 51
"#,
);
let err = ProxyConfig::load(&path)
.expect_err("prefetch timeout above max security bound must be rejected");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_classifier_prefetch_timeout_ms must be within [5, 50]"),
"error must explain timeout bound invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_accepts_mask_classifier_prefetch_timeout_within_bounds() {
let path = write_temp_config(
r#"
[censorship]
mask_classifier_prefetch_timeout_ms = 20
"#,
);
let cfg =
ProxyConfig::load(&path).expect("prefetch timeout within security bounds must be accepted");
assert_eq!(cfg.censorship.mask_classifier_prefetch_timeout_ms, 20);
remove_temp_config(&path);
}

View File

@ -236,3 +236,57 @@ mask_shape_above_cap_blur_max_bytes = 8
remove_temp_config(&path);
}
#[test]
fn load_rejects_zero_mask_relay_max_bytes() {
let path = write_temp_config(
r#"
[censorship]
mask_relay_max_bytes = 0
"#,
);
let err = ProxyConfig::load(&path).expect_err("mask_relay_max_bytes must be > 0");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_relay_max_bytes must be > 0"),
"error must explain non-zero relay cap invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_rejects_mask_relay_max_bytes_above_upper_bound() {
let path = write_temp_config(
r#"
[censorship]
mask_relay_max_bytes = 67108865
"#,
);
let err =
ProxyConfig::load(&path).expect_err("mask_relay_max_bytes above hard cap must be rejected");
let msg = err.to_string();
assert!(
msg.contains("censorship.mask_relay_max_bytes must be <= 67108864"),
"error must explain relay cap upper bound invariant, got: {msg}"
);
remove_temp_config(&path);
}
#[test]
fn load_accepts_valid_mask_relay_max_bytes() {
let path = write_temp_config(
r#"
[censorship]
mask_relay_max_bytes = 8388608
"#,
);
let cfg = ProxyConfig::load(&path).expect("valid mask_relay_max_bytes must be accepted");
assert_eq!(cfg.censorship.mask_relay_max_bytes, 8_388_608);
remove_temp_config(&path);
}

View File

@ -954,7 +954,8 @@ impl Default for GeneralConfig {
me_d2c_flush_batch_max_delay_us: default_me_d2c_flush_batch_max_delay_us(),
me_d2c_ack_flush_immediate: default_me_d2c_ack_flush_immediate(),
me_quota_soft_overshoot_bytes: default_me_quota_soft_overshoot_bytes(),
me_d2c_frame_buf_shrink_threshold_bytes: default_me_d2c_frame_buf_shrink_threshold_bytes(),
me_d2c_frame_buf_shrink_threshold_bytes:
default_me_d2c_frame_buf_shrink_threshold_bytes(),
direct_relay_copy_buf_c2s_bytes: default_direct_relay_copy_buf_c2s_bytes(),
direct_relay_copy_buf_s2c_bytes: default_direct_relay_copy_buf_s2c_bytes(),
me_warmup_stagger_enabled: default_true(),
@ -1239,9 +1240,10 @@ pub struct ServerConfig {
/// Trusted source CIDRs allowed to send incoming PROXY protocol headers.
///
/// When non-empty, connections from addresses outside this allowlist are
/// rejected before `src_addr` is applied.
#[serde(default)]
/// If this field is omitted in config, it defaults to trust-all CIDRs
/// (`0.0.0.0/0` and `::/0`). If it is explicitly set to an empty list,
/// all PROXY protocol headers are rejected.
#[serde(default = "default_proxy_protocol_trusted_cidrs")]
pub proxy_protocol_trusted_cidrs: Vec<IpNetwork>,
/// Port for the Prometheus-compatible metrics endpoint.
@ -1286,7 +1288,7 @@ impl Default for ServerConfig {
listen_tcp: None,
proxy_protocol: false,
proxy_protocol_header_timeout_ms: default_proxy_protocol_header_timeout_ms(),
proxy_protocol_trusted_cidrs: Vec::new(),
proxy_protocol_trusted_cidrs: default_proxy_protocol_trusted_cidrs(),
metrics_port: None,
metrics_listen: None,
metrics_whitelist: default_metrics_whitelist(),
@ -1357,6 +1359,90 @@ impl Default for TimeoutsConfig {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum UnknownSniAction {
#[default]
Drop,
Mask,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TlsFetchProfile {
ModernChromeLike,
ModernFirefoxLike,
CompatTls12,
LegacyMinimal,
}
impl TlsFetchProfile {
pub fn as_str(self) -> &'static str {
match self {
TlsFetchProfile::ModernChromeLike => "modern_chrome_like",
TlsFetchProfile::ModernFirefoxLike => "modern_firefox_like",
TlsFetchProfile::CompatTls12 => "compat_tls12",
TlsFetchProfile::LegacyMinimal => "legacy_minimal",
}
}
}
fn default_tls_fetch_profiles() -> Vec<TlsFetchProfile> {
vec![
TlsFetchProfile::ModernChromeLike,
TlsFetchProfile::ModernFirefoxLike,
TlsFetchProfile::CompatTls12,
TlsFetchProfile::LegacyMinimal,
]
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsFetchConfig {
/// Ordered list of ClientHello profiles used for adaptive fallback.
#[serde(default = "default_tls_fetch_profiles")]
pub profiles: Vec<TlsFetchProfile>,
/// When true and upstream route is configured, TLS fetch fails closed on
/// upstream connect errors and does not fallback to direct TCP.
#[serde(default = "default_tls_fetch_strict_route")]
pub strict_route: bool,
/// Timeout per one profile attempt in milliseconds.
#[serde(default = "default_tls_fetch_attempt_timeout_ms")]
pub attempt_timeout_ms: u64,
/// Total wall-clock budget in milliseconds across all profile attempts.
#[serde(default = "default_tls_fetch_total_budget_ms")]
pub total_budget_ms: u64,
/// Adds GREASE-style values into selected ClientHello extensions.
#[serde(default)]
pub grease_enabled: bool,
/// Produces deterministic ClientHello randomness for debugging/tests.
#[serde(default)]
pub deterministic: bool,
/// TTL for winner-profile cache entries in seconds.
/// Set to 0 to disable profile cache.
#[serde(default = "default_tls_fetch_profile_cache_ttl_secs")]
pub profile_cache_ttl_secs: u64,
}
impl Default for TlsFetchConfig {
fn default() -> Self {
Self {
profiles: default_tls_fetch_profiles(),
strict_route: default_tls_fetch_strict_route(),
attempt_timeout_ms: default_tls_fetch_attempt_timeout_ms(),
total_budget_ms: default_tls_fetch_total_budget_ms(),
grease_enabled: false,
deterministic: false,
profile_cache_ttl_secs: default_tls_fetch_profile_cache_ttl_secs(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AntiCensorshipConfig {
#[serde(default = "default_tls_domain")]
@ -1366,11 +1452,19 @@ pub struct AntiCensorshipConfig {
#[serde(default)]
pub tls_domains: Vec<String>,
/// Policy for TLS ClientHello with unknown (non-configured) SNI.
#[serde(default)]
pub unknown_sni_action: UnknownSniAction,
/// Upstream scope used for TLS front metadata fetches.
/// Empty value keeps default upstream routing behavior.
#[serde(default = "default_tls_fetch_scope")]
pub tls_fetch_scope: String,
/// Fetch strategy for TLS front metadata bootstrap and periodic refresh.
#[serde(default)]
pub tls_fetch: TlsFetchConfig,
#[serde(default = "default_true")]
pub mask: bool,
@ -1450,6 +1544,14 @@ pub struct AntiCensorshipConfig {
#[serde(default = "default_mask_shape_above_cap_blur_max_bytes")]
pub mask_shape_above_cap_blur_max_bytes: usize,
/// Maximum bytes relayed per direction on unauthenticated masking fallback paths.
#[serde(default = "default_mask_relay_max_bytes")]
pub mask_relay_max_bytes: usize,
/// Prefetch timeout (ms) for extending fragmented masking classifier window.
#[serde(default = "default_mask_classifier_prefetch_timeout_ms")]
pub mask_classifier_prefetch_timeout_ms: u64,
/// Enable outcome-time normalization envelope for masking fallback.
#[serde(default = "default_mask_timing_normalization_enabled")]
pub mask_timing_normalization_enabled: bool,
@ -1468,7 +1570,9 @@ impl Default for AntiCensorshipConfig {
Self {
tls_domain: default_tls_domain(),
tls_domains: Vec::new(),
unknown_sni_action: UnknownSniAction::Drop,
tls_fetch_scope: default_tls_fetch_scope(),
tls_fetch: TlsFetchConfig::default(),
mask: default_true(),
mask_host: None,
mask_port: default_mask_port(),
@ -1488,6 +1592,8 @@ impl Default for AntiCensorshipConfig {
mask_shape_bucket_cap_bytes: default_mask_shape_bucket_cap_bytes(),
mask_shape_above_cap_blur: default_mask_shape_above_cap_blur(),
mask_shape_above_cap_blur_max_bytes: default_mask_shape_above_cap_blur_max_bytes(),
mask_relay_max_bytes: default_mask_relay_max_bytes(),
mask_classifier_prefetch_timeout_ms: default_mask_classifier_prefetch_timeout_ms(),
mask_timing_normalization_enabled: default_mask_timing_normalization_enabled(),
mask_timing_normalization_floor_ms: default_mask_timing_normalization_floor_ms(),
mask_timing_normalization_ceiling_ms: default_mask_timing_normalization_ceiling_ms(),

View File

@ -216,6 +216,9 @@ pub enum ProxyError {
#[error("Invalid proxy protocol header")]
InvalidProxyProtocol,
#[error("Unknown TLS SNI")]
UnknownTlsSni,
#[error("Proxy error: {0}")]
Proxy(String),

View File

@ -8,8 +8,10 @@ use tracing::{debug, error, info, warn};
use crate::cli;
use crate::config::ProxyConfig;
use crate::transport::UpstreamManager;
use crate::transport::middle_proxy::{
ProxyConfigData, fetch_proxy_config_with_raw, load_proxy_config_cache, save_proxy_config_cache,
ProxyConfigData, fetch_proxy_config_with_raw_via_upstream, load_proxy_config_cache,
save_proxy_config_cache,
};
pub(crate) fn resolve_runtime_config_path(
@ -288,9 +290,10 @@ pub(crate) async fn load_startup_proxy_config_snapshot(
cache_path: Option<&str>,
me2dc_fallback: bool,
label: &'static str,
upstream: Option<std::sync::Arc<UpstreamManager>>,
) -> Option<ProxyConfigData> {
loop {
match fetch_proxy_config_with_raw(url).await {
match fetch_proxy_config_with_raw_via_upstream(url, upstream.clone()).await {
Ok((cfg, raw)) => {
if !cfg.map.is_empty() {
if let Some(path) = cache_path

View File

@ -63,9 +63,10 @@ pub(crate) async fn initialize_me_pool(
let proxy_secret_path = config.general.proxy_secret_path.as_deref();
let pool_size = config.general.middle_proxy_pool_size.max(1);
let proxy_secret = loop {
match crate::transport::middle_proxy::fetch_proxy_secret(
match crate::transport::middle_proxy::fetch_proxy_secret_with_upstream(
proxy_secret_path,
config.general.proxy_secret_len_max,
Some(upstream_manager.clone()),
)
.await
{
@ -129,6 +130,7 @@ pub(crate) async fn initialize_me_pool(
config.general.proxy_config_v4_cache_path.as_deref(),
me2dc_fallback,
"getProxyConfig",
Some(upstream_manager.clone()),
)
.await;
if cfg_v4.is_some() {
@ -160,6 +162,7 @@ pub(crate) async fn initialize_me_pool(
config.general.proxy_config_v6_cache_path.as_deref(),
me2dc_fallback,
"getProxyConfigV6",
Some(upstream_manager.clone()),
)
.await;
if cfg_v6.is_some() {

View File

@ -225,6 +225,7 @@ pub async fn run() -> std::result::Result<(), Box<dyn std::error::Error>> {
config.general.upstream_connect_retry_attempts,
config.general.upstream_connect_retry_backoff_ms,
config.general.upstream_connect_budget_ms,
config.timeouts.tg_connect,
config.general.upstream_unhealthy_fail_threshold,
config.general.upstream_connect_failfast_hard_errors,
stats.clone(),

View File

@ -7,6 +7,7 @@ use tracing::warn;
use crate::config::ProxyConfig;
use crate::startup::{COMPONENT_TLS_FRONT_BOOTSTRAP, StartupTracker};
use crate::tls_front::TlsFrontCache;
use crate::tls_front::fetcher::TlsFetchStrategy;
use crate::transport::UpstreamManager;
pub(crate) async fn bootstrap_tls_front(
@ -40,7 +41,17 @@ pub(crate) async fn bootstrap_tls_front(
let mask_unix_sock = config.censorship.mask_unix_sock.clone();
let tls_fetch_scope = (!config.censorship.tls_fetch_scope.is_empty())
.then(|| config.censorship.tls_fetch_scope.clone());
let fetch_timeout = Duration::from_secs(5);
let tls_fetch = config.censorship.tls_fetch.clone();
let fetch_strategy = TlsFetchStrategy {
profiles: tls_fetch.profiles,
strict_route: tls_fetch.strict_route,
attempt_timeout: Duration::from_millis(tls_fetch.attempt_timeout_ms.max(1)),
total_budget: Duration::from_millis(tls_fetch.total_budget_ms.max(1)),
grease_enabled: tls_fetch.grease_enabled,
deterministic: tls_fetch.deterministic,
profile_cache_ttl: Duration::from_secs(tls_fetch.profile_cache_ttl_secs),
};
let fetch_timeout = fetch_strategy.total_budget;
let cache_initial = cache.clone();
let domains_initial = tls_domains.to_vec();
@ -48,6 +59,7 @@ pub(crate) async fn bootstrap_tls_front(
let unix_sock_initial = mask_unix_sock.clone();
let scope_initial = tls_fetch_scope.clone();
let upstream_initial = upstream_manager.clone();
let strategy_initial = fetch_strategy.clone();
tokio::spawn(async move {
let mut join = tokio::task::JoinSet::new();
for domain in domains_initial {
@ -56,12 +68,13 @@ pub(crate) async fn bootstrap_tls_front(
let unix_sock_domain = unix_sock_initial.clone();
let scope_domain = scope_initial.clone();
let upstream_domain = upstream_initial.clone();
let strategy_domain = strategy_initial.clone();
join.spawn(async move {
match crate::tls_front::fetcher::fetch_real_tls(
match crate::tls_front::fetcher::fetch_real_tls_with_strategy(
&host_domain,
port,
&domain,
fetch_timeout,
&strategy_domain,
Some(upstream_domain),
scope_domain.as_deref(),
proxy_protocol,
@ -107,6 +120,7 @@ pub(crate) async fn bootstrap_tls_front(
let unix_sock_refresh = mask_unix_sock.clone();
let scope_refresh = tls_fetch_scope.clone();
let upstream_refresh = upstream_manager.clone();
let strategy_refresh = fetch_strategy.clone();
tokio::spawn(async move {
loop {
let base_secs = rand::rng().random_range(4 * 3600..=6 * 3600);
@ -120,12 +134,13 @@ pub(crate) async fn bootstrap_tls_front(
let unix_sock_domain = unix_sock_refresh.clone();
let scope_domain = scope_refresh.clone();
let upstream_domain = upstream_refresh.clone();
let strategy_domain = strategy_refresh.clone();
join.spawn(async move {
match crate::tls_front::fetcher::fetch_real_tls(
match crate::tls_front::fetcher::fetch_real_tls_with_strategy(
&host_domain,
port,
&domain,
fetch_timeout,
&strategy_domain,
Some(upstream_domain),
scope_domain.as_deref(),
proxy_protocol,

View File

@ -7,12 +7,12 @@ mod crypto;
mod error;
mod ip_tracker;
#[cfg(test)]
#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"]
mod ip_tracker_hotpath_adversarial_tests;
#[cfg(test)]
#[path = "tests/ip_tracker_encapsulation_adversarial_tests.rs"]
mod ip_tracker_encapsulation_adversarial_tests;
#[cfg(test)]
#[path = "tests/ip_tracker_hotpath_adversarial_tests.rs"]
mod ip_tracker_hotpath_adversarial_tests;
#[cfg(test)]
#[path = "tests/ip_tracker_regression_tests.rs"]
mod ip_tracker_regression_tests;
mod maestro;
@ -29,5 +29,6 @@ mod util;
#[tokio::main]
async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let _ = rustls::crypto::ring::default_provider().install_default();
maestro::run().await
}

View File

@ -1233,10 +1233,7 @@ async fn render_metrics(stats: &Stats, config: &ProxyConfig, ip_tracker: &UserIp
out,
"# HELP telemt_me_d2c_batch_bytes_bucket_total DC->Client batch byte size buckets"
);
let _ = writeln!(
out,
"# TYPE telemt_me_d2c_batch_bytes_bucket_total counter"
);
let _ = writeln!(out, "# TYPE telemt_me_d2c_batch_bytes_bucket_total counter");
let _ = writeln!(
out,
"telemt_me_d2c_batch_bytes_bucket_total{{bucket=\"0_1k\"}} {}",

View File

@ -186,6 +186,72 @@ fn handshake_timeout_with_mask_grace(config: &ProxyConfig) -> Duration {
}
}
const MASK_CLASSIFIER_PREFETCH_WINDOW: usize = 16;
#[cfg(test)]
const MASK_CLASSIFIER_PREFETCH_TIMEOUT: Duration = Duration::from_millis(5);
fn mask_classifier_prefetch_timeout(config: &ProxyConfig) -> Duration {
Duration::from_millis(config.censorship.mask_classifier_prefetch_timeout_ms)
}
fn should_prefetch_mask_classifier_window(initial_data: &[u8]) -> bool {
if initial_data.len() >= MASK_CLASSIFIER_PREFETCH_WINDOW {
return false;
}
if initial_data.is_empty() {
// Empty initial_data means there is no client probe prefix to refine.
// Prefetching in this case can consume fallback relay payload bytes and
// accidentally route them through shaping heuristics.
return false;
}
if initial_data[0] == 0x16 || initial_data.starts_with(b"SSH-") {
return false;
}
initial_data
.iter()
.all(|b| b.is_ascii_alphabetic() || *b == b' ')
}
#[cfg(test)]
async fn extend_masking_initial_window<R>(reader: &mut R, initial_data: &mut Vec<u8>)
where
R: AsyncRead + Unpin,
{
extend_masking_initial_window_with_timeout(
reader,
initial_data,
MASK_CLASSIFIER_PREFETCH_TIMEOUT,
)
.await;
}
async fn extend_masking_initial_window_with_timeout<R>(
reader: &mut R,
initial_data: &mut Vec<u8>,
prefetch_timeout: Duration,
) where
R: AsyncRead + Unpin,
{
if !should_prefetch_mask_classifier_window(initial_data) {
return;
}
let need = MASK_CLASSIFIER_PREFETCH_WINDOW.saturating_sub(initial_data.len());
if need == 0 {
return;
}
let mut extra = [0u8; MASK_CLASSIFIER_PREFETCH_WINDOW];
if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await
&& n > 0
{
initial_data.extend_from_slice(&extra[..n]);
}
}
fn masking_outcome<R, W>(
reader: R,
writer: W,
@ -200,6 +266,15 @@ where
W: AsyncWrite + Unpin + Send + 'static,
{
HandshakeOutcome::NeedsMasking(Box::pin(async move {
let mut reader = reader;
let mut initial_data = initial_data;
extend_masking_initial_window_with_timeout(
&mut reader,
&mut initial_data,
mask_classifier_prefetch_timeout(&config),
)
.await;
handle_bad_client(
reader,
writer,
@ -242,13 +317,20 @@ fn record_handshake_failure_class(
record_beobachten_class(beobachten, config, peer_ip, class);
}
#[inline]
fn increment_bad_on_unknown_tls_sni(stats: &Stats, error: &ProxyError) {
if matches!(error, ProxyError::UnknownTlsSni) {
stats.increment_connects_bad();
}
}
fn is_trusted_proxy_source(peer_ip: IpAddr, trusted: &[IpNetwork]) -> bool {
if trusted.is_empty() {
static EMPTY_PROXY_TRUST_WARNED: OnceLock<AtomicBool> = OnceLock::new();
let warned = EMPTY_PROXY_TRUST_WARNED.get_or_init(|| AtomicBool::new(false));
if !warned.swap(true, Ordering::Relaxed) {
warn!(
"PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers by default"
"PROXY protocol enabled but server.proxy_protocol_trusted_cidrs is empty; rejecting all PROXY headers"
);
}
return false;
@ -433,7 +515,10 @@ where
beobachten.clone(),
));
}
HandshakeResult::Error(e) => return Err(e),
HandshakeResult::Error(e) => {
increment_bad_on_unknown_tls_sni(stats.as_ref(), &e);
return Err(e);
}
};
debug!(peer = %peer, "Reading MTProto handshake through TLS");
@ -884,7 +969,10 @@ impl RunningClientHandler {
self.beobachten.clone(),
));
}
HandshakeResult::Error(e) => return Err(e),
HandshakeResult::Error(e) => {
increment_bad_on_unknown_tls_sni(stats.as_ref(), &e);
return Err(e);
}
};
debug!(peer = %peer, "Reading MTProto handshake through TLS");
@ -1153,7 +1241,7 @@ impl RunningClientHandler {
}
if let Some(quota) = config.access.user_data_quota.get(user)
&& stats.get_user_total_octets(user) >= *quota
&& stats.get_user_quota_used(user) >= *quota
{
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
@ -1212,7 +1300,7 @@ impl RunningClientHandler {
}
if let Some(quota) = config.access.user_data_quota.get(user)
&& stats.get_user_total_octets(user) >= *quota
&& stats.get_user_quota_used(user) >= *quota
{
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
@ -1321,6 +1409,38 @@ mod masking_shape_classifier_fuzz_redteam_expected_fail_tests;
#[path = "tests/client_masking_probe_evasion_blackhat_tests.rs"]
mod masking_probe_evasion_blackhat_tests;
#[cfg(test)]
#[path = "tests/client_masking_fragmented_classifier_security_tests.rs"]
mod masking_fragmented_classifier_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_replay_timing_security_tests.rs"]
mod masking_replay_timing_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_http2_fragmented_preface_security_tests.rs"]
mod masking_http2_fragmented_preface_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_invariant_security_tests.rs"]
mod masking_prefetch_invariant_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_timing_matrix_security_tests.rs"]
mod masking_prefetch_timing_matrix_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_config_runtime_security_tests.rs"]
mod masking_prefetch_config_runtime_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_config_pipeline_integration_security_tests.rs"]
mod masking_prefetch_config_pipeline_integration_security_tests;
#[cfg(test)]
#[path = "tests/client_masking_prefetch_strict_boundary_security_tests.rs"]
mod masking_prefetch_strict_boundary_security_tests;
#[cfg(test)]
#[path = "tests/client_beobachten_ttl_bounds_security_tests.rs"]
mod beobachten_ttl_bounds_security_tests;
@ -1328,3 +1448,15 @@ mod beobachten_ttl_bounds_security_tests;
#[cfg(test)]
#[path = "tests/client_tls_record_wrap_hardening_security_tests.rs"]
mod tls_record_wrap_hardening_security_tests;
#[cfg(test)]
#[path = "tests/client_clever_advanced_tests.rs"]
mod client_clever_advanced_tests;
#[cfg(test)]
#[path = "tests/client_more_advanced_tests.rs"]
mod client_more_advanced_tests;
#[cfg(test)]
#[path = "tests/client_deep_invariants_tests.rs"]
mod client_deep_invariants_tests;

View File

@ -16,7 +16,7 @@ use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tracing::{debug, trace, warn};
use zeroize::{Zeroize, Zeroizing};
use crate::config::ProxyConfig;
use crate::config::{ProxyConfig, UnknownSniAction};
use crate::crypto::{AesCtr, SecureRandom, sha256};
use crate::error::{HandshakeResult, ProxyError};
use crate::protocol::constants::*;
@ -121,6 +121,19 @@ fn auth_probe_eviction_offset(peer_ip: IpAddr, now: Instant) -> usize {
hasher.finish() as usize
}
fn auth_probe_scan_start_offset(
peer_ip: IpAddr,
now: Instant,
state_len: usize,
scan_limit: usize,
) -> usize {
if state_len == 0 || scan_limit == 0 {
return 0;
}
auth_probe_eviction_offset(peer_ip, now) % state_len
}
fn auth_probe_is_throttled(peer_ip: IpAddr, now: Instant) -> bool {
let peer_ip = normalize_auth_probe_ip(peer_ip);
let state = auth_probe_state_map();
@ -269,34 +282,9 @@ fn auth_probe_record_failure_with_state(
let mut eviction_candidate: Option<(IpAddr, u32, Instant)> = None;
let state_len = state.len();
let scan_limit = state_len.min(AUTH_PROBE_PRUNE_SCAN_LIMIT);
let start_offset = if state_len == 0 {
0
} else {
auth_probe_eviction_offset(peer_ip, now) % state_len
};
let mut scanned = 0usize;
for entry in state.iter().skip(start_offset) {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => {}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(key);
}
scanned += 1;
if scanned >= scan_limit {
break;
}
}
if scanned < scan_limit {
for entry in state.iter().take(scan_limit - scanned) {
if state_len <= AUTH_PROBE_PRUNE_SCAN_LIMIT {
for entry in state.iter() {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
@ -310,6 +298,46 @@ fn auth_probe_record_failure_with_state(
stale_keys.push(key);
}
}
} else {
let start_offset =
auth_probe_scan_start_offset(peer_ip, now, state_len, scan_limit);
let mut scanned = 0usize;
for entry in state.iter().skip(start_offset) {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail && last_seen >= current_seen) => {}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(key);
}
scanned += 1;
if scanned >= scan_limit {
break;
}
}
if scanned < scan_limit {
for entry in state.iter().take(scan_limit - scanned) {
let key = *entry.key();
let fail_streak = entry.value().fail_streak;
let last_seen = entry.value().last_seen;
match eviction_candidate {
Some((_, current_fail, current_seen))
if fail_streak > current_fail
|| (fail_streak == current_fail
&& last_seen >= current_seen) => {}
_ => eviction_candidate = Some((key, fail_streak, last_seen)),
}
if auth_probe_state_expired(entry.value(), now) {
stale_keys.push(key);
}
}
}
}
for stale_key in stale_keys {
@ -501,6 +529,21 @@ fn decode_user_secrets(
secrets
}
#[inline]
fn find_matching_tls_domain<'a>(config: &'a ProxyConfig, sni: &str) -> Option<&'a str> {
if config.censorship.tls_domain.eq_ignore_ascii_case(sni) {
return Some(config.censorship.tls_domain.as_str());
}
for domain in &config.censorship.tls_domains {
if domain.eq_ignore_ascii_case(sni) {
return Some(domain.as_str());
}
}
None
}
async fn maybe_apply_server_hello_delay(config: &ProxyConfig) {
if config.censorship.server_hello_delay_max_ms == 0 {
return;
@ -584,70 +627,12 @@ where
}
let client_sni = tls::extract_sni_from_client_hello(handshake);
let secrets = decode_user_secrets(config, client_sni.as_deref());
let validation = match tls::validate_tls_handshake_with_replay_window(
handshake,
&secrets,
config.access.ignore_time_skew,
config.access.replay_window_secs,
) {
Some(v) => v,
None => {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,
ignore_time_skew = config.access.ignore_time_skew,
"TLS handshake validation failed - no matching user or time skew"
);
return HandshakeResult::BadClient { reader, writer };
}
};
// Replay tracking is applied only after successful authentication to avoid
// letting unauthenticated probes evict valid entries from the replay cache.
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_and_add_tls_digest(digest_half) {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer };
}
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s,
None => {
maybe_apply_server_hello_delay(config).await;
return HandshakeResult::BadClient { reader, writer };
}
};
let cached = if config.censorship.tls_emulation {
if let Some(cache) = tls_cache.as_ref() {
let selected_domain = if let Some(sni) = client_sni.as_ref() {
if cache.contains_domain(sni).await {
sni.clone()
} else {
config.censorship.tls_domain.clone()
}
} else {
config.censorship.tls_domain.clone()
};
let cached_entry = cache.get(&selected_domain).await;
let use_full_cert_payload = cache
.take_full_cert_budget_for_ip(
peer.ip(),
Duration::from_secs(config.censorship.tls_full_cert_ttl_secs),
)
.await;
Some((cached_entry, use_full_cert_payload))
} else {
None
}
} else {
None
};
let preferred_user_hint = client_sni
.as_deref()
.filter(|sni| config.access.users.contains_key(*sni));
let matched_tls_domain = client_sni
.as_deref()
.and_then(|sni| find_matching_tls_domain(config, sni));
let alpn_list = if config.censorship.alpn_enforce {
tls::extract_alpn_from_client_hello(handshake)
@ -670,6 +655,81 @@ where
None
};
if client_sni.is_some() && matched_tls_domain.is_none() && preferred_user_hint.is_none() {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,
sni = ?client_sni,
action = ?config.censorship.unknown_sni_action,
"TLS handshake rejected by unknown SNI policy"
);
return match config.censorship.unknown_sni_action {
UnknownSniAction::Drop => HandshakeResult::Error(ProxyError::UnknownTlsSni),
UnknownSniAction::Mask => HandshakeResult::BadClient { reader, writer },
};
}
let secrets = decode_user_secrets(config, preferred_user_hint);
let validation = match tls::validate_tls_handshake_with_replay_window(
handshake,
&secrets,
config.access.ignore_time_skew,
config.access.replay_window_secs,
) {
Some(v) => v,
None => {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
debug!(
peer = %peer,
ignore_time_skew = config.access.ignore_time_skew,
"TLS handshake validation failed - no matching user or time skew"
);
return HandshakeResult::BadClient { reader, writer };
}
};
// Reject known replay digests before expensive cache/domain/ALPN policy work.
let digest_half = &validation.digest[..tls::TLS_DIGEST_HALF_LEN];
if replay_checker.check_tls_digest(digest_half) {
auth_probe_record_failure(peer.ip(), Instant::now());
maybe_apply_server_hello_delay(config).await;
warn!(peer = %peer, "TLS replay attack detected (duplicate digest)");
return HandshakeResult::BadClient { reader, writer };
}
let secret = match secrets.iter().find(|(name, _)| *name == validation.user) {
Some((_, s)) => s,
None => {
maybe_apply_server_hello_delay(config).await;
return HandshakeResult::BadClient { reader, writer };
}
};
let cached = if config.censorship.tls_emulation {
if let Some(cache) = tls_cache.as_ref() {
let selected_domain =
matched_tls_domain.unwrap_or(config.censorship.tls_domain.as_str());
let cached_entry = cache.get(selected_domain).await;
let use_full_cert_payload = cache
.take_full_cert_budget_for_ip(
peer.ip(),
Duration::from_secs(config.censorship.tls_full_cert_ttl_secs),
)
.await;
Some((cached_entry, use_full_cert_payload))
} else {
None
}
} else {
None
};
// Add replay digest only for policy-valid handshakes.
replay_checker.add_tls_digest(digest_half);
let response = if let Some((cached_entry, use_full_cert_payload)) = cached {
emulator::build_emulated_server_hello(
secret,
@ -769,7 +829,7 @@ where
let mut dec_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let dec_key = Zeroizing::new(sha256(&dec_key_input));
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
@ -805,7 +865,7 @@ where
let mut enc_key_input = Zeroizing::new(Vec::with_capacity(PREKEY_LEN + secret.len()));
enc_key_input.extend_from_slice(enc_prekey);
enc_key_input.extend_from_slice(&secret);
let enc_key = sha256(&enc_key_input);
let enc_key = Zeroizing::new(sha256(&enc_key_input));
let mut enc_iv_arr = [0u8; IV_LEN];
enc_iv_arr.copy_from_slice(enc_iv_bytes);
@ -830,9 +890,9 @@ where
user: user.clone(),
dc_idx,
proto_tag,
dec_key,
dec_key: *dec_key,
dec_iv,
enc_key,
enc_key: *enc_key,
enc_iv,
peer,
is_tls,
@ -979,6 +1039,38 @@ mod saturation_poison_security_tests;
#[path = "tests/handshake_auth_probe_hardening_adversarial_tests.rs"]
mod auth_probe_hardening_adversarial_tests;
#[cfg(test)]
#[path = "tests/handshake_auth_probe_scan_budget_security_tests.rs"]
mod auth_probe_scan_budget_security_tests;
#[cfg(test)]
#[path = "tests/handshake_auth_probe_scan_offset_stress_tests.rs"]
mod auth_probe_scan_offset_stress_tests;
#[cfg(test)]
#[path = "tests/handshake_auth_probe_eviction_bias_security_tests.rs"]
mod auth_probe_eviction_bias_security_tests;
#[cfg(test)]
#[path = "tests/handshake_advanced_clever_tests.rs"]
mod advanced_clever_tests;
#[cfg(test)]
#[path = "tests/handshake_more_clever_tests.rs"]
mod more_clever_tests;
#[cfg(test)]
#[path = "tests/handshake_real_bug_stress_tests.rs"]
mod real_bug_stress_tests;
#[cfg(test)]
#[path = "tests/handshake_timing_manual_bench_tests.rs"]
mod timing_manual_bench_tests;
#[cfg(test)]
#[path = "tests/handshake_key_material_zeroization_security_tests.rs"]
mod handshake_key_material_zeroization_security_tests;
/// Compile-time guard: HandshakeSuccess holds cryptographic key material and
/// must never be Copy. A Copy impl would allow silent key duplication,
/// undermining the zeroize-on-drop guarantee.

View File

@ -4,14 +4,23 @@ use crate::config::ProxyConfig;
use crate::network::dns_overrides::resolve_socket_addr;
use crate::stats::beobachten::BeobachtenStore;
use crate::transport::proxy_protocol::{ProxyProtocolV1Builder, ProxyProtocolV2Builder};
use rand::{Rng, RngExt};
use std::net::SocketAddr;
#[cfg(unix)]
use nix::ifaddrs::getifaddrs;
use rand::rngs::StdRng;
use rand::{Rng, RngExt, SeedableRng};
use std::net::{IpAddr, SocketAddr};
use std::str;
use std::time::Duration;
#[cfg(test)]
use std::sync::atomic::{AtomicUsize, Ordering};
#[cfg(unix)]
use std::sync::{Mutex, OnceLock};
use std::time::{Duration, Instant as StdInstant};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
#[cfg(unix)]
use tokio::sync::Mutex as AsyncMutex;
use tokio::time::{Instant, timeout};
use tracing::debug;
@ -30,28 +39,55 @@ const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_secs(5);
#[cfg(test)]
const MASK_RELAY_IDLE_TIMEOUT: Duration = Duration::from_millis(100);
const MASK_BUFFER_SIZE: usize = 8192;
#[cfg(unix)]
#[cfg(not(test))]
const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(300);
#[cfg(all(unix, test))]
const LOCAL_INTERFACE_CACHE_TTL: Duration = Duration::from_secs(1);
struct CopyOutcome {
total: usize,
ended_by_eof: bool,
}
async fn copy_with_idle_timeout<R, W>(reader: &mut R, writer: &mut W) -> CopyOutcome
async fn copy_with_idle_timeout<R, W>(
reader: &mut R,
writer: &mut W,
byte_cap: usize,
shutdown_on_eof: bool,
) -> CopyOutcome
where
R: AsyncRead + Unpin,
W: AsyncWrite + Unpin,
{
let mut buf = [0u8; MASK_BUFFER_SIZE];
let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]);
let mut total = 0usize;
let mut ended_by_eof = false;
if byte_cap == 0 {
return CopyOutcome {
total,
ended_by_eof,
};
}
loop {
let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf)).await;
let remaining_budget = byte_cap.saturating_sub(total);
if remaining_budget == 0 {
break;
}
let read_len = remaining_budget.min(MASK_BUFFER_SIZE);
let read_res = timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await;
let n = match read_res {
Ok(Ok(n)) => n,
Ok(Err(_)) | Err(_) => break,
};
if n == 0 {
ended_by_eof = true;
if shutdown_on_eof {
let _ = timeout(MASK_RELAY_IDLE_TIMEOUT, writer.shutdown()).await;
}
break;
}
total = total.saturating_add(n);
@ -68,6 +104,31 @@ where
}
}
fn is_http_probe(data: &[u8]) -> bool {
// RFC 7540 section 3.5: HTTP/2 client preface starts with "PRI ".
const HTTP_METHODS: [&[u8]; 10] = [
b"GET ", b"POST", b"HEAD", b"PUT ", b"DELETE", b"OPTIONS", b"CONNECT", b"TRACE", b"PATCH",
b"PRI ",
];
if data.is_empty() {
return false;
}
let window = &data[..data.len().min(16)];
for method in HTTP_METHODS {
if data.len() >= method.len() && window.starts_with(method) {
return true;
}
if (2..=3).contains(&window.len()) && method.starts_with(window) {
return true;
}
}
false
}
fn next_mask_shape_bucket(total: usize, floor: usize, cap: usize) -> usize {
if total == 0 || floor == 0 || cap < floor {
return total;
@ -125,6 +186,11 @@ async fn maybe_write_shape_padding<W>(
let mut remaining = target_total - total_sent;
let mut pad_chunk = [0u8; 1024];
let deadline = Instant::now() + MASK_TIMEOUT;
// Use a Send RNG so relay futures remain spawn-safe under Tokio.
let mut rng = {
let mut seed_source = rand::rng();
StdRng::from_rng(&mut seed_source)
};
while remaining > 0 {
let now = Instant::now();
@ -133,10 +199,7 @@ async fn maybe_write_shape_padding<W>(
}
let write_len = remaining.min(pad_chunk.len());
{
let mut rng = rand::rng();
rng.fill_bytes(&mut pad_chunk[..write_len]);
}
rng.fill_bytes(&mut pad_chunk[..write_len]);
let write_budget = deadline.saturating_duration_since(now);
match timeout(write_budget, mask_write.write_all(&pad_chunk[..write_len])).await {
Ok(Ok(())) => {}
@ -167,11 +230,11 @@ where
}
}
async fn consume_client_data_with_timeout<R>(reader: R)
async fn consume_client_data_with_timeout_and_cap<R>(reader: R, byte_cap: usize)
where
R: AsyncRead + Unpin,
{
if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader))
if timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, byte_cap))
.await
.is_err()
{
@ -190,6 +253,13 @@ fn mask_outcome_target_budget(config: &ProxyConfig) -> Duration {
if config.censorship.mask_timing_normalization_enabled {
let floor = config.censorship.mask_timing_normalization_floor_ms;
let ceiling = config.censorship.mask_timing_normalization_ceiling_ms;
if floor == 0 {
if ceiling == 0 {
return Duration::from_millis(0);
}
let mut rng = rand::rng();
return Duration::from_millis(rng.random_range(0..=ceiling));
}
if ceiling > floor {
let mut rng = rand::rng();
return Duration::from_millis(rng.random_range(floor..=ceiling));
@ -219,14 +289,7 @@ async fn wait_mask_outcome_budget(started: Instant, config: &ProxyConfig) {
/// Detect client type based on initial data
fn detect_client_type(data: &[u8]) -> &'static str {
// Check for HTTP request
if data.len() > 4
&& (data.starts_with(b"GET ")
|| data.starts_with(b"POST")
|| data.starts_with(b"HEAD")
|| data.starts_with(b"PUT ")
|| data.starts_with(b"DELETE")
|| data.starts_with(b"OPTIONS"))
{
if is_http_probe(data) {
return "HTTP";
}
@ -248,6 +311,247 @@ fn detect_client_type(data: &[u8]) -> &'static str {
"unknown"
}
fn parse_mask_host_ip_literal(host: &str) -> Option<IpAddr> {
if host.starts_with('[') && host.ends_with(']') {
return host[1..host.len() - 1].parse::<IpAddr>().ok();
}
host.parse::<IpAddr>().ok()
}
fn canonical_ip(ip: IpAddr) -> IpAddr {
match ip {
IpAddr::V6(v6) => v6
.to_ipv4_mapped()
.map(IpAddr::V4)
.unwrap_or(IpAddr::V6(v6)),
IpAddr::V4(v4) => IpAddr::V4(v4),
}
}
#[cfg(unix)]
fn collect_local_interface_ips() -> Vec<IpAddr> {
#[cfg(test)]
LOCAL_INTERFACE_ENUMERATIONS.fetch_add(1, Ordering::Relaxed);
let mut out = Vec::new();
if let Ok(addrs) = getifaddrs() {
for iface in addrs {
if let Some(address) = iface.address {
if let Some(v4) = address.as_sockaddr_in() {
out.push(canonical_ip(IpAddr::V4(v4.ip())));
} else if let Some(v6) = address.as_sockaddr_in6() {
out.push(canonical_ip(IpAddr::V6(v6.ip())));
}
}
}
}
out
}
fn choose_interface_snapshot(previous: &[IpAddr], refreshed: Vec<IpAddr>) -> Vec<IpAddr> {
if refreshed.is_empty() && !previous.is_empty() {
return previous.to_vec();
}
refreshed
}
#[cfg(unix)]
#[derive(Default)]
struct LocalInterfaceCache {
ips: Vec<IpAddr>,
refreshed_at: Option<StdInstant>,
}
#[cfg(unix)]
static LOCAL_INTERFACE_CACHE: OnceLock<Mutex<LocalInterfaceCache>> = OnceLock::new();
#[cfg(unix)]
static LOCAL_INTERFACE_REFRESH_LOCK: OnceLock<AsyncMutex<()>> = OnceLock::new();
#[cfg(all(unix, test))]
fn local_interface_ips() -> Vec<IpAddr> {
let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default()));
let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
let stale = guard
.refreshed_at
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
if stale {
let refreshed = collect_local_interface_ips();
guard.ips = choose_interface_snapshot(&guard.ips, refreshed);
guard.refreshed_at = Some(StdInstant::now());
}
guard.ips.clone()
}
#[cfg(unix)]
async fn local_interface_ips_async() -> Vec<IpAddr> {
let cache = LOCAL_INTERFACE_CACHE.get_or_init(|| Mutex::new(LocalInterfaceCache::default()));
{
let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
let stale = guard
.refreshed_at
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
if !stale {
return guard.ips.clone();
}
}
let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(()));
let _refresh_guard = refresh_lock.lock().await;
{
let guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
let stale = guard
.refreshed_at
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
if !stale {
return guard.ips.clone();
}
}
let refreshed = tokio::task::spawn_blocking(collect_local_interface_ips)
.await
.unwrap_or_default();
let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
let stale = guard
.refreshed_at
.is_none_or(|at| at.elapsed() >= LOCAL_INTERFACE_CACHE_TTL);
if stale {
guard.ips = choose_interface_snapshot(&guard.ips, refreshed);
guard.refreshed_at = Some(StdInstant::now());
}
guard.ips.clone()
}
#[cfg(all(not(unix), test))]
fn local_interface_ips() -> Vec<IpAddr> {
Vec::new()
}
#[cfg(not(unix))]
async fn local_interface_ips_async() -> Vec<IpAddr> {
Vec::new()
}
#[cfg(test)]
static LOCAL_INTERFACE_ENUMERATIONS: AtomicUsize = AtomicUsize::new(0);
#[cfg(test)]
fn reset_local_interface_enumerations_for_tests() {
LOCAL_INTERFACE_ENUMERATIONS.store(0, Ordering::Relaxed);
#[cfg(unix)]
if let Some(cache) = LOCAL_INTERFACE_CACHE.get() {
let mut guard = cache.lock().unwrap_or_else(|poison| poison.into_inner());
guard.ips.clear();
guard.refreshed_at = None;
}
}
#[cfg(test)]
fn local_interface_enumerations_for_tests() -> usize {
LOCAL_INTERFACE_ENUMERATIONS.load(Ordering::Relaxed)
}
fn is_mask_target_local_listener_with_interfaces(
mask_host: &str,
mask_port: u16,
local_addr: SocketAddr,
resolved_override: Option<SocketAddr>,
interface_ips: &[IpAddr],
) -> bool {
if mask_port != local_addr.port() {
return false;
}
let local_ip = canonical_ip(local_addr.ip());
let literal_mask_ip = parse_mask_host_ip_literal(mask_host).map(canonical_ip);
if let Some(addr) = resolved_override {
let resolved_ip = canonical_ip(addr.ip());
if resolved_ip == local_ip {
return true;
}
if local_ip.is_unspecified()
&& (resolved_ip.is_loopback()
|| resolved_ip.is_unspecified()
|| interface_ips.contains(&resolved_ip))
{
return true;
}
}
if let Some(mask_ip) = literal_mask_ip {
if mask_ip == local_ip {
return true;
}
if local_ip.is_unspecified()
&& (mask_ip.is_loopback()
|| mask_ip.is_unspecified()
|| interface_ips.contains(&mask_ip))
{
return true;
}
}
false
}
#[cfg(test)]
fn is_mask_target_local_listener(
mask_host: &str,
mask_port: u16,
local_addr: SocketAddr,
resolved_override: Option<SocketAddr>,
) -> bool {
if mask_port != local_addr.port() {
return false;
}
let interfaces = local_interface_ips();
is_mask_target_local_listener_with_interfaces(
mask_host,
mask_port,
local_addr,
resolved_override,
&interfaces,
)
}
async fn is_mask_target_local_listener_async(
mask_host: &str,
mask_port: u16,
local_addr: SocketAddr,
resolved_override: Option<SocketAddr>,
) -> bool {
if mask_port != local_addr.port() {
return false;
}
let interfaces = local_interface_ips_async().await;
is_mask_target_local_listener_with_interfaces(
mask_host,
mask_port,
local_addr,
resolved_override,
&interfaces,
)
}
fn masking_beobachten_ttl(config: &ProxyConfig) -> Duration {
let minutes = config.general.beobachten_minutes;
let clamped = minutes.clamp(1, 24 * 60);
Duration::from_secs(clamped.saturating_mul(60))
}
fn build_mask_proxy_header(
version: u8,
peer: SocketAddr,
@ -290,13 +594,14 @@ pub async fn handle_bad_client<R, W>(
{
let client_type = detect_client_type(initial_data);
if config.general.beobachten {
let ttl = Duration::from_secs(config.general.beobachten_minutes.saturating_mul(60));
let ttl = masking_beobachten_ttl(config);
beobachten.record(client_type, peer.ip(), ttl);
}
if !config.censorship.mask {
// Masking disabled, just consume data
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes)
.await;
return;
}
@ -341,6 +646,7 @@ pub async fn handle_bad_client<R, W>(
config.censorship.mask_shape_above_cap_blur,
config.censorship.mask_shape_above_cap_blur_max_bytes,
config.censorship.mask_shape_hardening_aggressive_mode,
config.censorship.mask_relay_max_bytes,
),
)
.await
@ -353,12 +659,20 @@ pub async fn handle_bad_client<R, W>(
Ok(Err(e)) => {
wait_mask_connect_budget_if_needed(connect_started, config).await;
debug!(error = %e, "Failed to connect to mask unix socket");
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(
reader,
config.censorship.mask_relay_max_bytes,
)
.await;
wait_mask_outcome_budget(outcome_started, config).await;
}
Err(_) => {
debug!("Timeout connecting to mask unix socket");
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(
reader,
config.censorship.mask_relay_max_bytes,
)
.await;
wait_mask_outcome_budget(outcome_started, config).await;
}
}
@ -372,6 +686,29 @@ pub async fn handle_bad_client<R, W>(
.unwrap_or(&config.censorship.tls_domain);
let mask_port = config.censorship.mask_port;
// Fail closed when fallback points at our own listener endpoint.
// Self-referential masking can create recursive proxy loops under
// misconfiguration and leak distinguishable load spikes to adversaries.
let resolved_mask_addr = resolve_socket_addr(mask_host, mask_port);
if is_mask_target_local_listener_async(mask_host, mask_port, local_addr, resolved_mask_addr)
.await
{
let outcome_started = Instant::now();
debug!(
client_type = client_type,
host = %mask_host,
port = mask_port,
local = %local_addr,
"Mask target resolves to local listener; refusing self-referential masking fallback"
);
consume_client_data_with_timeout_and_cap(reader, config.censorship.mask_relay_max_bytes)
.await;
wait_mask_outcome_budget(outcome_started, config).await;
return;
}
let outcome_started = Instant::now();
debug!(
client_type = client_type,
host = %mask_host,
@ -381,10 +718,9 @@ pub async fn handle_bad_client<R, W>(
);
// Apply runtime DNS override for mask target when configured.
let mask_addr = resolve_socket_addr(mask_host, mask_port)
let mask_addr = resolved_mask_addr
.map(|addr| addr.to_string())
.unwrap_or_else(|| format!("{}:{}", mask_host, mask_port));
let outcome_started = Instant::now();
let connect_started = Instant::now();
let connect_result = timeout(MASK_TIMEOUT, TcpStream::connect(&mask_addr)).await;
match connect_result {
@ -413,6 +749,7 @@ pub async fn handle_bad_client<R, W>(
config.censorship.mask_shape_above_cap_blur,
config.censorship.mask_shape_above_cap_blur_max_bytes,
config.censorship.mask_shape_hardening_aggressive_mode,
config.censorship.mask_relay_max_bytes,
),
)
.await
@ -425,12 +762,20 @@ pub async fn handle_bad_client<R, W>(
Ok(Err(e)) => {
wait_mask_connect_budget_if_needed(connect_started, config).await;
debug!(error = %e, "Failed to connect to mask host");
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(
reader,
config.censorship.mask_relay_max_bytes,
)
.await;
wait_mask_outcome_budget(outcome_started, config).await;
}
Err(_) => {
debug!("Timeout connecting to mask host");
consume_client_data_with_timeout(reader).await;
consume_client_data_with_timeout_and_cap(
reader,
config.censorship.mask_relay_max_bytes,
)
.await;
wait_mask_outcome_budget(outcome_started, config).await;
}
}
@ -449,6 +794,7 @@ async fn relay_to_mask<R, W, MR, MW>(
shape_above_cap_blur: bool,
shape_above_cap_blur_max_bytes: usize,
shape_hardening_aggressive_mode: bool,
mask_relay_max_bytes: usize,
) where
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
@ -464,8 +810,18 @@ async fn relay_to_mask<R, W, MR, MW>(
}
let (upstream_copy, downstream_copy) = tokio::join!(
async { copy_with_idle_timeout(&mut reader, &mut mask_write).await },
async { copy_with_idle_timeout(&mut mask_read, &mut writer).await }
async {
copy_with_idle_timeout(
&mut reader,
&mut mask_write,
mask_relay_max_bytes,
!shape_hardening_enabled,
)
.await
},
async {
copy_with_idle_timeout(&mut mask_read, &mut writer, mask_relay_max_bytes, true).await
}
);
let total_sent = initial_data.len().saturating_add(upstream_copy.total);
@ -491,13 +847,36 @@ async fn relay_to_mask<R, W, MR, MW>(
let _ = writer.shutdown().await;
}
/// Just consume all data from client without responding
async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R) {
let mut buf = vec![0u8; MASK_BUFFER_SIZE];
while let Ok(n) = reader.read(&mut buf).await {
/// Just consume all data from client without responding.
async fn consume_client_data<R: AsyncRead + Unpin>(mut reader: R, byte_cap: usize) {
if byte_cap == 0 {
return;
}
// Keep drain path fail-closed under slow-loris stalls.
let mut buf = Box::new([0u8; MASK_BUFFER_SIZE]);
let mut total = 0usize;
loop {
let remaining_budget = byte_cap.saturating_sub(total);
if remaining_budget == 0 {
break;
}
let read_len = remaining_budget.min(MASK_BUFFER_SIZE);
let n = match timeout(MASK_RELAY_IDLE_TIMEOUT, reader.read(&mut buf[..read_len])).await {
Ok(Ok(n)) => n,
Ok(Err(_)) | Err(_) => break,
};
if n == 0 {
break;
}
total = total.saturating_add(n);
if total >= byte_cap {
break;
}
}
}
@ -521,6 +900,10 @@ mod masking_shape_above_cap_blur_security_tests;
#[path = "tests/masking_timing_normalization_security_tests.rs"]
mod masking_timing_normalization_security_tests;
#[cfg(test)]
#[path = "tests/masking_timing_budget_coupling_security_tests.rs"]
mod masking_timing_budget_coupling_security_tests;
#[cfg(test)]
#[path = "tests/masking_ab_envelope_blur_integration_security_tests.rs"]
mod masking_ab_envelope_blur_integration_security_tests;
@ -548,3 +931,75 @@ mod masking_aggressive_mode_security_tests;
#[cfg(test)]
#[path = "tests/masking_timing_sidechannel_redteam_expected_fail_tests.rs"]
mod masking_timing_sidechannel_redteam_expected_fail_tests;
#[cfg(test)]
#[path = "tests/masking_self_target_loop_security_tests.rs"]
mod masking_self_target_loop_security_tests;
#[cfg(test)]
#[path = "tests/masking_classification_completeness_security_tests.rs"]
mod masking_classification_completeness_security_tests;
#[cfg(test)]
#[path = "tests/masking_relay_guardrails_security_tests.rs"]
mod masking_relay_guardrails_security_tests;
#[cfg(test)]
#[path = "tests/masking_connect_failure_close_matrix_security_tests.rs"]
mod masking_connect_failure_close_matrix_security_tests;
#[cfg(test)]
#[path = "tests/masking_additional_hardening_security_tests.rs"]
mod masking_additional_hardening_security_tests;
#[cfg(test)]
#[path = "tests/masking_consume_idle_timeout_security_tests.rs"]
mod masking_consume_idle_timeout_security_tests;
#[cfg(test)]
#[path = "tests/masking_http2_probe_classification_security_tests.rs"]
mod masking_http2_probe_classification_security_tests;
#[cfg(test)]
#[path = "tests/masking_http_probe_boundary_security_tests.rs"]
mod masking_http_probe_boundary_security_tests;
#[cfg(test)]
#[path = "tests/masking_rng_hoist_perf_regression_tests.rs"]
mod masking_rng_hoist_perf_regression_tests;
#[cfg(test)]
#[path = "tests/masking_http2_preface_integration_security_tests.rs"]
mod masking_http2_preface_integration_security_tests;
#[cfg(test)]
#[path = "tests/masking_consume_stress_adversarial_tests.rs"]
mod masking_consume_stress_adversarial_tests;
#[cfg(test)]
#[path = "tests/masking_interface_cache_security_tests.rs"]
mod masking_interface_cache_security_tests;
#[cfg(test)]
#[path = "tests/masking_interface_cache_defense_in_depth_security_tests.rs"]
mod masking_interface_cache_defense_in_depth_security_tests;
#[cfg(test)]
#[path = "tests/masking_interface_cache_concurrency_security_tests.rs"]
mod masking_interface_cache_concurrency_security_tests;
#[cfg(test)]
#[path = "tests/masking_production_cap_regression_security_tests.rs"]
mod masking_production_cap_regression_security_tests;
#[cfg(test)]
#[path = "tests/masking_extended_attack_surface_security_tests.rs"]
mod masking_extended_attack_surface_security_tests;
#[cfg(test)]
#[path = "tests/masking_padding_timeout_adversarial_tests.rs"]
mod masking_padding_timeout_adversarial_tests;
#[cfg(all(test, feature = "redteam_offline_expected_fail"))]
#[path = "tests/masking_offline_target_redteam_expected_fail_tests.rs"]
mod masking_offline_target_redteam_expected_fail_tests;

View File

@ -1,5 +1,7 @@
use std::collections::hash_map::RandomState;
use std::collections::{BTreeSet, HashMap};
#[cfg(test)]
use std::future::Future;
use std::hash::{BuildHasher, Hash};
use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
@ -8,7 +10,7 @@ use std::time::{Duration, Instant};
use dashmap::DashMap;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::{Mutex as AsyncMutex, mpsc, oneshot, watch};
use tokio::sync::{mpsc, oneshot, watch};
use tokio::time::timeout;
use tracing::{debug, info, trace, warn};
@ -21,7 +23,9 @@ use crate::proxy::route_mode::{
ROUTE_SWITCH_ERROR_MSG, RelayRouteMode, RouteCutoverState, affected_cutover_state,
cutover_stagger_delay,
};
use crate::stats::{MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, Stats};
use crate::stats::{
MeD2cFlushReason, MeD2cQuotaRejectStage, MeD2cWriteMode, QuotaReserveError, Stats, UserStats,
};
use crate::stream::{BufferPool, CryptoReader, CryptoWriter, PooledBuffer};
use crate::transport::middle_proxy::{MePool, MeResponse, proto_flags_for_tag};
@ -39,28 +43,23 @@ const C2ME_CHANNEL_CAPACITY_FALLBACK: usize = 128;
const C2ME_SOFT_PRESSURE_MIN_FREE_SLOTS: usize = 64;
const C2ME_SENDER_FAIRNESS_BUDGET: usize = 32;
const RELAY_IDLE_IO_POLL_MAX: Duration = Duration::from_secs(1);
const TINY_FRAME_DEBT_PER_TINY: u32 = 8;
const TINY_FRAME_DEBT_LIMIT: u32 = 512;
#[cfg(test)]
const C2ME_SEND_TIMEOUT: Duration = Duration::from_millis(50);
#[cfg(not(test))]
const C2ME_SEND_TIMEOUT: Duration = Duration::from_secs(5);
#[cfg(test)]
const RELAY_TEST_STEP_TIMEOUT: Duration = Duration::from_secs(1);
const ME_D2C_FLUSH_BATCH_MAX_FRAMES_MIN: usize = 1;
const ME_D2C_FLUSH_BATCH_MAX_BYTES_MIN: usize = 4096;
const ME_D2C_FRAME_BUF_SHRINK_HYSTERESIS_FACTOR: usize = 2;
const ME_D2C_SINGLE_WRITE_COALESCE_MAX_BYTES: usize = 128 * 1024;
#[cfg(test)]
const QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))]
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
#[cfg(test)]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
#[cfg(not(test))]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
const QUOTA_RESERVE_SPIN_RETRIES: usize = 32;
static DESYNC_DEDUP: OnceLock<DashMap<u64, Instant>> = OnceLock::new();
static DESYNC_HASHER: OnceLock<RandomState> = OnceLock::new();
static DESYNC_FULL_CACHE_LAST_EMIT_AT: OnceLock<Mutex<Option<Instant>>> = OnceLock::new();
static DESYNC_DEDUP_EVER_SATURATED: OnceLock<AtomicBool> = OnceLock::new();
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<AsyncMutex<()>>>> = OnceLock::new();
static QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<AsyncMutex<()>>>> = OnceLock::new();
static RELAY_IDLE_CANDIDATE_REGISTRY: OnceLock<Mutex<RelayIdleCandidateRegistry>> = OnceLock::new();
static RELAY_IDLE_MARK_SEQ: AtomicU64 = AtomicU64::new(0);
@ -94,10 +93,24 @@ fn relay_idle_candidate_registry() -> &'static Mutex<RelayIdleCandidateRegistry>
RELAY_IDLE_CANDIDATE_REGISTRY.get_or_init(|| Mutex::new(RelayIdleCandidateRegistry::default()))
}
fn relay_idle_candidate_registry_lock() -> std::sync::MutexGuard<'static, RelayIdleCandidateRegistry>
{
let registry = relay_idle_candidate_registry();
match registry.lock() {
Ok(guard) => guard,
Err(poisoned) => {
let mut guard = poisoned.into_inner();
// Fail closed after panic while holding registry lock: drop all
// candidates and pressure cursors to avoid stale cross-session state.
*guard = RelayIdleCandidateRegistry::default();
registry.clear_poison();
guard
}
}
}
fn mark_relay_idle_candidate(conn_id: u64) -> bool {
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
return false;
};
let mut guard = relay_idle_candidate_registry_lock();
if guard.by_conn_id.contains_key(&conn_id) {
return false;
@ -116,9 +129,7 @@ fn mark_relay_idle_candidate(conn_id: u64) -> bool {
}
fn clear_relay_idle_candidate(conn_id: u64) {
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
return;
};
let mut guard = relay_idle_candidate_registry_lock();
if let Some(meta) = guard.by_conn_id.remove(&conn_id) {
guard.ordered.remove(&(meta.mark_order_seq, conn_id));
@ -127,23 +138,17 @@ fn clear_relay_idle_candidate(conn_id: u64) {
#[cfg(test)]
fn oldest_relay_idle_candidate() -> Option<u64> {
let Ok(guard) = relay_idle_candidate_registry().lock() else {
return None;
};
let guard = relay_idle_candidate_registry_lock();
guard.ordered.iter().next().map(|(_, conn_id)| *conn_id)
}
fn note_relay_pressure_event() {
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
return;
};
let mut guard = relay_idle_candidate_registry_lock();
guard.pressure_event_seq = guard.pressure_event_seq.wrapping_add(1);
}
fn relay_pressure_event_seq() -> u64 {
let Ok(guard) = relay_idle_candidate_registry().lock() else {
return 0;
};
let guard = relay_idle_candidate_registry_lock();
guard.pressure_event_seq
}
@ -152,9 +157,7 @@ fn maybe_evict_idle_candidate_on_pressure(
seen_pressure_seq: &mut u64,
stats: &Stats,
) -> bool {
let Ok(mut guard) = relay_idle_candidate_registry().lock() else {
return false;
};
let mut guard = relay_idle_candidate_registry_lock();
let latest_pressure_seq = guard.pressure_event_seq;
if latest_pressure_seq == *seen_pressure_seq {
@ -199,13 +202,9 @@ fn maybe_evict_idle_candidate_on_pressure(
#[cfg(test)]
fn clear_relay_idle_pressure_state_for_testing() {
if let Some(registry) = RELAY_IDLE_CANDIDATE_REGISTRY.get()
&& let Ok(mut guard) = registry.lock()
{
guard.by_conn_id.clear();
guard.ordered.clear();
guard.pressure_event_seq = 0;
guard.pressure_consumed_seq = 0;
if RELAY_IDLE_CANDIDATE_REGISTRY.get().is_some() {
let mut guard = relay_idle_candidate_registry_lock();
*guard = RelayIdleCandidateRegistry::default();
}
RELAY_IDLE_MARK_SEQ.store(0, Ordering::Relaxed);
}
@ -259,6 +258,7 @@ impl RelayClientIdlePolicy {
struct RelayClientIdleState {
last_client_frame_at: Instant,
soft_idle_marked: bool,
tiny_frame_debt: u32,
}
impl RelayClientIdleState {
@ -266,6 +266,7 @@ impl RelayClientIdleState {
Self {
last_client_frame_at: now,
soft_idle_marked: false,
tiny_frame_debt: 0,
}
}
@ -531,48 +532,28 @@ fn should_yield_c2me_sender(sent_since_yield: usize, has_backlog: bool) -> bool
has_backlog && sent_since_yield >= C2ME_SENDER_FAIRNESS_BUDGET
}
fn quota_exceeded_for_user(stats: &Stats, user: &str, quota_limit: Option<u64>) -> bool {
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota)
}
#[cfg_attr(not(test), allow(dead_code))]
fn quota_would_be_exceeded_for_user(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
bytes: u64,
) -> bool {
quota_limit.is_some_and(|quota| {
let used = stats.get_user_total_octets(user);
used >= quota || bytes > quota.saturating_sub(used)
})
}
fn quota_soft_cap(limit: u64, overshoot: u64) -> u64 {
limit.saturating_add(overshoot)
}
fn quota_exceeded_for_user_soft(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
overshoot: u64,
) -> bool {
quota_limit.is_some_and(|quota| stats.get_user_total_octets(user) >= quota_soft_cap(quota, overshoot))
}
fn quota_would_be_exceeded_for_user_soft(
stats: &Stats,
user: &str,
quota_limit: Option<u64>,
async fn reserve_user_quota_with_yield(
user_stats: &UserStats,
bytes: u64,
overshoot: u64,
) -> bool {
quota_limit.is_some_and(|quota| {
let cap = quota_soft_cap(quota, overshoot);
let used = stats.get_user_total_octets(user);
used >= cap || bytes > cap.saturating_sub(used)
})
limit: u64,
) -> std::result::Result<u64, QuotaReserveError> {
loop {
for _ in 0..QUOTA_RESERVE_SPIN_RETRIES {
match user_stats.quota_try_reserve(bytes, limit) {
Ok(total) => return Ok(total),
Err(QuotaReserveError::LimitExceeded) => {
return Err(QuotaReserveError::LimitExceeded);
}
Err(QuotaReserveError::Contended) => std::hint::spin_loop(),
}
}
tokio::task::yield_now().await;
}
}
fn classify_me_d2c_flush_reason(
@ -619,53 +600,18 @@ fn observe_me_d2c_flush_event(
}
#[cfg(test)]
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
fn relay_idle_pressure_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
quota_user_lock_test_guard()
pub(crate) fn relay_idle_pressure_test_scope() -> std::sync::MutexGuard<'static, ()> {
relay_idle_pressure_test_guard()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn quota_overflow_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
.map(|_| Arc::new(AsyncMutex::new(())))
.collect()
});
let hash = crc32fast::hash(user.as_bytes()) as usize;
Arc::clone(&stripes[hash % stripes.len()])
}
fn quota_user_lock(user: &str) -> Arc<AsyncMutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
locks.retain(|_, value| Arc::strong_count(value) > 1);
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
return quota_overflow_user_lock(user);
}
let created = Arc::new(AsyncMutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
}
async fn enqueue_c2me_command(
tx: &mpsc::Sender<C2MeCommand>,
cmd: C2MeCommand,
@ -691,6 +637,16 @@ async fn enqueue_c2me_command(
}
}
#[cfg(test)]
async fn run_relay_test_step_timeout<F, T>(context: &'static str, fut: F) -> T
where
F: Future<Output = T>,
{
timeout(RELAY_TEST_STEP_TIMEOUT, fut)
.await
.unwrap_or_else(|_| panic!("{context} exceeded {}s", RELAY_TEST_STEP_TIMEOUT.as_secs()))
}
pub(crate) async fn handle_via_middle_proxy<R, W>(
mut crypto_reader: CryptoReader<R>,
crypto_writer: CryptoWriter<W>,
@ -711,6 +667,7 @@ where
{
let user = success.user.clone();
let quota_limit = config.access.user_data_quota.get(&user).copied();
let quota_user_stats = quota_limit.map(|_| stats.get_or_create_user_stats_handle(&user));
let peer = success.peer;
let proto_tag = success.proto_tag;
let pool_generation = me_pool.current_generation();
@ -837,6 +794,7 @@ where
let stats_clone = stats.clone();
let rng_clone = rng.clone();
let user_clone = user.clone();
let quota_user_stats_me_writer = quota_user_stats.clone();
let last_downstream_activity_ms_clone = last_downstream_activity_ms.clone();
let bytes_me2c_clone = bytes_me2c.clone();
let d2c_flush_policy = MeD2cFlushPolicy::from_config(&config);
@ -866,6 +824,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
bytes_me2c_clone.as_ref(),
@ -924,6 +883,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
bytes_me2c_clone.as_ref(),
@ -985,6 +945,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
bytes_me2c_clone.as_ref(),
@ -1048,6 +1009,7 @@ where
&mut frame_buf,
stats_clone.as_ref(),
&user_clone,
quota_user_stats_me_writer.as_deref(),
quota_limit,
d2c_flush_policy.quota_soft_overshoot_bytes,
bytes_me2c_clone.as_ref(),
@ -1219,16 +1181,23 @@ where
forensics.bytes_c2me = forensics
.bytes_c2me
.saturating_add(payload.len() as u64);
if let Some(limit) = quota_limit {
let quota_lock = quota_user_lock(&user);
let _quota_guard = quota_lock.lock().await;
stats.add_user_octets_from(&user, payload.len() as u64);
if quota_exceeded_for_user(stats.as_ref(), &user, Some(limit)) {
if let (Some(limit), Some(user_stats)) =
(quota_limit, quota_user_stats.as_deref())
{
if reserve_user_quota_with_yield(
user_stats,
payload.len() as u64,
limit,
)
.await
.is_err()
{
main_result = Err(ProxyError::DataQuotaExceeded {
user: user.clone(),
});
break;
}
stats.add_user_octets_from_handle(user_stats, payload.len() as u64);
} else {
stats.add_user_octets_from(&user, payload.len() as u64);
}
@ -1321,6 +1290,8 @@ async fn read_client_payload_with_idle_policy<R>(
where
R: AsyncRead + Unpin + Send + 'static,
{
const LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES: u32 = 4;
async fn read_exact_with_policy<R>(
client_reader: &mut CryptoReader<R>,
buf: &mut [u8],
@ -1459,6 +1430,7 @@ where
Ok(())
}
let mut consecutive_zero_len_frames = 0u32;
loop {
let (len, quickack, raw_len_bytes) = match proto_tag {
ProtoTag::Abridged => {
@ -1539,6 +1511,26 @@ where
};
if len == 0 {
idle_state.tiny_frame_debt = idle_state
.tiny_frame_debt
.saturating_add(TINY_FRAME_DEBT_PER_TINY);
if idle_state.tiny_frame_debt >= TINY_FRAME_DEBT_LIMIT {
stats.increment_relay_protocol_desync_close_total();
return Err(ProxyError::Proxy(format!(
"Tiny frame overhead limit exceeded: debt={}, conn_id={}",
idle_state.tiny_frame_debt, forensics.conn_id
)));
}
if !idle_policy.enabled {
consecutive_zero_len_frames = consecutive_zero_len_frames.saturating_add(1);
if consecutive_zero_len_frames > LEGACY_MAX_CONSECUTIVE_ZERO_LEN_FRAMES {
stats.increment_relay_protocol_desync_close_total();
return Err(ProxyError::Proxy(
"Excessive zero-length abridged frames".to_string(),
));
}
}
continue;
}
if len < 4 && proto_tag != ProtoTag::Abridged {
@ -1607,6 +1599,7 @@ where
}
*frame_counter += 1;
idle_state.on_client_frame(Instant::now());
idle_state.tiny_frame_debt = idle_state.tiny_frame_debt.saturating_sub(1);
clear_relay_idle_candidate(forensics.conn_id);
return Ok(Some((payload, quickack)));
}
@ -1690,6 +1683,7 @@ async fn process_me_writer_response<W>(
frame_buf: &mut Vec<u8>,
stats: &Stats,
user: &str,
quota_user_stats: Option<&UserStats>,
quota_limit: Option<u64>,
quota_soft_overshoot_bytes: u64,
bytes_me2c: &AtomicU64,
@ -1708,40 +1702,42 @@ where
trace!(conn_id, bytes = data.len(), flags, "ME->C data");
}
let data_len = data.len() as u64;
if quota_would_be_exceeded_for_user_soft(
stats,
user,
quota_limit,
data_len,
quota_soft_overshoot_bytes,
) {
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
if let (Some(limit), Some(user_stats)) = (quota_limit, quota_user_stats) {
let soft_limit = quota_soft_cap(limit, quota_soft_overshoot_bytes);
if reserve_user_quota_with_yield(user_stats, data_len, soft_limit)
.await
.is_err()
{
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PreWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
}
}
let write_mode =
write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await?;
stats.increment_me_d2c_write_mode(write_mode);
match write_client_payload(client_writer, proto_tag, flags, &data, rng, frame_buf)
.await
{
Ok(mode) => mode,
Err(err) => {
if quota_limit.is_some() {
stats.add_quota_write_fail_bytes_total(data_len);
stats.increment_quota_write_fail_events_total();
}
return Err(err);
}
};
bytes_me2c.fetch_add(data.len() as u64, Ordering::Relaxed);
stats.add_user_octets_to(user, data.len() as u64);
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data.len() as u64);
if quota_exceeded_for_user_soft(
stats,
user,
quota_limit,
quota_soft_overshoot_bytes,
) {
stats.increment_me_d2c_quota_reject_total(MeD2cQuotaRejectStage::PostWrite);
return Err(ProxyError::DataQuotaExceeded {
user: user.to_string(),
});
bytes_me2c.fetch_add(data_len, Ordering::Relaxed);
if let Some(user_stats) = quota_user_stats {
stats.add_user_octets_to_handle(user_stats, data_len);
} else {
stats.add_user_octets_to(user, data_len);
}
stats.increment_me_d2c_data_frames_total();
stats.add_me_d2c_payload_bytes_total(data_len);
stats.increment_me_d2c_write_mode(write_mode);
Ok(MeWriterResponseOutcome::Continue {
frames: 1,
@ -1841,8 +1837,14 @@ where
MeD2cWriteMode::Coalesced
} else {
let header = [first];
client_writer.write_all(&header).await.map_err(ProxyError::Io)?;
client_writer.write_all(data).await.map_err(ProxyError::Io)?;
client_writer
.write_all(&header)
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Split
}
} else if len_words < (1 << 24) {
@ -1864,8 +1866,14 @@ where
MeD2cWriteMode::Coalesced
} else {
let header = [first, lw[0], lw[1], lw[2]];
client_writer.write_all(&header).await.map_err(ProxyError::Io)?;
client_writer.write_all(data).await.map_err(ProxyError::Io)?;
client_writer
.write_all(&header)
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
MeD2cWriteMode::Split
}
} else {
@ -1907,8 +1915,14 @@ where
MeD2cWriteMode::Coalesced
} else {
let header = len_val.to_le_bytes();
client_writer.write_all(&header).await.map_err(ProxyError::Io)?;
client_writer.write_all(data).await.map_err(ProxyError::Io)?;
client_writer
.write_all(&header)
.await
.map_err(ProxyError::Io)?;
client_writer
.write_all(data)
.await
.map_err(ProxyError::Io)?;
if padding_len > 0 {
frame_buf.clear();
if frame_buf.capacity() < padding_len {
@ -1948,10 +1962,6 @@ where
.map_err(ProxyError::Io)
}
#[cfg(test)]
#[path = "tests/middle_relay_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_idle_policy_security_tests.rs"]
mod idle_policy_security_tests;
@ -1964,18 +1974,30 @@ mod desync_all_full_dedup_security_tests;
#[path = "tests/middle_relay_stub_completion_security_tests.rs"]
mod stub_completion_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_coverage_high_risk_security_tests.rs"]
mod coverage_high_risk_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_quota_overflow_lock_security_tests.rs"]
mod quota_overflow_lock_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_length_cast_hardening_security_tests.rs"]
mod length_cast_hardening_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_blackhat_campaign_integration_tests.rs"]
mod blackhat_campaign_integration_tests;
#[path = "tests/middle_relay_idle_registry_poison_security_tests.rs"]
mod middle_relay_idle_registry_poison_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_zero_length_frame_security_tests.rs"]
mod middle_relay_zero_length_frame_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_tiny_frame_debt_security_tests.rs"]
mod middle_relay_tiny_frame_debt_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_tiny_frame_debt_concurrency_security_tests.rs"]
mod middle_relay_tiny_frame_debt_concurrency_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_tiny_frame_debt_proto_chunking_security_tests.rs"]
mod middle_relay_tiny_frame_debt_proto_chunking_security_tests;
#[cfg(test)]
#[path = "tests/middle_relay_atomic_quota_invariant_tests.rs"]
mod middle_relay_atomic_quota_invariant_tests;

View File

@ -4,58 +4,58 @@
#![cfg_attr(test, allow(warnings))]
#![cfg_attr(not(test), forbid(clippy::undocumented_unsafe_blocks))]
#![cfg_attr(
not(test),
deny(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::todo,
clippy::unimplemented,
clippy::correctness,
clippy::option_if_let_else,
clippy::or_fun_call,
clippy::branches_sharing_code,
clippy::single_option_map,
clippy::useless_let_if_seq,
clippy::redundant_locals,
clippy::cloned_ref_to_slice_refs,
unsafe_code,
clippy::await_holding_lock,
clippy::await_holding_refcell_ref,
clippy::debug_assert_with_mut_call,
clippy::macro_use_imports,
clippy::cast_ptr_alignment,
clippy::cast_lossless,
clippy::ptr_as_ptr,
clippy::large_stack_arrays,
clippy::same_functions_in_if_condition,
trivial_casts,
trivial_numeric_casts,
unused_extern_crates,
unused_import_braces,
rust_2018_idioms
)
not(test),
deny(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::todo,
clippy::unimplemented,
clippy::correctness,
clippy::option_if_let_else,
clippy::or_fun_call,
clippy::branches_sharing_code,
clippy::single_option_map,
clippy::useless_let_if_seq,
clippy::redundant_locals,
clippy::cloned_ref_to_slice_refs,
unsafe_code,
clippy::await_holding_lock,
clippy::await_holding_refcell_ref,
clippy::debug_assert_with_mut_call,
clippy::macro_use_imports,
clippy::cast_ptr_alignment,
clippy::cast_lossless,
clippy::ptr_as_ptr,
clippy::large_stack_arrays,
clippy::same_functions_in_if_condition,
trivial_casts,
trivial_numeric_casts,
unused_extern_crates,
unused_import_braces,
rust_2018_idioms
)
)]
#![cfg_attr(
not(test),
allow(
clippy::use_self,
clippy::redundant_closure,
clippy::too_many_arguments,
clippy::doc_markdown,
clippy::missing_const_for_fn,
clippy::unnecessary_operation,
clippy::redundant_pub_crate,
clippy::derive_partial_eq_without_eq,
clippy::type_complexity,
clippy::new_ret_no_self,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::significant_drop_tightening,
clippy::significant_drop_in_scrutinee,
clippy::float_cmp,
clippy::nursery
)
not(test),
allow(
clippy::use_self,
clippy::redundant_closure,
clippy::too_many_arguments,
clippy::doc_markdown,
clippy::missing_const_for_fn,
clippy::unnecessary_operation,
clippy::redundant_pub_crate,
clippy::derive_partial_eq_without_eq,
clippy::type_complexity,
clippy::new_ret_no_self,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::significant_drop_tightening,
clippy::significant_drop_in_scrutinee,
clippy::float_cmp,
clippy::nursery
)
)]
pub mod adaptive_buffers;

View File

@ -52,13 +52,12 @@
//! - `SharedCounters` (atomics) let the watchdog read stats without locking
use crate::error::{ProxyError, Result};
use crate::stats::Stats;
use crate::stats::{Stats, UserStats};
use crate::stream::BufferPool;
use dashmap::DashMap;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf, copy_bidirectional_with_sizes};
@ -209,12 +208,10 @@ struct StatsIo<S> {
counters: Arc<SharedCounters>,
stats: Arc<Stats>,
user: String,
user_stats: Arc<UserStats>,
quota_limit: Option<u64>,
quota_exceeded: Arc<AtomicBool>,
quota_read_wake_scheduled: bool,
quota_write_wake_scheduled: bool,
quota_read_retry_active: Arc<AtomicBool>,
quota_write_retry_active: Arc<AtomicBool>,
quota_bytes_since_check: u64,
epoch: Instant,
}
@ -230,30 +227,21 @@ impl<S> StatsIo<S> {
) -> Self {
// Mark initial activity so the watchdog doesn't fire before data flows
counters.touch(Instant::now(), epoch);
let user_stats = stats.get_or_create_user_stats_handle(&user);
Self {
inner,
counters,
stats,
user,
user_stats,
quota_limit,
quota_exceeded,
quota_read_wake_scheduled: false,
quota_write_wake_scheduled: false,
quota_read_retry_active: Arc::new(AtomicBool::new(false)),
quota_write_retry_active: Arc::new(AtomicBool::new(false)),
quota_bytes_since_check: 0,
epoch,
}
}
}
impl<S> Drop for StatsIo<S> {
fn drop(&mut self) {
self.quota_read_retry_active.store(false, Ordering::Relaxed);
self.quota_write_retry_active
.store(false, Ordering::Relaxed);
}
}
#[derive(Debug)]
struct QuotaIoSentinel;
@ -277,84 +265,22 @@ fn is_quota_io_error(err: &io::Error) -> bool {
.is_some()
}
#[cfg(test)]
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(1);
#[cfg(not(test))]
const QUOTA_CONTENTION_RETRY_INTERVAL: Duration = Duration::from_millis(2);
const QUOTA_NEAR_LIMIT_BYTES: u64 = 64 * 1024;
const QUOTA_LARGE_CHARGE_BYTES: u64 = 16 * 1024;
const QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES: u64 = 4 * 1024;
const QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES: u64 = 64 * 1024;
fn spawn_quota_retry_waker(retry_active: Arc<AtomicBool>, waker: std::task::Waker) {
tokio::task::spawn(async move {
loop {
if !retry_active.load(Ordering::Relaxed) {
break;
}
tokio::time::sleep(QUOTA_CONTENTION_RETRY_INTERVAL).await;
if !retry_active.load(Ordering::Relaxed) {
break;
}
waker.wake_by_ref();
}
});
#[inline]
fn quota_adaptive_interval_bytes(remaining_before: u64) -> u64 {
remaining_before.saturating_div(2).clamp(
QUOTA_ADAPTIVE_INTERVAL_MIN_BYTES,
QUOTA_ADAPTIVE_INTERVAL_MAX_BYTES,
)
}
static QUOTA_USER_LOCKS: OnceLock<DashMap<String, Arc<Mutex<()>>>> = OnceLock::new();
static QUOTA_USER_OVERFLOW_LOCKS: OnceLock<Vec<Arc<Mutex<()>>>> = OnceLock::new();
#[cfg(test)]
const QUOTA_USER_LOCKS_MAX: usize = 64;
#[cfg(not(test))]
const QUOTA_USER_LOCKS_MAX: usize = 4_096;
#[cfg(test)]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 16;
#[cfg(not(test))]
const QUOTA_OVERFLOW_LOCK_STRIPES: usize = 256;
#[cfg(test)]
fn quota_user_lock_test_guard() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
#[cfg(test)]
fn quota_user_lock_test_scope() -> std::sync::MutexGuard<'static, ()> {
quota_user_lock_test_guard()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn quota_overflow_user_lock(user: &str) -> Arc<Mutex<()>> {
let stripes = QUOTA_USER_OVERFLOW_LOCKS.get_or_init(|| {
(0..QUOTA_OVERFLOW_LOCK_STRIPES)
.map(|_| Arc::new(Mutex::new(())))
.collect()
});
let hash = crc32fast::hash(user.as_bytes()) as usize;
Arc::clone(&stripes[hash % stripes.len()])
}
fn quota_user_lock(user: &str) -> Arc<Mutex<()>> {
let locks = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
if let Some(existing) = locks.get(user) {
return Arc::clone(existing.value());
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
locks.retain(|_, value| Arc::strong_count(value) > 1);
}
if locks.len() >= QUOTA_USER_LOCKS_MAX {
return quota_overflow_user_lock(user);
}
let created = Arc::new(Mutex::new(()));
match locks.entry(user.to_string()) {
dashmap::mapref::entry::Entry::Occupied(entry) => Arc::clone(entry.get()),
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(Arc::clone(&created));
created
}
}
#[inline]
fn should_immediate_quota_check(remaining_before: u64, charge_bytes: u64) -> bool {
remaining_before <= QUOTA_NEAR_LIMIT_BYTES || charge_bytes >= QUOTA_LARGE_CHARGE_BYTES
}
impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
@ -364,80 +290,60 @@ impl<S: AsyncRead + Unpin> AsyncRead for StatsIo<S> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
if this.quota_exceeded.load(Ordering::Relaxed) {
if this.quota_exceeded.load(Ordering::Acquire) {
return Poll::Ready(Err(quota_io_error()));
}
let quota_lock = this
.quota_limit
.is_some()
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => {
this.quota_read_wake_scheduled = false;
this.quota_read_retry_active.store(false, Ordering::Relaxed);
Some(guard)
}
Err(_) => {
if !this.quota_read_wake_scheduled {
this.quota_read_wake_scheduled = true;
this.quota_read_retry_active.store(true, Ordering::Relaxed);
spawn_quota_retry_waker(
Arc::clone(&this.quota_read_retry_active),
cx.waker().clone(),
);
}
return Poll::Pending;
}
let mut remaining_before = None;
if let Some(limit) = this.quota_limit {
let used_before = this.user_stats.quota_used();
let remaining = limit.saturating_sub(used_before);
if remaining == 0 {
this.quota_exceeded.store(true, Ordering::Release);
return Poll::Ready(Err(quota_io_error()));
}
} else {
None
};
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
remaining_before = Some(remaining);
}
let before = buf.filled().len();
match Pin::new(&mut this.inner).poll_read(cx, buf) {
Poll::Ready(Ok(())) => {
let n = buf.filled().len() - before;
if n > 0 {
let mut reached_quota_boundary = false;
if let Some(limit) = this.quota_limit {
let used = this.stats.get_user_total_octets(&this.user);
if used >= limit {
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
let remaining = limit - used;
if (n as u64) > remaining {
// Fail closed: when a single read chunk would cross quota,
// stop relay immediately without accounting beyond the cap.
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
reached_quota_boundary = (n as u64) == remaining;
}
let n_to_charge = n as u64;
// C→S: client sent data
this.counters
.c2s_bytes
.fetch_add(n as u64, Ordering::Relaxed);
.fetch_add(n_to_charge, Ordering::Relaxed);
this.counters.c2s_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch);
this.stats.add_user_octets_from(&this.user, n as u64);
this.stats.increment_user_msgs_from(&this.user);
this.stats
.add_user_octets_from_handle(this.user_stats.as_ref(), n_to_charge);
this.stats
.increment_user_msgs_from_handle(this.user_stats.as_ref());
if reached_quota_boundary {
this.quota_exceeded.store(true, Ordering::Relaxed);
if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) {
this.stats
.quota_charge_post_write(this.user_stats.as_ref(), n_to_charge);
if should_immediate_quota_check(remaining, n_to_charge) {
this.quota_bytes_since_check = 0;
if this.user_stats.quota_used() >= limit {
this.quota_exceeded.store(true, Ordering::Release);
}
} else {
this.quota_bytes_since_check =
this.quota_bytes_since_check.saturating_add(n_to_charge);
let interval = quota_adaptive_interval_bytes(remaining);
if this.quota_bytes_since_check >= interval {
this.quota_bytes_since_check = 0;
if this.user_stats.quota_used() >= limit {
this.quota_exceeded.store(true, Ordering::Release);
}
}
}
}
trace!(user = %this.user, bytes = n, "C->S");
@ -456,75 +362,57 @@ impl<S: AsyncWrite + Unpin> AsyncWrite for StatsIo<S> {
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if this.quota_exceeded.load(Ordering::Relaxed) {
if this.quota_exceeded.load(Ordering::Acquire) {
return Poll::Ready(Err(quota_io_error()));
}
let quota_lock = this
.quota_limit
.is_some()
.then(|| quota_user_lock(&this.user));
let _quota_guard = if let Some(lock) = quota_lock.as_ref() {
match lock.try_lock() {
Ok(guard) => {
this.quota_write_wake_scheduled = false;
this.quota_write_retry_active
.store(false, Ordering::Relaxed);
Some(guard)
}
Err(_) => {
if !this.quota_write_wake_scheduled {
this.quota_write_wake_scheduled = true;
this.quota_write_retry_active.store(true, Ordering::Relaxed);
spawn_quota_retry_waker(
Arc::clone(&this.quota_write_retry_active),
cx.waker().clone(),
);
}
return Poll::Pending;
}
}
} else {
None
};
let write_buf = if let Some(limit) = this.quota_limit {
let used = this.stats.get_user_total_octets(&this.user);
if used >= limit {
this.quota_exceeded.store(true, Ordering::Relaxed);
let mut remaining_before = None;
if let Some(limit) = this.quota_limit {
let used_before = this.user_stats.quota_used();
let remaining = limit.saturating_sub(used_before);
if remaining == 0 {
this.quota_exceeded.store(true, Ordering::Release);
return Poll::Ready(Err(quota_io_error()));
}
remaining_before = Some(remaining);
}
let remaining = (limit - used) as usize;
if buf.len() > remaining {
// Fail closed: do not emit partial S->C payload when remaining
// quota cannot accommodate the pending write request.
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
}
buf
} else {
buf
};
match Pin::new(&mut this.inner).poll_write(cx, write_buf) {
match Pin::new(&mut this.inner).poll_write(cx, buf) {
Poll::Ready(Ok(n)) => {
if n > 0 {
let n_to_charge = n as u64;
// S→C: data written to client
this.counters
.s2c_bytes
.fetch_add(n as u64, Ordering::Relaxed);
.fetch_add(n_to_charge, Ordering::Relaxed);
this.counters.s2c_ops.fetch_add(1, Ordering::Relaxed);
this.counters.touch(Instant::now(), this.epoch);
this.stats.add_user_octets_to(&this.user, n as u64);
this.stats.increment_user_msgs_to(&this.user);
this.stats
.add_user_octets_to_handle(this.user_stats.as_ref(), n_to_charge);
this.stats
.increment_user_msgs_to_handle(this.user_stats.as_ref());
if let Some(limit) = this.quota_limit
&& this.stats.get_user_total_octets(&this.user) >= limit
{
this.quota_exceeded.store(true, Ordering::Relaxed);
return Poll::Ready(Err(quota_io_error()));
if let (Some(limit), Some(remaining)) = (this.quota_limit, remaining_before) {
this.stats
.quota_charge_post_write(this.user_stats.as_ref(), n_to_charge);
if should_immediate_quota_check(remaining, n_to_charge) {
this.quota_bytes_since_check = 0;
if this.user_stats.quota_used() >= limit {
this.quota_exceeded.store(true, Ordering::Release);
}
} else {
this.quota_bytes_since_check =
this.quota_bytes_since_check.saturating_add(n_to_charge);
let interval = quota_adaptive_interval_bytes(remaining);
if this.quota_bytes_since_check >= interval {
this.quota_bytes_since_check = 0;
if this.user_stats.quota_used() >= limit {
this.quota_exceeded.store(true, Ordering::Release);
}
}
}
}
trace!(user = %this.user, bytes = n, "S->C");
@ -618,7 +506,7 @@ where
let now = Instant::now();
let idle = wd_counters.idle_duration(now, epoch);
if wd_quota_exceeded.load(Ordering::Relaxed) {
if wd_quota_exceeded.load(Ordering::Acquire) {
warn!(user = %wd_user, "User data quota reached, closing relay");
return;
}
@ -756,18 +644,10 @@ where
}
}
#[cfg(test)]
#[path = "tests/relay_security_tests.rs"]
mod security_tests;
#[cfg(test)]
#[path = "tests/relay_adversarial_tests.rs"]
mod adversarial_tests;
#[cfg(test)]
#[path = "tests/relay_quota_lock_pressure_adversarial_tests.rs"]
mod relay_quota_lock_pressure_adversarial_tests;
#[cfg(test)]
#[path = "tests/relay_quota_boundary_blackhat_tests.rs"]
mod relay_quota_boundary_blackhat_tests;
@ -780,14 +660,14 @@ mod relay_quota_model_adversarial_tests;
#[path = "tests/relay_quota_overflow_regression_tests.rs"]
mod relay_quota_overflow_regression_tests;
#[cfg(test)]
#[path = "tests/relay_quota_extended_attack_surface_security_tests.rs"]
mod relay_quota_extended_attack_surface_security_tests;
#[cfg(test)]
#[path = "tests/relay_watchdog_delta_security_tests.rs"]
mod relay_watchdog_delta_security_tests;
#[cfg(test)]
#[path = "tests/relay_quota_waker_storm_adversarial_tests.rs"]
mod relay_quota_waker_storm_adversarial_tests;
#[cfg(test)]
#[path = "tests/relay_quota_wake_liveness_regression_tests.rs"]
mod relay_quota_wake_liveness_regression_tests;
#[path = "tests/relay_atomic_quota_invariant_tests.rs"]
mod relay_atomic_quota_invariant_tests;

View File

@ -0,0 +1,467 @@
use super::*;
use crate::config::{ProxyConfig, UpstreamConfig, UpstreamType};
use crate::protocol::constants::{MAX_TLS_PLAINTEXT_SIZE, MIN_TLS_CLIENT_HELLO_SIZE};
use crate::stats::Stats;
use crate::transport::UpstreamManager;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf, duplex};
use tokio::net::TcpListener;
#[test]
fn edge_mask_reject_delay_min_greater_than_max_does_not_panic() {
let mut config = ProxyConfig::default();
config.censorship.server_hello_delay_min_ms = 5000;
config.censorship.server_hello_delay_max_ms = 1000;
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let start = std::time::Instant::now();
maybe_apply_mask_reject_delay(&config).await;
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(1000));
assert!(elapsed < Duration::from_millis(1500));
});
}
#[test]
fn edge_handshake_timeout_with_mask_grace_saturating_add_prevents_overflow() {
let mut config = ProxyConfig::default();
config.timeouts.client_handshake = u64::MAX;
config.censorship.mask = true;
let timeout = handshake_timeout_with_mask_grace(&config);
assert_eq!(timeout.as_secs(), u64::MAX);
}
#[test]
fn edge_tls_clienthello_len_in_bounds_exact_boundaries() {
assert!(tls_clienthello_len_in_bounds(MIN_TLS_CLIENT_HELLO_SIZE));
assert!(!tls_clienthello_len_in_bounds(
MIN_TLS_CLIENT_HELLO_SIZE - 1
));
assert!(tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE));
assert!(!tls_clienthello_len_in_bounds(MAX_TLS_PLAINTEXT_SIZE + 1));
}
#[test]
fn edge_synthetic_local_addr_boundaries() {
assert_eq!(synthetic_local_addr(0).port(), 0);
assert_eq!(synthetic_local_addr(80).port(), 80);
assert_eq!(synthetic_local_addr(u16::MAX).port(), u16::MAX);
}
#[test]
fn edge_beobachten_record_handshake_failure_class_stream_error_eof() {
let beobachten = BeobachtenStore::new();
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 1;
let eof_err = ProxyError::Stream(crate::error::StreamError::UnexpectedEof);
let peer_ip: IpAddr = "198.51.100.100".parse().unwrap();
record_handshake_failure_class(&beobachten, &config, peer_ip, &eof_err);
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[expected_64_got_0]"));
}
#[tokio::test]
async fn adversarial_tls_handshake_timeout_during_masking_delay() {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.timeouts.client_handshake = 1;
cfg.censorship.mask = true;
cfg.censorship.server_hello_delay_min_ms = 3000;
cfg.censorship.server_hello_delay_max_ms = 3000;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handle = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.1:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
client_side
.write_all(&[0x16, 0x03, 0x01, 0xFF, 0xFF])
.await
.unwrap();
let result = tokio::time::timeout(Duration::from_secs(4), handle)
.await
.unwrap()
.unwrap();
assert!(matches!(result, Err(ProxyError::TgHandshakeTimeout)));
assert_eq!(stats.get_handshake_timeouts(), 1);
}
#[tokio::test]
async fn blackhat_proxy_protocol_slowloris_timeout() {
let mut cfg = ProxyConfig::default();
cfg.server.proxy_protocol_header_timeout_ms = 200;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handle = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.2:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
true,
));
client_side.write_all(b"PROXY TCP4 192.").await.unwrap();
tokio::time::sleep(Duration::from_millis(300)).await;
let result = tokio::time::timeout(Duration::from_secs(2), handle)
.await
.unwrap()
.unwrap();
assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol)));
assert_eq!(stats.get_connects_bad(), 1);
}
#[test]
fn blackhat_ipv4_mapped_ipv6_proxy_source_bypass_attempt() {
let trusted = vec!["192.0.2.0/24".parse().unwrap()];
let peer_ip = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc000, 0x0201));
assert!(!is_trusted_proxy_source(peer_ip, &trusted));
}
#[tokio::test]
async fn negative_proxy_protocol_enabled_but_client_sends_tls_hello() {
let mut cfg = ProxyConfig::default();
cfg.server.proxy_protocol_header_timeout_ms = 500;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handle = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.3:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
true,
));
client_side
.write_all(&[0x16, 0x03, 0x01, 0x02, 0x00])
.await
.unwrap();
let result = tokio::time::timeout(Duration::from_secs(2), handle)
.await
.unwrap()
.unwrap();
assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol)));
assert_eq!(stats.get_connects_bad(), 1);
}
#[tokio::test]
async fn edge_client_stream_exactly_4_bytes_eof() {
let config = Arc::new(ProxyConfig::default());
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let handle = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.4:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
client_side
.write_all(&[0x16, 0x03, 0x01, 0x00])
.await
.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handle).await;
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[expected_64_got_0]"));
}
#[tokio::test]
async fn edge_client_stream_tls_header_valid_but_body_1_byte_short_eof() {
let config = Arc::new(ProxyConfig::default());
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handle = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.5:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
client_side
.write_all(&[0x16, 0x03, 0x01, 0x00, 100])
.await
.unwrap();
client_side.write_all(&vec![0x41; 99]).await.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handle).await;
assert_eq!(stats.get_connects_bad(), 1);
}
#[tokio::test]
async fn integration_non_tls_modes_disabled_immediately_masks() {
let mut cfg = ProxyConfig::default();
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
cfg.censorship.mask = true;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handle = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.6:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
client_side.write_all(b"GET / HTTP/1.1\r\n").await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handle).await;
assert_eq!(stats.get_connects_bad(), 1);
}
struct YieldingReader {
data: Vec<u8>,
pos: usize,
yields_left: usize,
}
impl AsyncRead for YieldingReader {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = self.get_mut();
if this.yields_left > 0 {
this.yields_left -= 1;
cx.waker().wake_by_ref();
return Poll::Pending;
}
if this.pos >= this.data.len() {
return Poll::Ready(Ok(()));
}
buf.put_slice(&this.data[this.pos..this.pos + 1]);
this.pos += 1;
this.yields_left = 2;
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn fuzz_read_with_progress_heavy_yielding() {
let expected_data = b"HEAVY_YIELD_TEST_DATA".to_vec();
let mut reader = YieldingReader {
data: expected_data.clone(),
pos: 0,
yields_left: 2,
};
let mut buf = vec![0u8; expected_data.len()];
let read_bytes = read_with_progress(&mut reader, &mut buf).await.unwrap();
assert_eq!(read_bytes, expected_data.len());
assert_eq!(buf, expected_data);
}
#[test]
fn edge_wrap_tls_application_record_exactly_u16_max() {
let payload = vec![0u8; 65535];
let wrapped = wrap_tls_application_record(&payload);
assert_eq!(wrapped.len(), 65540);
assert_eq!(wrapped[0], TLS_RECORD_APPLICATION);
assert_eq!(&wrapped[3..5], &65535u16.to_be_bytes());
}
#[test]
fn fuzz_wrap_tls_application_record_lengths() {
let lengths = [0, 1, 65534, 65535, 65536, 131070, 131071, 131072];
for len in lengths {
let payload = vec![0u8; len];
let wrapped = wrap_tls_application_record(&payload);
let expected_chunks = len.div_ceil(65535).max(1);
assert_eq!(wrapped.len(), len + 5 * expected_chunks);
}
}
#[tokio::test]
async fn stress_user_connection_reservation_concurrent_same_ip_exhaustion() {
let user = "stress-same-ip-user";
let mut config = ProxyConfig::default();
config.access.user_max_tcp_conns.insert(user.to_string(), 5);
let config = Arc::new(config);
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 10).await;
let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77)), 55000);
let mut tasks = tokio::task::JoinSet::new();
let mut reservations = Vec::new();
for _ in 0..10 {
let config = config.clone();
let stats = stats.clone();
let ip_tracker = ip_tracker.clone();
tasks.spawn(async move {
RunningClientHandler::acquire_user_connection_reservation_static(
user, &config, stats, peer, ip_tracker,
)
.await
});
}
let mut successes = 0;
let mut failures = 0;
while let Some(res) = tasks.join_next().await {
match res.unwrap() {
Ok(r) => {
successes += 1;
reservations.push(r);
}
Err(_) => failures += 1,
}
}
assert_eq!(successes, 5);
assert_eq!(failures, 5);
assert_eq!(stats.get_user_curr_connects(user), 5);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 1);
for reservation in reservations {
reservation.release().await;
}
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}

View File

@ -0,0 +1,222 @@
use super::*;
use crate::config::ProxyConfig;
use crate::protocol::constants::MIN_TLS_CLIENT_HELLO_SIZE;
use crate::stats::Stats;
use crate::transport::UpstreamManager;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncWriteExt, duplex};
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
}
#[test]
fn invariant_wrap_tls_application_record_exact_multiples() {
let chunk_size = u16::MAX as usize;
let payload = vec![0xAA; chunk_size * 2];
let wrapped = wrap_tls_application_record(&payload);
assert_eq!(wrapped.len(), 2 * (5 + chunk_size));
assert_eq!(wrapped[0], TLS_RECORD_APPLICATION);
assert_eq!(&wrapped[3..5], &65535u16.to_be_bytes());
let second_header_idx = 5 + chunk_size;
assert_eq!(wrapped[second_header_idx], TLS_RECORD_APPLICATION);
assert_eq!(
&wrapped[second_header_idx + 3..second_header_idx + 5],
&65535u16.to_be_bytes()
);
}
#[tokio::test]
async fn invariant_tls_clienthello_truncation_exact_boundary_triggers_masking() {
let config = Arc::new(ProxyConfig::default());
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.20:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
let claimed_len = MIN_TLS_CLIENT_HELLO_SIZE as u16;
let mut header = vec![0x16, 0x03, 0x01];
header.extend_from_slice(&claimed_len.to_be_bytes());
client_side.write_all(&header).await.unwrap();
client_side
.write_all(&vec![0x42; MIN_TLS_CLIENT_HELLO_SIZE - 1])
.await
.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handler)
.await
.unwrap();
assert_eq!(stats.get_connects_bad(), 1);
}
#[tokio::test]
async fn invariant_acquire_reservation_ip_limit_rollback() {
let user = "rollback-test-user";
let mut config = ProxyConfig::default();
config
.access
.user_max_tcp_conns
.insert(user.to_string(), 10);
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
let peer_a = "198.51.100.21:55000".parse().unwrap();
let _res_a = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer_a,
ip_tracker.clone(),
)
.await
.unwrap();
assert_eq!(stats.get_user_curr_connects(user), 1);
let peer_b = "203.0.113.22:55000".parse().unwrap();
let res_b = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer_b,
ip_tracker.clone(),
)
.await;
assert!(matches!(
res_b,
Err(ProxyError::ConnectionLimitExceeded { .. })
));
assert_eq!(stats.get_user_curr_connects(user), 1);
}
#[tokio::test]
async fn invariant_quota_exact_boundary_inclusive() {
let user = "quota-strict-user";
let mut config = ProxyConfig::default();
config.access.user_data_quota.insert(user.to_string(), 1000);
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
let peer = "198.51.100.23:55000".parse().unwrap();
preload_user_quota(stats.as_ref(), user, 999);
let res1 = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer,
ip_tracker.clone(),
)
.await;
assert!(res1.is_ok());
res1.unwrap().release().await;
preload_user_quota(stats.as_ref(), user, 1);
let res2 = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer,
ip_tracker.clone(),
)
.await;
assert!(matches!(res2, Err(ProxyError::DataQuotaExceeded { .. })));
}
#[tokio::test]
async fn invariant_direct_mode_partial_header_eof_is_error_not_bad_connect() {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = true;
cfg.general.beobachten_minutes = 1;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.25:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
client_side.write_all(&[0xEF, 0xEF, 0xEF]).await.unwrap();
client_side.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(2), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_err());
assert_eq!(stats.get_connects_bad(), 0);
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[expected_64_got_0]"));
}
#[tokio::test]
async fn invariant_route_mode_snapshot_picks_up_latest_mode() {
let route_runtime = Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct));
assert!(matches!(
route_runtime.snapshot().mode,
RelayRouteMode::Direct
));
route_runtime.set_mode(RelayRouteMode::Middle);
assert!(matches!(
route_runtime.snapshot().mode,
RelayRouteMode::Middle
));
}

View File

@ -0,0 +1,100 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::Duration;
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
#[tokio::test]
async fn fragmented_connect_probe_is_classified_as_http_via_prefetch_window() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
stream.read_to_end(&mut got).await.unwrap();
got
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = true;
cfg.general.beobachten_minutes = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let peer: SocketAddr = "198.51.100.251:57501".parse().unwrap();
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
new_upstream_manager(stats),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
client_side.write_all(b"CONNE").await.unwrap();
client_side
.write_all(b"CT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n")
.await
.unwrap();
client_side.shutdown().await.unwrap();
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
assert!(
forwarded.starts_with(b"CONNECT example.org:443 HTTP/1.1"),
"mask backend must receive the full fragmented CONNECT probe"
);
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[HTTP]"));
assert!(snapshot.contains("198.51.100.251-1"));
}

View File

@ -0,0 +1,122 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, sleep};
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
async fn run_http2_fragment_case(split_at: usize, delay_ms: u64, peer: SocketAddr) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
stream.read_to_end(&mut got).await.unwrap();
got
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = true;
cfg.general.beobachten_minutes = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
new_upstream_manager(stats),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
let first = split_at.min(preface.len());
client_side.write_all(&preface[..first]).await.unwrap();
if first < preface.len() {
sleep(Duration::from_millis(delay_ms)).await;
client_side.write_all(&preface[first..]).await.unwrap();
}
client_side.shutdown().await.unwrap();
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
assert!(
forwarded.starts_with(&preface),
"mask backend must receive an intact HTTP/2 preface prefix"
);
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[HTTP]"));
assert!(snapshot.contains(&format!("{}-1", peer.ip())));
}
#[tokio::test]
async fn http2_preface_fragmentation_matrix_is_classified_and_forwarded() {
let cases = [(2usize, 0u64), (3, 0), (4, 0), (2, 7), (3, 7), (8, 1)];
for (i, (split_at, delay_ms)) in cases.into_iter().enumerate() {
let peer: SocketAddr = format!("198.51.100.{}:58{}", 140 + i, 100 + i)
.parse()
.unwrap();
run_http2_fragment_case(split_at, delay_ms, peer).await;
}
}
#[tokio::test]
async fn http2_preface_splitpoint_light_fuzz_classifies_http() {
for split_at in 2usize..=12 {
let delay_ms = if split_at % 3 == 0 { 7 } else { 1 };
let peer: SocketAddr = format!("198.51.101.{}:59{}", split_at, 10 + split_at)
.parse()
.unwrap();
run_http2_fragment_case(split_at, delay_ms, peer).await;
}
}

View File

@ -0,0 +1,150 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, sleep};
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
async fn run_pipeline_prefetch_case(
prefetch_timeout_ms: u64,
delayed_tail_ms: u64,
peer: SocketAddr,
) -> (Vec<u8>, String) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = Vec::new();
stream.read_to_end(&mut got).await.unwrap();
got
});
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = true;
cfg.general.beobachten_minutes = 1;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = backend_addr.port();
cfg.censorship.mask_classifier_prefetch_timeout_ms = prefetch_timeout_ms;
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
new_upstream_manager(stats),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
client_side.write_all(b"C").await.unwrap();
sleep(Duration::from_millis(delayed_tail_ms)).await;
client_side
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n")
.await
.unwrap();
client_side.shutdown().await.unwrap();
let forwarded = tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
let result = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
assert!(result.is_ok());
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
(forwarded, snapshot)
}
#[tokio::test]
async fn tdd_pipeline_prefetch_5ms_misses_15ms_tail_and_classifies_as_port_scanner() {
let peer: SocketAddr = "198.51.100.171:58071".parse().unwrap();
let (forwarded, snapshot) = run_pipeline_prefetch_case(5, 15, peer).await;
assert!(
forwarded.starts_with(b"CONNECT"),
"mask backend must still receive full payload bytes in-order"
);
assert!(
snapshot.contains("[HTTP]") || snapshot.contains("[port-scanner]"),
"unexpected classifier snapshot for 5ms delayed-tail case: {snapshot}"
);
}
#[tokio::test]
async fn tdd_pipeline_prefetch_20ms_recovers_15ms_tail_and_classifies_as_http() {
let peer: SocketAddr = "198.51.100.172:58072".parse().unwrap();
let (forwarded, snapshot) = run_pipeline_prefetch_case(20, 15, peer).await;
assert!(
forwarded.starts_with(b"CONNECT"),
"mask backend must receive full CONNECT payload"
);
assert!(
snapshot.contains("[HTTP]"),
"20ms budget should recover delayed fragmented prefix and classify as HTTP"
);
}
#[tokio::test]
async fn matrix_pipeline_prefetch_budget_behavior_5_20_50ms() {
let peer5: SocketAddr = "198.51.100.173:58073".parse().unwrap();
let peer20: SocketAddr = "198.51.100.174:58074".parse().unwrap();
let peer50: SocketAddr = "198.51.100.175:58075".parse().unwrap();
let (_, snap5) = run_pipeline_prefetch_case(5, 35, peer5).await;
let (_, snap20) = run_pipeline_prefetch_case(20, 35, peer20).await;
let (_, snap50) = run_pipeline_prefetch_case(50, 35, peer50).await;
assert!(
snap5.contains("[HTTP]") || snap5.contains("[port-scanner]"),
"unexpected 5ms snapshot: {snap5}"
);
assert!(
snap20.contains("[HTTP]") || snap20.contains("[port-scanner]"),
"unexpected 20ms snapshot: {snap20}"
);
assert!(snap50.contains("[HTTP]"));
}

View File

@ -0,0 +1,88 @@
use super::*;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, sleep};
#[test]
fn prefetch_timeout_budget_reads_from_config() {
let mut cfg = ProxyConfig::default();
assert_eq!(
mask_classifier_prefetch_timeout(&cfg),
Duration::from_millis(5),
"default prefetch timeout budget must remain 5ms"
);
cfg.censorship.mask_classifier_prefetch_timeout_ms = 20;
assert_eq!(
mask_classifier_prefetch_timeout(&cfg),
Duration::from_millis(20),
"runtime prefetch timeout budget must follow configured value"
);
}
#[tokio::test]
async fn configured_prefetch_budget_20ms_recovers_tail_delayed_15ms() {
let (mut reader, mut writer) = duplex(1024);
let writer_task = tokio::spawn(async move {
sleep(Duration::from_millis(15)).await;
writer
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
.await
.expect("tail bytes must be writable");
writer
.shutdown()
.await
.expect("writer shutdown must succeed");
});
let mut initial_data = b"C".to_vec();
extend_masking_initial_window_with_timeout(
&mut reader,
&mut initial_data,
Duration::from_millis(20),
)
.await;
writer_task
.await
.expect("writer task must not panic in runtime timeout test");
assert!(
initial_data.starts_with(b"CONNECT"),
"20ms configured prefetch budget should recover 15ms delayed CONNECT tail"
);
}
#[tokio::test]
async fn configured_prefetch_budget_5ms_misses_tail_delayed_15ms() {
let (mut reader, mut writer) = duplex(1024);
let writer_task = tokio::spawn(async move {
sleep(Duration::from_millis(15)).await;
writer
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
.await
.expect("tail bytes must be writable");
writer
.shutdown()
.await
.expect("writer shutdown must succeed");
});
let mut initial_data = b"C".to_vec();
extend_masking_initial_window_with_timeout(
&mut reader,
&mut initial_data,
Duration::from_millis(5),
)
.await;
writer_task
.await
.expect("writer task must not panic in runtime timeout test");
assert!(
!initial_data.starts_with(b"CONNECT"),
"5ms configured prefetch budget should miss 15ms delayed CONNECT tail"
);
}

View File

@ -0,0 +1,264 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION};
use crate::protocol::tls;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
struct PipelineHarness {
config: Arc<ProxyConfig>,
stats: Arc<Stats>,
upstream_manager: Arc<UpstreamManager>,
replay_checker: Arc<ReplayChecker>,
buffer_pool: Arc<BufferPool>,
rng: Arc<SecureRandom>,
route_runtime: Arc<RouteRuntimeController>,
ip_tracker: Arc<UserIpTracker>,
beobachten: Arc<BeobachtenStore>,
}
fn build_harness(secret_hex: &str, mask_port: u16) -> PipelineHarness {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = mask_port;
cfg.censorship.mask_proxy_protocol = 0;
cfg.access.ignore_time_skew = true;
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let upstream_manager = Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats.clone(),
));
PipelineHarness {
config,
stats,
upstream_manager,
replay_checker: Arc::new(ReplayChecker::new(256, Duration::from_secs(60))),
buffer_pool: Arc::new(BufferPool::new()),
rng: Arc::new(SecureRandom::new()),
route_runtime: Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
ip_tracker: Arc::new(UserIpTracker::new()),
beobachten: Arc::new(BeobachtenStore::new()),
}
}
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn wrap_tls_application_data(payload: &[u8]) -> Vec<u8> {
let mut record = Vec::with_capacity(5 + payload.len());
record.push(0x17);
record.extend_from_slice(&TLS_VERSION);
record.extend_from_slice(&(payload.len() as u16).to_be_bytes());
record.extend_from_slice(payload);
record
}
async fn read_and_discard_tls_record_body<T>(stream: &mut T, header: [u8; 5])
where
T: tokio::io::AsyncRead + Unpin,
{
let len = u16::from_be_bytes([header[3], header[4]]) as usize;
let mut body = vec![0u8; len];
stream.read_exact(&mut body).await.unwrap();
}
#[test]
fn empty_initial_data_prefetch_gate_is_fail_closed() {
assert!(
!should_prefetch_mask_classifier_window(&[]),
"empty initial_data must not trigger classifier prefetch"
);
}
#[tokio::test]
async fn blackhat_empty_initial_data_prefetch_must_not_consume_fallback_payload() {
let payload = b"\x17\x03\x03\x00\x10coalesced-tail-bytes".to_vec();
let (mut reader, mut writer) = duplex(1024);
writer.write_all(&payload).await.unwrap();
writer.shutdown().await.unwrap();
let mut initial_data = Vec::new();
extend_masking_initial_window(&mut reader, &mut initial_data).await;
assert!(
initial_data.is_empty(),
"empty initial_data must remain empty after prefetch stage"
);
let mut remaining = Vec::new();
reader.read_to_end(&mut remaining).await.unwrap();
assert_eq!(
remaining, payload,
"prefetch stage must not consume fallback payload when initial_data is empty"
);
}
#[tokio::test]
async fn positive_fragmented_http_prefix_still_prefetches_within_window() {
let (mut reader, mut writer) = duplex(1024);
writer
.write_all(b"NECT example.org:443 HTTP/1.1\r\n")
.await
.unwrap();
writer.shutdown().await.unwrap();
let mut initial_data = b"CON".to_vec();
extend_masking_initial_window(&mut reader, &mut initial_data).await;
assert!(
initial_data.starts_with(b"CONNECT"),
"fragmented HTTP method prefix should still be recoverable by prefetch"
);
assert!(
initial_data.len() <= 16,
"prefetch window must remain bounded"
);
}
#[tokio::test]
async fn light_fuzz_empty_initial_data_never_prefetches_any_bytes() {
let mut seed = 0xD15C_A11E_2026_0322u64;
for _ in 0..128 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let len = ((seed & 0x3f) as usize).saturating_add(1);
let mut payload = vec![0u8; len];
for (idx, byte) in payload.iter_mut().enumerate() {
*byte = (seed as u8).wrapping_add(idx as u8).wrapping_mul(17);
}
let (mut reader, mut writer) = duplex(1024);
writer.write_all(&payload).await.unwrap();
writer.shutdown().await.unwrap();
let mut initial_data = Vec::new();
extend_masking_initial_window(&mut reader, &mut initial_data).await;
assert!(initial_data.is_empty());
let mut remaining = Vec::new();
reader.read_to_end(&mut remaining).await.unwrap();
assert_eq!(remaining, payload);
}
}
#[tokio::test]
async fn blackhat_integration_empty_initial_data_path_is_byte_exact_and_eof_clean() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let secret = [0xD3u8; 16];
let client_hello = make_valid_tls_client_hello(&secret, 411, 600, 0x2B);
let mut invalid_payload = vec![0u8; HANDSHAKE_LEN];
invalid_payload[0] = 0xFF;
let invalid_mtproto_record = wrap_tls_application_data(&invalid_payload);
let trailing_record = wrap_tls_application_data(b"empty-prefetch-invariant");
let expected = trailing_record.clone();
let accept_task = tokio::spawn(async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; expected.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, expected);
let mut one = [0u8; 1];
let n = stream.read(&mut one).await.unwrap();
assert_eq!(
n, 0,
"fallback stream must not append synthetic bytes on empty initial_data path"
);
});
let harness = build_harness("d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3d3", backend_addr.port());
let (server_side, mut client_side) = duplex(131072);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.245:56145".parse().unwrap(),
harness.config,
harness.stats,
harness.upstream_manager,
harness.replay_checker,
harness.buffer_pool,
harness.rng,
None,
harness.route_runtime,
None,
harness.ip_tracker,
harness.beobachten,
false,
));
client_side.write_all(&client_hello).await.unwrap();
let mut head = [0u8; 5];
client_side.read_exact(&mut head).await.unwrap();
assert_eq!(head[0], 0x16);
read_and_discard_tls_record_body(&mut client_side, head).await;
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side.write_all(&trailing_record).await.unwrap();
client_side.shutdown().await.unwrap();
tokio::time::timeout(Duration::from_secs(3), accept_task)
.await
.unwrap()
.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(3), handler)
.await
.unwrap()
.unwrap();
}

View File

@ -0,0 +1,72 @@
use super::*;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, advance, sleep};
async fn run_strict_prefetch_case(prefetch_ms: u64, tail_delay_ms: u64) -> Vec<u8> {
let (mut reader, mut writer) = duplex(1024);
let writer_task = tokio::spawn(async move {
sleep(Duration::from_millis(tail_delay_ms)).await;
let _ = writer
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
.await;
let _ = writer.shutdown().await;
});
let mut initial_data = b"C".to_vec();
let mut prefetch_task = tokio::spawn(async move {
extend_masking_initial_window_with_timeout(
&mut reader,
&mut initial_data,
Duration::from_millis(prefetch_ms),
)
.await;
initial_data
});
tokio::task::yield_now().await;
if tail_delay_ms > 0 {
advance(Duration::from_millis(tail_delay_ms)).await;
tokio::task::yield_now().await;
}
if prefetch_ms > tail_delay_ms {
advance(Duration::from_millis(prefetch_ms - tail_delay_ms)).await;
tokio::task::yield_now().await;
}
let result = prefetch_task.await.expect("prefetch task must not panic");
writer_task.await.expect("writer task must not panic");
result
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_5ms_misses_15ms_tail() {
let got = run_strict_prefetch_case(5, 15).await;
assert_eq!(got, b"C".to_vec());
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_20ms_recovers_15ms_tail() {
let got = run_strict_prefetch_case(20, 15).await;
assert!(got.starts_with(b"CONNECT"));
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_50ms_recovers_35ms_tail() {
let got = run_strict_prefetch_case(50, 35).await;
assert!(got.starts_with(b"CONNECT"));
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_equal_budget_and_delay_recovers_tail() {
let got = run_strict_prefetch_case(20, 20).await;
assert!(got.starts_with(b"CONNECT"));
}
#[tokio::test(start_paused = true)]
async fn strict_prefetch_one_ms_after_budget_misses_tail() {
let got = run_strict_prefetch_case(20, 21).await;
assert_eq!(got, b"C".to_vec());
}

View File

@ -0,0 +1,98 @@
use super::*;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt, duplex};
use tokio::time::{Duration, sleep, timeout};
async fn extend_masking_initial_window_with_budget<R>(
reader: &mut R,
initial_data: &mut Vec<u8>,
prefetch_timeout: Duration,
) where
R: AsyncRead + Unpin,
{
if !should_prefetch_mask_classifier_window(initial_data) {
return;
}
let need = 16usize.saturating_sub(initial_data.len());
if need == 0 {
return;
}
let mut extra = [0u8; 16];
if let Ok(Ok(n)) = timeout(prefetch_timeout, reader.read(&mut extra[..need])).await
&& n > 0
{
initial_data.extend_from_slice(&extra[..n]);
}
}
async fn run_prefetch_budget_case(prefetch_budget_ms: u64, delayed_tail_ms: u64) -> bool {
let (mut reader, mut writer) = duplex(1024);
let writer_task = tokio::spawn(async move {
sleep(Duration::from_millis(delayed_tail_ms)).await;
writer
.write_all(b"ONNECT example.org:443 HTTP/1.1\r\n")
.await
.expect("tail bytes must be writable");
writer
.shutdown()
.await
.expect("writer shutdown must succeed");
});
let mut initial_data = b"C".to_vec();
extend_masking_initial_window_with_budget(
&mut reader,
&mut initial_data,
Duration::from_millis(prefetch_budget_ms),
)
.await;
writer_task
.await
.expect("writer task must not panic during matrix case");
initial_data.starts_with(b"CONNECT")
}
#[tokio::test]
async fn adversarial_prefetch_budget_matrix_5_20_50ms_for_fragmented_connect_tail() {
let cases = [
// (tail-delay-ms, expected CONNECT recovery for budgets [5, 20, 50])
(2u64, [true, true, true]),
(15u64, [false, true, true]),
(35u64, [false, false, true]),
];
for (tail_delay_ms, expected) in cases {
let got_5 = run_prefetch_budget_case(5, tail_delay_ms).await;
let got_20 = run_prefetch_budget_case(20, tail_delay_ms).await;
let got_50 = run_prefetch_budget_case(50, tail_delay_ms).await;
assert_eq!(
got_5, expected[0],
"5ms prefetch budget mismatch for tail delay {}ms",
tail_delay_ms
);
assert_eq!(
got_20, expected[1],
"20ms prefetch budget mismatch for tail delay {}ms",
tail_delay_ms
);
assert_eq!(
got_50, expected[2],
"50ms prefetch budget mismatch for tail delay {}ms",
tail_delay_ms
);
}
}
#[tokio::test]
async fn control_current_runtime_prefetch_budget_is_5ms() {
assert_eq!(
MASK_CLASSIFIER_PREFETCH_TIMEOUT,
Duration::from_millis(5),
"matrix assumptions require current runtime prefetch budget to stay at 5ms"
);
}

View File

@ -0,0 +1,167 @@
use super::*;
use crate::config::{UpstreamConfig, UpstreamType};
use crate::crypto::sha256_hmac;
use crate::protocol::constants::{HANDSHAKE_LEN, TLS_VERSION};
use crate::protocol::tls;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant};
fn new_upstream_manager(stats: Arc<Stats>) -> Arc<UpstreamManager> {
Arc::new(UpstreamManager::new(
vec![UpstreamConfig {
upstream_type: UpstreamType::Direct {
interface: None,
bind_addresses: None,
},
weight: 1,
enabled: true,
scopes: String::new(),
selected_scope: String::new(),
}],
1,
1,
1,
1,
false,
stats,
))
}
fn make_valid_tls_client_hello(secret: &[u8], timestamp: u32, tls_len: usize, fill: u8) -> Vec<u8> {
let total_len = 5 + tls_len;
let mut handshake = vec![fill; total_len];
handshake[0] = 0x16;
handshake[1] = 0x03;
handshake[2] = 0x01;
handshake[3..5].copy_from_slice(&(tls_len as u16).to_be_bytes());
let session_id_len: usize = 32;
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
async fn run_replay_candidate_session(
replay_checker: Arc<ReplayChecker>,
hello: &[u8],
peer: SocketAddr,
drive_mtproto_fail: bool,
) -> Duration {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = 1;
cfg.censorship.mask_timing_normalization_enabled = false;
cfg.access.ignore_time_skew = true;
cfg.access.users.insert(
"user".to_string(),
"abababababababababababababababab".to_string(),
);
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(65536);
let started = Instant::now();
let task = tokio::spawn(handle_client_stream(
server_side,
peer,
config,
stats.clone(),
new_upstream_manager(stats),
replay_checker,
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten,
false,
));
client_side.write_all(hello).await.unwrap();
if drive_mtproto_fail {
let mut server_hello_head = [0u8; 5];
client_side
.read_exact(&mut server_hello_head)
.await
.unwrap();
assert_eq!(server_hello_head[0], 0x16);
let body_len = u16::from_be_bytes([server_hello_head[3], server_hello_head[4]]) as usize;
let mut body = vec![0u8; body_len];
client_side.read_exact(&mut body).await.unwrap();
let mut invalid_mtproto_record = Vec::with_capacity(5 + HANDSHAKE_LEN);
invalid_mtproto_record.push(0x17);
invalid_mtproto_record.extend_from_slice(&TLS_VERSION);
invalid_mtproto_record.extend_from_slice(&(HANDSHAKE_LEN as u16).to_be_bytes());
invalid_mtproto_record.extend_from_slice(&vec![0u8; HANDSHAKE_LEN]);
client_side
.write_all(&invalid_mtproto_record)
.await
.unwrap();
client_side
.write_all(b"GET /replay-fallback HTTP/1.1\r\nHost: x\r\n\r\n")
.await
.unwrap();
}
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(4), task)
.await
.unwrap()
.unwrap();
started.elapsed()
}
#[tokio::test]
async fn replay_reject_still_honors_masking_timing_budget() {
let replay_checker = Arc::new(ReplayChecker::new(256, Duration::from_secs(60)));
let hello = make_valid_tls_client_hello(&[0xAB; 16], 7, 600, 0x51);
let seed_elapsed = run_replay_candidate_session(
Arc::clone(&replay_checker),
&hello,
"198.51.100.201:58001".parse().unwrap(),
true,
)
.await;
assert!(
seed_elapsed >= Duration::from_millis(40) && seed_elapsed < Duration::from_millis(250),
"seed replay-candidate run must honor masking timing budget without unbounded delay"
);
let replay_elapsed = run_replay_candidate_session(
Arc::clone(&replay_checker),
&hello,
"198.51.100.202:58002".parse().unwrap(),
false,
)
.await;
assert!(
replay_elapsed >= Duration::from_millis(40) && replay_elapsed < Duration::from_millis(250),
"replay rejection path must still satisfy masking timing budget without unbounded DB/CPU delay"
);
}

View File

@ -0,0 +1,288 @@
use super::*;
use crate::config::ProxyConfig;
use crate::stats::Stats;
use crate::transport::UpstreamManager;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
}
#[tokio::test]
async fn edge_mask_delay_bypassed_if_max_is_zero() {
let mut config = ProxyConfig::default();
config.censorship.server_hello_delay_min_ms = 10_000;
config.censorship.server_hello_delay_max_ms = 0;
let start = std::time::Instant::now();
maybe_apply_mask_reject_delay(&config).await;
assert!(start.elapsed() < Duration::from_millis(50));
}
#[test]
fn edge_beobachten_ttl_clamps_exactly_to_24_hours() {
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 100_000;
let ttl = beobachten_ttl(&config);
assert_eq!(ttl.as_secs(), 24 * 60 * 60);
}
#[test]
fn edge_wrap_tls_application_record_empty_payload() {
let wrapped = wrap_tls_application_record(&[]);
assert_eq!(wrapped.len(), 5);
assert_eq!(wrapped[0], TLS_RECORD_APPLICATION);
assert_eq!(&wrapped[3..5], &[0, 0]);
}
#[tokio::test]
async fn boundary_user_data_quota_exact_match_rejects() {
let user = "quota-boundary-user";
let mut config = ProxyConfig::default();
config.access.user_data_quota.insert(user.to_string(), 1024);
let stats = Arc::new(Stats::new());
preload_user_quota(stats.as_ref(), user, 1024);
let ip_tracker = Arc::new(UserIpTracker::new());
let peer = "198.51.100.10:55000".parse().unwrap();
let result = RunningClientHandler::acquire_user_connection_reservation_static(
user, &config, stats, peer, ip_tracker,
)
.await;
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
}
#[tokio::test]
async fn boundary_user_expiration_in_past_rejects() {
let user = "expired-boundary-user";
let mut config = ProxyConfig::default();
let expired_time = chrono::Utc::now() - chrono::Duration::milliseconds(1);
config
.access
.user_expirations
.insert(user.to_string(), expired_time);
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
let peer = "198.51.100.11:55000".parse().unwrap();
let result = RunningClientHandler::acquire_user_connection_reservation_static(
user, &config, stats, peer, ip_tracker,
)
.await;
assert!(matches!(result, Err(ProxyError::UserExpired { .. })));
}
#[tokio::test]
async fn blackhat_proxy_protocol_massive_garbage_rejected_quickly() {
let mut cfg = ProxyConfig::default();
cfg.server.proxy_protocol_header_timeout_ms = 300;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.12:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
true,
));
client_side.write_all(&vec![b'A'; 2000]).await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(2), handler)
.await
.unwrap()
.unwrap();
assert!(matches!(result, Err(ProxyError::InvalidProxyProtocol)));
assert_eq!(stats.get_connects_bad(), 1);
}
#[tokio::test]
async fn edge_tls_body_immediate_eof_triggers_masking_and_bad_connect() {
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = true;
cfg.general.beobachten_minutes = 1;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let beobachten = Arc::new(BeobachtenStore::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.13:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
beobachten.clone(),
false,
));
client_side
.write_all(&[0x16, 0x03, 0x01, 0x00, 100])
.await
.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handler)
.await
.unwrap();
assert_eq!(stats.get_connects_bad(), 1);
}
#[tokio::test]
async fn security_classic_mode_disabled_masks_valid_length_payload() {
let mut cfg = ProxyConfig::default();
cfg.general.modes.classic = false;
cfg.general.modes.secure = false;
cfg.censorship.mask = true;
let config = Arc::new(cfg);
let stats = Arc::new(Stats::new());
let (server_side, mut client_side) = duplex(4096);
let handler = tokio::spawn(handle_client_stream(
server_side,
"198.51.100.15:55000".parse().unwrap(),
config,
stats.clone(),
Arc::new(UpstreamManager::new(
vec![],
1,
1,
1,
1,
false,
stats.clone(),
)),
Arc::new(ReplayChecker::new(128, Duration::from_secs(60))),
Arc::new(BufferPool::new()),
Arc::new(SecureRandom::new()),
None,
Arc::new(RouteRuntimeController::new(RelayRouteMode::Direct)),
None,
Arc::new(UserIpTracker::new()),
Arc::new(BeobachtenStore::new()),
false,
));
client_side.write_all(&vec![0xEF; 64]).await.unwrap();
client_side.shutdown().await.unwrap();
let _ = tokio::time::timeout(Duration::from_secs(2), handler)
.await
.unwrap();
assert_eq!(stats.get_connects_bad(), 1);
}
#[tokio::test]
async fn concurrency_ip_tracker_strict_limit_one_rapid_churn() {
let user = "rapid-churn-user";
let mut config = ProxyConfig::default();
config
.access
.user_max_tcp_conns
.insert(user.to_string(), 10);
let stats = Arc::new(Stats::new());
let ip_tracker = Arc::new(UserIpTracker::new());
ip_tracker.set_user_limit(user, 1).await;
let peer = "198.51.100.16:55000".parse().unwrap();
for _ in 0..500 {
let reservation = RunningClientHandler::acquire_user_connection_reservation_static(
user,
&config,
stats.clone(),
peer,
ip_tracker.clone(),
)
.await
.unwrap();
reservation.release().await;
}
assert_eq!(stats.get_user_curr_connects(user), 0);
assert_eq!(ip_tracker.get_active_ip_count(user).await, 0);
}
#[tokio::test]
async fn quirk_read_with_progress_zero_length_buffer_returns_zero_immediately() {
let (mut server_side, _client_side) = duplex(4096);
let mut empty_buf = &mut [][..];
let result = tokio::time::timeout(
Duration::from_millis(50),
read_with_progress(&mut server_side, &mut empty_buf),
)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().unwrap(), 0);
}
#[tokio::test]
async fn stress_read_with_progress_cancellation_safety() {
let (mut server_side, mut client_side) = duplex(4096);
client_side.write_all(b"12345").await.unwrap();
let mut buf = [0u8; 10];
let result = tokio::time::timeout(
Duration::from_millis(50),
read_with_progress(&mut server_side, &mut buf),
)
.await;
assert!(result.is_err());
client_side.write_all(b"67890").await.unwrap();
let mut buf2 = [0u8; 5];
server_side.read_exact(&mut buf2).await.unwrap();
assert_eq!(&buf2, b"67890");
}

View File

@ -7,6 +7,9 @@ use crate::protocol::tls;
use crate::proxy::handshake::HandshakeSuccess;
use crate::stream::{CryptoReader, CryptoWriter};
use crate::transport::proxy_protocol::ProxyProtocolV1Builder;
use rand::Rng;
use rand::SeedableRng;
use rand::rngs::StdRng;
use std::net::Ipv4Addr;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::{TcpListener, TcpStream};
@ -25,6 +28,220 @@ fn synthetic_local_addr_uses_configured_port_for_max() {
assert_eq!(addr.port(), u16::MAX);
}
#[test]
fn handshake_timeout_with_mask_grace_includes_mask_margin() {
let mut config = ProxyConfig::default();
config.timeouts.client_handshake = 2;
config.censorship.mask = false;
assert_eq!(
handshake_timeout_with_mask_grace(&config),
Duration::from_secs(2)
);
config.censorship.mask = true;
assert_eq!(
handshake_timeout_with_mask_grace(&config),
Duration::from_millis(2750),
"mask mode extends handshake timeout by 750 ms"
);
}
#[tokio::test]
async fn read_with_progress_reads_partial_buffers_before_eof() {
let data = vec![0xAA, 0xBB, 0xCC];
let mut reader = std::io::Cursor::new(data);
let mut buf = [0u8; 5];
let read = read_with_progress(&mut reader, &mut buf).await.unwrap();
assert_eq!(read, 3);
assert_eq!(&buf[..3], &[0xAA, 0xBB, 0xCC]);
}
#[test]
fn is_trusted_proxy_source_respects_cidr_list_and_empty_rejects_all() {
let peer: IpAddr = "10.10.10.10".parse().unwrap();
assert!(!is_trusted_proxy_source(peer, &[]));
let trusted = vec!["10.0.0.0/8".parse().unwrap()];
assert!(is_trusted_proxy_source(peer, &trusted));
let not_trusted = vec!["192.0.2.0/24".parse().unwrap()];
assert!(!is_trusted_proxy_source(peer, &not_trusted));
}
#[test]
fn is_trusted_proxy_source_accepts_cidr_zero_zero_as_global_cidr() {
let peer: IpAddr = "203.0.113.42".parse().unwrap();
let trust_all = vec!["0.0.0.0/0".parse().unwrap()];
assert!(is_trusted_proxy_source(peer, &trust_all));
let peer_v6: IpAddr = "2001:db8::1".parse().unwrap();
let trust_all_v6 = vec!["::/0".parse().unwrap()];
assert!(is_trusted_proxy_source(peer_v6, &trust_all_v6));
}
struct ErrorReader;
impl tokio::io::AsyncRead for ErrorReader {
fn poll_read(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
_buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"fake error",
)))
}
}
#[tokio::test]
async fn read_with_progress_returns_error_from_failed_reader() {
let mut reader = ErrorReader;
let mut buf = [0u8; 8];
let err = read_with_progress(&mut reader, &mut buf).await.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
}
#[test]
fn handshake_timeout_with_mask_grace_handles_maximum_values_without_overflow() {
let mut config = ProxyConfig::default();
config.timeouts.client_handshake = u64::MAX;
config.censorship.mask = true;
let timeout = handshake_timeout_with_mask_grace(&config);
assert!(timeout >= Duration::from_secs(u64::MAX));
}
#[tokio::test]
async fn read_with_progress_zero_length_buffer_returns_zero() {
let data = vec![1, 2, 3];
let mut reader = std::io::Cursor::new(data);
let mut buf = [];
let read = read_with_progress(&mut reader, &mut buf).await.unwrap();
assert_eq!(read, 0);
}
#[test]
fn handshake_timeout_without_mask_is_exact_base() {
let mut config = ProxyConfig::default();
config.timeouts.client_handshake = 7;
config.censorship.mask = false;
assert_eq!(
handshake_timeout_with_mask_grace(&config),
Duration::from_secs(7)
);
}
#[test]
fn handshake_timeout_mask_enabled_adds_750ms() {
let mut config = ProxyConfig::default();
config.timeouts.client_handshake = 3;
config.censorship.mask = true;
assert_eq!(
handshake_timeout_with_mask_grace(&config),
Duration::from_millis(3750)
);
}
#[tokio::test]
async fn read_with_progress_full_then_empty_transition() {
let data = vec![0x10, 0x20];
let mut cursor = std::io::Cursor::new(data);
let mut buf = [0u8; 2];
assert_eq!(read_with_progress(&mut cursor, &mut buf).await.unwrap(), 2);
assert_eq!(read_with_progress(&mut cursor, &mut buf).await.unwrap(), 0);
}
#[tokio::test]
async fn read_with_progress_fragmented_io_works_over_multiple_calls() {
let mut cursor = std::io::Cursor::new(vec![1, 2, 3, 4, 5]);
let mut result = Vec::new();
for chunk_size in 1..=5 {
let mut b = vec![0u8; chunk_size];
let n = read_with_progress(&mut cursor, &mut b).await.unwrap();
result.extend_from_slice(&b[..n]);
if n == 0 {
break;
}
}
assert_eq!(result, vec![1, 2, 3, 4, 5]);
}
#[tokio::test]
async fn read_with_progress_stress_randomized_chunk_sizes() {
for i in 0..128 {
let mut rng = StdRng::seed_from_u64(i as u64 + 1);
let mut input: Vec<u8> = (0..(i % 41)).map(|_| rng.next_u32() as u8).collect();
let mut cursor = std::io::Cursor::new(input.clone());
let mut collected = Vec::new();
while cursor.position() < cursor.get_ref().len() as u64 {
let chunk = 1 + (rng.next_u32() as usize % 8);
let mut b = vec![0u8; chunk];
let read = read_with_progress(&mut cursor, &mut b).await.unwrap();
collected.extend_from_slice(&b[..read]);
if read == 0 {
break;
}
}
assert_eq!(collected, input);
}
}
#[test]
fn is_trusted_proxy_source_boundary_narrow_ipv4() {
let matching = "172.16.0.1".parse().unwrap();
let not_matching = "172.15.255.255".parse().unwrap();
let cidr = vec!["172.16.0.0/12".parse().unwrap()];
assert!(is_trusted_proxy_source(matching, &cidr));
assert!(!is_trusted_proxy_source(not_matching, &cidr));
}
#[test]
fn is_trusted_proxy_source_rejects_out_of_family_ipv6_v4_cidr() {
let peer = "2001:db8::1".parse().unwrap();
let cidr = vec!["10.0.0.0/8".parse().unwrap()];
assert!(!is_trusted_proxy_source(peer, &cidr));
}
#[test]
fn wrap_tls_application_record_reserved_chunks_look_reasonable() {
let payload = vec![0xAA; 1 + (u16::MAX as usize) + 2];
let wrapped = wrap_tls_application_record(&payload);
assert!(wrapped.len() > payload.len());
assert!(wrapped.contains(&0x17));
}
#[test]
fn wrap_tls_application_record_roundtrip_size_check() {
let payload_len = 3000;
let payload = vec![0x55; payload_len];
let wrapped = wrap_tls_application_record(&payload);
let mut idx = 0;
let mut consumed = 0;
while idx + 5 <= wrapped.len() {
assert_eq!(wrapped[idx], 0x17);
let len = u16::from_be_bytes([wrapped[idx + 3], wrapped[idx + 4]]) as usize;
consumed += len;
idx += 5 + len;
if idx >= wrapped.len() {
break;
}
}
assert_eq!(consumed, payload_len);
}
fn make_crypto_reader<R>(reader: R) -> CryptoReader<R>
where
R: tokio::io::AsyncRead + Unpin,
@ -43,6 +260,11 @@ where
CryptoWriter::new(writer, AesCtr::new(&key, iv), 8 * 1024)
}
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
}
#[tokio::test]
async fn user_connection_reservation_drop_enqueues_cleanup_synchronously() {
let ip_tracker = Arc::new(crate::ip_tracker::UserIpTracker::new());
@ -2841,7 +3063,7 @@ async fn quota_rejection_does_not_reserve_ip_or_trigger_rollback() {
.insert("user".to_string(), 1024);
let stats = Stats::new();
stats.add_user_octets_from("user", 1024);
preload_user_quota(&stats, "user", 1024);
let ip_tracker = UserIpTracker::new();
let peer_addr: SocketAddr = "203.0.113.211:50001".parse().unwrap();

View File

@ -25,13 +25,26 @@ fn wrap_tls_application_record_oversized_payload_is_chunked_without_truncation()
let len = u16::from_be_bytes([record[offset + 3], record[offset + 4]]) as usize;
let body_start = offset + 5;
let body_end = body_start + len;
assert!(body_end <= record.len(), "declared TLS record length must be in-bounds");
assert!(
body_end <= record.len(),
"declared TLS record length must be in-bounds"
);
recovered.extend_from_slice(&record[body_start..body_end]);
offset = body_end;
frames += 1;
}
assert_eq!(offset, record.len(), "record parser must consume exact output size");
assert_eq!(frames, 2, "oversized payload should split into exactly two records");
assert_eq!(recovered, payload, "chunked records must preserve full payload");
assert_eq!(
offset,
record.len(),
"record parser must consume exact output size"
);
assert_eq!(
frames, 2,
"oversized payload should split into exactly two records"
);
assert_eq!(
recovered, payload,
"chunked records must preserve full payload"
);
}

View File

@ -773,8 +773,7 @@ fn anchored_open_nix_path_writes_expected_lines() {
"target/telemt-unknown-dc-anchored-open-ok-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let _ = fs::remove_file(&sanitized.resolved_path);
let mut first = open_unknown_dc_log_append_anchored(&sanitized)
@ -787,7 +786,10 @@ fn anchored_open_nix_path_writes_expected_lines() {
let content =
fs::read_to_string(&sanitized.resolved_path).expect("anchored log file must be readable");
let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect();
let lines: Vec<&str> = content
.lines()
.filter(|line| !line.trim().is_empty())
.collect();
assert_eq!(lines.len(), 2, "expected one line per anchored append call");
assert!(
lines.contains(&"dc_idx=31200") && lines.contains(&"dc_idx=31201"),
@ -811,8 +813,7 @@ fn anchored_open_parallel_appends_preserve_line_integrity() {
"target/telemt-unknown-dc-anchored-open-parallel-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let _ = fs::remove_file(&sanitized.resolved_path);
let mut workers = Vec::new();
@ -831,8 +832,15 @@ fn anchored_open_parallel_appends_preserve_line_integrity() {
let content =
fs::read_to_string(&sanitized.resolved_path).expect("parallel log file must be readable");
let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect();
assert_eq!(lines.len(), 64, "expected one complete line per worker append");
let lines: Vec<&str> = content
.lines()
.filter(|line| !line.trim().is_empty())
.collect();
assert_eq!(
lines.len(),
64,
"expected one complete line per worker append"
);
for line in lines {
assert!(
line.starts_with("dc_idx="),
@ -867,8 +875,7 @@ fn anchored_open_creates_private_0600_file_permissions() {
"target/telemt-unknown-dc-anchored-perms-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let _ = fs::remove_file(&sanitized.resolved_path);
let mut file = open_unknown_dc_log_append_anchored(&sanitized)
@ -905,8 +912,7 @@ fn anchored_open_rejects_existing_symlink_target() {
"target/telemt-unknown-dc-anchored-symlink-target-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let outside = std::env::temp_dir().join(format!(
"telemt-unknown-dc-anchored-symlink-outside-{}.log",
@ -943,8 +949,7 @@ fn anchored_open_high_contention_multi_write_preserves_complete_lines() {
"target/telemt-unknown-dc-anchored-contention-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let _ = fs::remove_file(&sanitized.resolved_path);
let workers = 24usize;
@ -970,7 +975,10 @@ fn anchored_open_high_contention_multi_write_preserves_complete_lines() {
let content = fs::read_to_string(&sanitized.resolved_path)
.expect("contention output file must be readable");
let lines: Vec<&str> = content.lines().filter(|line| !line.trim().is_empty()).collect();
let lines: Vec<&str> = content
.lines()
.filter(|line| !line.trim().is_empty())
.collect();
assert_eq!(
lines.len(),
workers * rounds,
@ -1014,8 +1022,7 @@ fn append_unknown_dc_line_returns_error_for_read_only_descriptor() {
"target/telemt-unknown-dc-append-ro-{}/unknown-dc.log",
std::process::id()
);
let sanitized =
sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
let sanitized = sanitize_unknown_dc_log_path(&rel_candidate).expect("candidate must sanitize");
fs::write(&sanitized.resolved_path, "seed\n").expect("seed file must be writable");
let mut readonly = std::fs::OpenOptions::new()

View File

@ -0,0 +1,719 @@
use super::*;
use crate::crypto::{AesCtr, sha256, sha256_hmac};
use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
// --- Helpers ---
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg.general.modes.secure = true;
cfg.general.modes.classic = true;
cfg.general.modes.tls = true;
cfg
}
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
let session_id_len: usize = 32;
let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn make_valid_mtproto_handshake(
secret_hex: &str,
proto_tag: ProtoTag,
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode");
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]
.iter_mut()
.enumerate()
{
*b = (idx as u8).wrapping_add(1);
}
let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN];
let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut stream = AesCtr::new(&dec_key, dec_iv);
let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]);
let mut target_plain = [0u8; HANDSHAKE_LEN];
target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
for idx in PROTO_TAG_POS..HANDSHAKE_LEN {
handshake[idx] = target_plain[idx] ^ keystream[idx];
}
handshake
}
fn make_valid_tls_client_hello_with_alpn(
secret: &[u8],
timestamp: u32,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
body.push(32);
body.extend_from_slice(&[0x42u8; 32]);
body.extend_from_slice(&2u16.to_be_bytes());
body.extend_from_slice(&[0x13, 0x01]);
body.push(1);
body.push(0);
let mut ext_blob = Vec::new();
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
alpn_list.push(proto.len() as u8);
alpn_list.extend_from_slice(proto);
}
let mut alpn_data = Vec::new();
alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
alpn_data.extend_from_slice(&alpn_list);
ext_blob.extend_from_slice(&0x0010u16.to_be_bytes());
ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&alpn_data);
}
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_blob);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut record = Vec::new();
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&[0x03, 0x01]);
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
record.extend_from_slice(&handshake);
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &record);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest);
record
}
// --- Category 1: Edge Cases & Protocol Boundaries ---
#[tokio::test]
async fn tls_minimum_viable_length_boundary() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x11u8; 16];
let config = test_config_with_secret_hex("11111111111111111111111111111111");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.1:12345".parse().unwrap();
let min_len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1;
let mut exact_min_handshake = vec![0x42u8; min_len];
exact_min_handshake[min_len - 1] = 0;
exact_min_handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let digest = sha256_hmac(&secret, &exact_min_handshake);
exact_min_handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
let res = handle_tls_handshake(
&exact_min_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(
matches!(res, HandshakeResult::Success(_)),
"Exact minimum length TLS handshake must succeed"
);
let short_handshake = vec![0x42u8; min_len - 1];
let res_short = handle_tls_handshake(
&short_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(
matches!(res_short, HandshakeResult::BadClient { .. }),
"Handshake 1 byte shorter than minimum must fail closed"
);
}
#[tokio::test]
async fn mtproto_extreme_dc_index_serialization() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "22222222222222222222222222222222";
let config = test_config_with_secret_hex(secret_hex);
for (idx, extreme_dc) in [i16::MIN, i16::MAX, -1, 0].into_iter().enumerate() {
// Keep replay state independent per case so we validate dc_idx encoding,
// not duplicate-handshake rejection behavior.
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 0, 2, 2)), 12345 + idx as u16);
let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, extreme_dc);
let res = handle_mtproto_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
match res {
HandshakeResult::Success((_, _, success)) => {
assert_eq!(
success.dc_idx, extreme_dc,
"Extreme DC index {} must serialize/deserialize perfectly",
extreme_dc
);
}
_ => panic!(
"MTProto handshake with extreme DC index {} failed",
extreme_dc
),
}
}
}
#[tokio::test]
async fn alpn_strict_case_and_padding_rejection() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x33u8; 16];
let mut config = test_config_with_secret_hex("33333333333333333333333333333333");
config.censorship.alpn_enforce = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.3:12345".parse().unwrap();
let bad_alpns: &[&[u8]] = &[b"H2", b"h2\0", b" http/1.1", b"http/1.1\n"];
for bad_alpn in bad_alpns {
let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[*bad_alpn]);
let res = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(
matches!(res, HandshakeResult::BadClient { .. }),
"ALPN strict enforcement must reject {:?}",
bad_alpn
);
}
}
#[test]
fn ipv4_mapped_ipv6_bucketing_anomaly() {
let ipv4_mapped_1 = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc000, 0x0201));
let ipv4_mapped_2 = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc633, 0x6402));
let norm_1 = normalize_auth_probe_ip(ipv4_mapped_1);
let norm_2 = normalize_auth_probe_ip(ipv4_mapped_2);
assert_eq!(
norm_1, norm_2,
"IPv4-mapped IPv6 addresses must collapse into the same /64 bucket (::0)"
);
assert_eq!(
norm_1,
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
"The bucket must be exactly ::0"
);
}
// --- Category 2: Adversarial & Black Hat ---
#[tokio::test]
async fn mtproto_invalid_ciphertext_does_not_poison_replay_cache() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "55555555555555555555555555555555";
let config = test_config_with_secret_hex(secret_hex);
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.5:12345".parse().unwrap();
let valid_handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Secure, 1);
let mut invalid_handshake = valid_handshake;
invalid_handshake[SKIP_LEN + PREKEY_LEN + IV_LEN + 1] ^= 0xFF;
let res_invalid = handle_mtproto_handshake(
&invalid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(res_invalid, HandshakeResult::BadClient { .. }));
let res_valid = handle_mtproto_handshake(
&valid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(
matches!(res_valid, HandshakeResult::Success(_)),
"Invalid MTProto ciphertext must not poison the replay cache"
);
}
#[tokio::test]
async fn tls_invalid_session_does_not_poison_replay_cache() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x66u8; 16];
let config = test_config_with_secret_hex("66666666666666666666666666666666");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.6:12345".parse().unwrap();
let valid_handshake = make_valid_tls_handshake(&secret, 0);
let mut invalid_handshake = valid_handshake.clone();
let session_idx = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1;
invalid_handshake[session_idx] ^= 0xFF;
let res_invalid = handle_tls_handshake(
&invalid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(res_invalid, HandshakeResult::BadClient { .. }));
let res_valid = handle_tls_handshake(
&valid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(
matches!(res_valid, HandshakeResult::Success(_)),
"Invalid TLS payload must not poison the replay cache"
);
}
#[tokio::test]
async fn server_hello_delay_timing_neutrality_on_hmac_failure() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x77u8; 16];
let mut config = test_config_with_secret_hex("77777777777777777777777777777777");
config.censorship.server_hello_delay_min_ms = 50;
config.censorship.server_hello_delay_max_ms = 50;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.7:12345".parse().unwrap();
let mut invalid_handshake = make_valid_tls_handshake(&secret, 0);
invalid_handshake[tls::TLS_DIGEST_POS] ^= 0xFF;
let start = Instant::now();
let res = handle_tls_handshake(
&invalid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
let elapsed = start.elapsed();
assert!(matches!(res, HandshakeResult::BadClient { .. }));
assert!(
elapsed >= Duration::from_millis(45),
"Invalid HMAC must still incur the configured ServerHello delay to prevent timing side-channels"
);
}
#[tokio::test]
async fn server_hello_delay_inversion_resilience() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x88u8; 16];
let mut config = test_config_with_secret_hex("88888888888888888888888888888888");
config.censorship.server_hello_delay_min_ms = 100;
config.censorship.server_hello_delay_max_ms = 10;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.8:12345".parse().unwrap();
let valid_handshake = make_valid_tls_handshake(&secret, 0);
let start = Instant::now();
let res = handle_tls_handshake(
&valid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
let elapsed = start.elapsed();
assert!(matches!(res, HandshakeResult::Success(_)));
assert!(
elapsed >= Duration::from_millis(90),
"Delay logic must gracefully handle min > max inversions via max.max(min)"
);
}
#[tokio::test]
async fn mixed_valid_and_invalid_user_secrets_configuration() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let _warn_guard = warned_secrets_test_lock().lock().unwrap();
clear_warned_secrets_for_testing();
let mut config = ProxyConfig::default();
config.access.ignore_time_skew = true;
for i in 0..9 {
let bad_secret = if i % 2 == 0 { "badhex!" } else { "1122" };
config
.access
.users
.insert(format!("bad_user_{}", i), bad_secret.to_string());
}
let valid_secret_hex = "99999999999999999999999999999999";
config
.access
.users
.insert("good_user".to_string(), valid_secret_hex.to_string());
config.general.modes.secure = true;
config.general.modes.classic = true;
config.general.modes.tls = true;
let secret = [0x99u8; 16];
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.9:12345".parse().unwrap();
let valid_handshake = make_valid_tls_handshake(&secret, 0);
let res = handle_tls_handshake(
&valid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(
matches!(res, HandshakeResult::Success(_)),
"Proxy must gracefully skip invalid secrets and authenticate the valid one"
);
}
#[tokio::test]
async fn tls_emulation_fallback_when_cache_missing() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0xAAu8; 16];
let mut config = test_config_with_secret_hex("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa");
config.censorship.tls_emulation = true;
config.general.modes.tls = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.10:12345".parse().unwrap();
let valid_handshake = make_valid_tls_handshake(&secret, 0);
let res = handle_tls_handshake(
&valid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(
matches!(res, HandshakeResult::Success(_)),
"TLS emulation must gracefully fall back to standard ServerHello if cache is missing"
);
}
#[tokio::test]
async fn classic_mode_over_tls_transport_protocol_confusion() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb";
let mut config = test_config_with_secret_hex(secret_hex);
config.general.modes.classic = true;
config.general.modes.tls = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.11:12345".parse().unwrap();
let handshake = make_valid_mtproto_handshake(secret_hex, ProtoTag::Intermediate, 1);
let res = handle_mtproto_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
true,
None,
)
.await;
assert!(
matches!(res, HandshakeResult::Success(_)),
"Intermediate tag over TLS must succeed if classic mode is enabled, locking in cross-transport behavior"
);
}
#[test]
fn generate_tg_nonce_never_emits_reserved_bytes() {
let client_enc_key = [0xCCu8; 32];
let client_enc_iv = 123456789u128;
let rng = SecureRandom::new();
for _ in 0..10_000 {
let (nonce, _, _, _, _) = generate_tg_nonce(
ProtoTag::Secure,
1,
&client_enc_key,
client_enc_iv,
&rng,
false,
);
assert!(
!RESERVED_NONCE_FIRST_BYTES.contains(&nonce[0]),
"Nonce must never start with reserved bytes"
);
let first_four: [u8; 4] = [nonce[0], nonce[1], nonce[2], nonce[3]];
assert!(
!RESERVED_NONCE_BEGINNINGS.contains(&first_four),
"Nonce must never match reserved 4-byte beginnings"
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn dashmap_concurrent_saturation_stress() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let ip_a: IpAddr = "192.0.2.13".parse().unwrap();
let ip_b: IpAddr = "198.51.100.13".parse().unwrap();
let mut tasks = Vec::new();
for i in 0..100 {
let target_ip = if i % 2 == 0 { ip_a } else { ip_b };
tasks.push(tokio::spawn(async move {
for _ in 0..50 {
auth_probe_record_failure(target_ip, Instant::now());
}
}));
}
for task in tasks {
task.await
.expect("Task panicked during concurrent DashMap stress");
}
assert!(
auth_probe_is_throttled_for_testing(ip_a),
"IP A must be throttled after concurrent stress"
);
assert!(
auth_probe_is_throttled_for_testing(ip_b),
"IP B must be throttled after concurrent stress"
);
}
#[test]
fn prototag_invalid_bytes_fail_closed() {
let invalid_tags: [[u8; 4]; 5] = [
[0, 0, 0, 0],
[0xFF, 0xFF, 0xFF, 0xFF],
[0xDE, 0xAD, 0xBE, 0xEF],
[0xDD, 0xDD, 0xDD, 0xDE],
[0x11, 0x22, 0x33, 0x44],
];
for tag in invalid_tags {
assert_eq!(
ProtoTag::from_bytes(tag),
None,
"Invalid ProtoTag bytes {:?} must fail closed",
tag
);
}
}
#[test]
fn auth_probe_eviction_hash_collision_stress() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let state = auth_probe_state_map();
let now = Instant::now();
for i in 0..10_000u32 {
let ip = IpAddr::V4(Ipv4Addr::new(10, 0, (i >> 8) as u8, (i & 0xFF) as u8));
auth_probe_record_failure_with_state(state, ip, now);
}
assert!(
state.len() <= AUTH_PROBE_TRACK_MAX_ENTRIES,
"Eviction logic must successfully bound the map size under heavy insertion stress"
);
}
#[test]
fn encrypt_tg_nonce_with_ciphers_advances_counter_correctly() {
let client_enc_key = [0xDDu8; 32];
let client_enc_iv = 987654321u128;
let rng = SecureRandom::new();
let (nonce, _, _, _, _) = generate_tg_nonce(
ProtoTag::Secure,
2,
&client_enc_key,
client_enc_iv,
&rng,
false,
);
let (_, mut returned_encryptor, _) = encrypt_tg_nonce_with_ciphers(&nonce);
let zeros = [0u8; 64];
let returned_keystream = returned_encryptor.encrypt(&zeros);
let enc_key_iv = &nonce[SKIP_LEN..SKIP_LEN + KEY_LEN + IV_LEN];
let mut expected_enc_key = [0u8; 32];
expected_enc_key.copy_from_slice(&enc_key_iv[..KEY_LEN]);
let mut expected_enc_iv_arr = [0u8; IV_LEN];
expected_enc_iv_arr.copy_from_slice(&enc_key_iv[KEY_LEN..]);
let expected_enc_iv = u128::from_be_bytes(expected_enc_iv_arr);
let mut manual_encryptor = AesCtr::new(&expected_enc_key, expected_enc_iv);
let mut manual_input = Vec::new();
manual_input.extend_from_slice(&nonce);
manual_input.extend_from_slice(&zeros);
let manual_output = manual_encryptor.encrypt(&manual_input);
assert_eq!(
returned_keystream,
&manual_output[64..128],
"encrypt_tg_nonce_with_ciphers must correctly advance the AES-CTR counter by exactly the nonce length"
);
}

View File

@ -0,0 +1,96 @@
use super::*;
use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn adversarial_large_state_offsets_escape_first_scan_window() {
let _guard = auth_probe_test_guard();
let base = Instant::now();
let state_len = 65_536usize;
let scan_limit = 1_024usize;
let mut saw_offset_outside_first_window = false;
for i in 0..8_192u64 {
let ip = IpAddr::V4(Ipv4Addr::new(
((i >> 16) & 0xff) as u8,
((i >> 8) & 0xff) as u8,
(i & 0xff) as u8,
((i.wrapping_mul(131)) & 0xff) as u8,
));
let now = base + Duration::from_nanos(i);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
if start >= scan_limit {
saw_offset_outside_first_window = true;
break;
}
}
assert!(
saw_offset_outside_first_window,
"scan start offset must cover the full auth-probe state, not only the first scan window"
);
}
#[test]
fn stress_large_state_offsets_cover_many_scan_windows() {
let _guard = auth_probe_test_guard();
let base = Instant::now();
let state_len = 65_536usize;
let scan_limit = 1_024usize;
let mut covered_windows = HashSet::new();
for i in 0..16_384u64 {
let ip = IpAddr::V4(Ipv4Addr::new(
((i >> 16) & 0xff) as u8,
((i >> 8) & 0xff) as u8,
(i & 0xff) as u8,
((i.wrapping_mul(17)) & 0xff) as u8,
));
let now = base + Duration::from_micros(i);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
covered_windows.insert(start / scan_limit);
}
assert!(
covered_windows.len() >= 16,
"eviction scan must not collapse to a tiny hot zone; covered windows={} out of {}",
covered_windows.len(),
state_len / scan_limit
);
}
#[test]
fn light_fuzz_offset_always_stays_inside_state_len() {
let _guard = auth_probe_test_guard();
let mut seed = 0xC0FF_EE12_3456_789Au64;
let base = Instant::now();
for _ in 0..8_192usize {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let ip = IpAddr::V4(Ipv4Addr::new(
(seed >> 24) as u8,
(seed >> 16) as u8,
(seed >> 8) as u8,
seed as u8,
));
let state_len = ((seed >> 16) as usize % 200_000).saturating_add(1);
let scan_limit = ((seed >> 40) as usize % 2_048).saturating_add(1);
let now = base + Duration::from_nanos(seed & 0x0fff);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
assert!(
start < state_len,
"scan offset must stay inside state length"
);
}
}

View File

@ -0,0 +1,99 @@
use super::*;
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn edge_zero_state_len_yields_zero_start_offset() {
let _guard = auth_probe_test_guard();
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 44));
let now = Instant::now();
assert_eq!(
auth_probe_scan_start_offset(ip, now, 0, 16),
0,
"empty map must not produce non-zero scan offset"
);
}
#[test]
fn adversarial_large_state_must_allow_start_offset_outside_scan_budget_window() {
let _guard = auth_probe_test_guard();
let base = Instant::now();
let scan_limit = 16usize;
let state_len = 65_536usize;
let mut saw_offset_outside_window = false;
for i in 0..2048u32 {
let ip = IpAddr::V4(Ipv4Addr::new(
203,
((i >> 16) & 0xff) as u8,
((i >> 8) & 0xff) as u8,
(i & 0xff) as u8,
));
let now = base + Duration::from_micros(i as u64);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
assert!(
start < state_len,
"start offset must stay within state length; start={start}, len={state_len}"
);
if start >= scan_limit {
saw_offset_outside_window = true;
break;
}
}
assert!(
saw_offset_outside_window,
"large-state eviction must sample beyond the first scan window"
);
}
#[test]
fn positive_state_smaller_than_scan_limit_caps_to_state_len() {
let _guard = auth_probe_test_guard();
let ip = IpAddr::V4(Ipv4Addr::new(192, 0, 2, 17));
let now = Instant::now();
for state_len in 1..32usize {
let start = auth_probe_scan_start_offset(ip, now, state_len, 64);
assert!(
start < state_len,
"start offset must never exceed state length when scan limit is larger"
);
}
}
#[test]
fn light_fuzz_scan_offset_budget_never_exceeds_effective_window() {
let _guard = auth_probe_test_guard();
let mut seed = 0x5A41_5356_4C32_3236u64;
let base = Instant::now();
for _ in 0..4096 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let ip = IpAddr::V4(Ipv4Addr::new(
(seed >> 24) as u8,
(seed >> 16) as u8,
(seed >> 8) as u8,
seed as u8,
));
let state_len = ((seed >> 8) as usize % 131_072).saturating_add(1);
let scan_limit = ((seed >> 32) as usize % 512).saturating_add(1);
let now = base + Duration::from_nanos(seed & 0xffff);
let start = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
assert!(
start < state_len,
"scan offset must stay inside state length"
);
}
}

View File

@ -0,0 +1,116 @@
use super::*;
use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, Instant};
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
#[test]
fn positive_same_ip_moving_time_yields_diverse_scan_offsets() {
let _guard = auth_probe_test_guard();
let ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 77));
let base = Instant::now();
let mut uniq = HashSet::new();
for i in 0..512u64 {
let now = base + Duration::from_nanos(i);
let offset = auth_probe_scan_start_offset(ip, now, 65_536, 16);
uniq.insert(offset);
}
assert!(
uniq.len() >= 256,
"offset randomization collapsed unexpectedly for same-ip moving-time samples (uniq={})",
uniq.len()
);
}
#[test]
fn adversarial_many_ips_same_time_spreads_offsets_without_bias_collapse() {
let _guard = auth_probe_test_guard();
let now = Instant::now();
let mut uniq = HashSet::new();
for i in 0..1024u32 {
let ip = IpAddr::V4(Ipv4Addr::new(
(i >> 16) as u8,
(i >> 8) as u8,
i as u8,
(255 - (i as u8)),
));
uniq.insert(auth_probe_scan_start_offset(ip, now, 65_536, 16));
}
assert!(
uniq.len() >= 512,
"scan offset distribution collapsed unexpectedly across adversarial peer set (uniq={})",
uniq.len()
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_failure_churn_under_saturation_remains_capped_and_live() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let start = Instant::now();
let mut workers = Vec::new();
for worker in 0..8u8 {
workers.push(tokio::spawn(async move {
for i in 0..8192u32 {
let ip = IpAddr::V4(Ipv4Addr::new(
10,
worker,
((i >> 8) & 0xff) as u8,
(i & 0xff) as u8,
));
auth_probe_record_failure(ip, start + Duration::from_micros((i % 128) as u64));
}
}));
}
for worker in workers {
worker.await.expect("saturation worker must not panic");
}
assert!(
auth_probe_state_map().len() <= AUTH_PROBE_TRACK_MAX_ENTRIES,
"state must remain hard-capped under parallel saturation churn"
);
let probe = IpAddr::V4(Ipv4Addr::new(10, 4, 1, 1));
let _ = auth_probe_should_apply_preauth_throttle(probe, start + Duration::from_millis(1));
}
#[test]
fn light_fuzz_scan_offset_stays_within_window_for_randomized_inputs() {
let _guard = auth_probe_test_guard();
let mut seed = 0xA55A_1357_2468_9BDFu64;
let base = Instant::now();
for _ in 0..8192 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let ip = IpAddr::V4(Ipv4Addr::new(
(seed >> 24) as u8,
(seed >> 16) as u8,
(seed >> 8) as u8,
seed as u8,
));
let state_len = ((seed >> 8) as usize % 200_000).saturating_add(1);
let scan_limit = ((seed >> 40) as usize % 1024).saturating_add(1);
let now = base + Duration::from_nanos(seed & 0x1fff);
let offset = auth_probe_scan_start_offset(ip, now, state_len, scan_limit);
assert!(
offset < state_len,
"scan offset must always remain inside state length"
);
}
}

View File

@ -0,0 +1,42 @@
use super::*;
fn handshake_source() -> &'static str {
include_str!("../handshake.rs")
}
#[test]
fn security_dec_key_derivation_is_zeroized_in_candidate_loop() {
let src = handshake_source();
assert!(
src.contains("let dec_key = Zeroizing::new(sha256(&dec_key_input));"),
"candidate-loop dec_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths"
);
}
#[test]
fn security_enc_key_derivation_is_zeroized_in_candidate_loop() {
let src = handshake_source();
assert!(
src.contains("let enc_key = Zeroizing::new(sha256(&enc_key_input));"),
"candidate-loop enc_key derivation must be wrapped in Zeroizing to clear secrets on early-continue paths"
);
}
#[test]
fn security_aes_ctr_initialization_uses_zeroizing_references() {
let src = handshake_source();
assert!(
src.contains("let mut decryptor = AesCtr::new(&dec_key, dec_iv);")
&& src.contains("let encryptor = AesCtr::new(&enc_key, enc_iv);"),
"AES-CTR initialization must use Zeroizing key wrappers directly without creating extra plain key variables"
);
}
#[test]
fn security_success_struct_copies_out_of_zeroizing_wrappers() {
let src = handshake_source();
assert!(
src.contains("dec_key: *dec_key,") && src.contains("enc_key: *enc_key,"),
"HandshakeSuccess construction must copy from Zeroizing wrappers so loop-local key material is dropped and zeroized"
);
}

View File

@ -0,0 +1,686 @@
use super::*;
use crate::crypto::{AesCtr, sha256, sha256_hmac};
use crate::protocol::constants::{ProtoTag, RESERVED_NONCE_BEGINNINGS, RESERVED_NONCE_FIRST_BYTES};
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use std::collections::HashSet;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Barrier;
// --- Helpers ---
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg.general.modes.secure = true;
cfg.general.modes.classic = true;
cfg.general.modes.tls = true;
cfg
}
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
let session_id_len: usize = 32;
let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn make_valid_mtproto_handshake(
secret_hex: &str,
proto_tag: ProtoTag,
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode");
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]
.iter_mut()
.enumerate()
{
*b = (idx as u8).wrapping_add(1);
}
let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN];
let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut stream = AesCtr::new(&dec_key, dec_iv);
let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]);
let mut target_plain = [0u8; HANDSHAKE_LEN];
target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
for idx in PROTO_TAG_POS..HANDSHAKE_LEN {
handshake[idx] = target_plain[idx] ^ keystream[idx];
}
handshake
}
fn make_valid_tls_client_hello_with_sni_and_alpn(
secret: &[u8],
timestamp: u32,
sni_host: &str,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
body.push(32);
body.extend_from_slice(&[0x42u8; 32]);
body.extend_from_slice(&2u16.to_be_bytes());
body.extend_from_slice(&[0x13, 0x01]);
body.push(1);
body.push(0);
let mut ext_blob = Vec::new();
let host_bytes = sni_host.as_bytes();
let mut sni_payload = Vec::new();
sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes());
sni_payload.push(0);
sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes());
sni_payload.extend_from_slice(host_bytes);
ext_blob.extend_from_slice(&0x0000u16.to_be_bytes());
ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&sni_payload);
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
alpn_list.push(proto.len() as u8);
alpn_list.extend_from_slice(proto);
}
let mut alpn_data = Vec::new();
alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
alpn_data.extend_from_slice(&alpn_list);
ext_blob.extend_from_slice(&0x0010u16.to_be_bytes());
ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&alpn_data);
}
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_blob);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut record = Vec::new();
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&[0x03, 0x01]);
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
record.extend_from_slice(&handshake);
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &record);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest);
record
}
// --- Category 1: Timing & Delay Invariants ---
#[tokio::test]
async fn server_hello_delay_bypassed_if_max_is_zero_despite_high_min() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x1Au8; 16];
let mut config = test_config_with_secret_hex("1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a1a");
config.censorship.server_hello_delay_min_ms = 5000;
config.censorship.server_hello_delay_max_ms = 0;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.101:12345".parse().unwrap();
let mut invalid_handshake = make_valid_tls_handshake(&secret, 0);
invalid_handshake[tls::TLS_DIGEST_POS] ^= 0xFF;
let fut = handle_tls_handshake(
&invalid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
);
// Deterministic assertion: with max_ms == 0 there must be no sleep path,
// so the handshake should complete promptly under a generous timeout budget.
let res = tokio::time::timeout(Duration::from_millis(250), fut)
.await
.expect("max_ms=0 should bypass artificial delay and complete quickly");
assert!(matches!(res, HandshakeResult::BadClient { .. }));
}
#[test]
fn auth_probe_backoff_extreme_fail_streak_clamps_safely() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let state = auth_probe_state_map();
let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 99));
let now = Instant::now();
state.insert(
peer_ip,
AuthProbeState {
fail_streak: u32::MAX - 1,
blocked_until: now,
last_seen: now,
},
);
auth_probe_record_failure_with_state(&state, peer_ip, now);
let updated = state.get(&peer_ip).unwrap();
assert_eq!(updated.fail_streak, u32::MAX);
let expected_blocked_until = now + Duration::from_millis(AUTH_PROBE_BACKOFF_MAX_MS);
assert_eq!(
updated.blocked_until, expected_blocked_until,
"Extreme fail streak must clamp cleanly to AUTH_PROBE_BACKOFF_MAX_MS"
);
}
#[test]
fn generate_tg_nonce_cryptographic_uniqueness_and_entropy() {
let client_enc_key = [0x2Bu8; 32];
let client_enc_iv = 1337u128;
let rng = SecureRandom::new();
let mut nonces = HashSet::new();
let mut total_set_bits = 0usize;
let iterations = 5_000;
for _ in 0..iterations {
let (nonce, _, _, _, _) = generate_tg_nonce(
ProtoTag::Secure,
2,
&client_enc_key,
client_enc_iv,
&rng,
false,
);
for byte in nonce.iter() {
total_set_bits += byte.count_ones() as usize;
}
assert!(
nonces.insert(nonce),
"generate_tg_nonce emitted a duplicate nonce! RNG is stuck."
);
}
let total_bits = iterations * HANDSHAKE_LEN * 8;
let ratio = (total_set_bits as f64) / (total_bits as f64);
assert!(
ratio > 0.48 && ratio < 0.52,
"Nonce entropy is degraded. Set bit ratio: {}",
ratio
);
}
#[tokio::test]
async fn mtproto_multi_user_decryption_isolation() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let mut config = ProxyConfig::default();
config.general.modes.secure = true;
config.access.ignore_time_skew = true;
config.access.users.insert(
"user_a".to_string(),
"11111111111111111111111111111111".to_string(),
);
config.access.users.insert(
"user_b".to_string(),
"22222222222222222222222222222222".to_string(),
);
let good_secret_hex = "33333333333333333333333333333333";
config
.access
.users
.insert("user_c".to_string(), good_secret_hex.to_string());
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.104:12345".parse().unwrap();
let valid_handshake = make_valid_mtproto_handshake(good_secret_hex, ProtoTag::Secure, 1);
let res = handle_mtproto_handshake(
&valid_handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
match res {
HandshakeResult::Success((_, _, success)) => {
assert_eq!(
success.user, "user_c",
"Decryption attempts on previous users must not corrupt the handshake buffer for the valid user"
);
}
_ => panic!(
"Multi-user MTProto handshake failed. Decryption buffer might be mutating in place."
),
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn invalid_secret_warning_lock_contention_and_bound() {
let _guard = warned_secrets_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_warned_secrets_for_testing();
let tasks = 50;
let iterations_per_task = 100;
let barrier = Arc::new(Barrier::new(tasks));
let mut handles = Vec::new();
for t in 0..tasks {
let b = barrier.clone();
handles.push(tokio::spawn(async move {
b.wait().await;
for i in 0..iterations_per_task {
let user_name = format!("contention_user_{}_{}", t, i);
warn_invalid_secret_once(&user_name, "invalid_hex", ACCESS_SECRET_BYTES, None);
}
}));
}
for handle in handles {
handle.await.unwrap();
}
let warned = INVALID_SECRET_WARNED.get().unwrap();
let guard = warned
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
assert_eq!(
guard.len(),
WARNED_SECRET_MAX_ENTRIES,
"Concurrent spam of invalid secrets must strictly bound the HashSet memory to WARNED_SECRET_MAX_ENTRIES"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn mtproto_strict_concurrent_replay_race_condition() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret_hex = "4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A4A";
let config = Arc::new(test_config_with_secret_hex(secret_hex));
let replay_checker = Arc::new(ReplayChecker::new(4096, Duration::from_secs(60)));
let valid_handshake = Arc::new(make_valid_mtproto_handshake(
secret_hex,
ProtoTag::Secure,
1,
));
let tasks = 100;
let barrier = Arc::new(Barrier::new(tasks));
let mut handles = Vec::new();
for i in 0..tasks {
let b = barrier.clone();
let cfg = config.clone();
let rc = replay_checker.clone();
let hs = valid_handshake.clone();
handles.push(tokio::spawn(async move {
let peer = SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(10, 0, 0, (i % 250) as u8)),
10000 + i as u16,
);
b.wait().await;
handle_mtproto_handshake(
&hs,
tokio::io::empty(),
tokio::io::sink(),
peer,
&cfg,
&rc,
false,
None,
)
.await
}));
}
let mut successes = 0;
let mut failures = 0;
for handle in handles {
match handle.await.unwrap() {
HandshakeResult::Success(_) => successes += 1,
HandshakeResult::BadClient { .. } => failures += 1,
_ => panic!("Unexpected error result in concurrent MTProto replay test"),
}
}
assert_eq!(
successes, 1,
"Replay cache race condition allowed multiple identical MTProto handshakes to succeed"
);
assert_eq!(
failures,
tasks - 1,
"Replay cache failed to forcefully reject concurrent duplicates"
);
}
#[tokio::test]
async fn tls_alpn_zero_length_protocol_handled_safely() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x5Bu8; 16];
let mut config = test_config_with_secret_hex("5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b5b");
config.censorship.alpn_enforce = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.107:12345".parse().unwrap();
let handshake =
make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b""]);
let res = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(
matches!(res, HandshakeResult::BadClient { .. }),
"0-length ALPN must be safely rejected without panicking"
);
}
#[tokio::test]
async fn tls_sni_massive_hostname_does_not_panic() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x6Cu8; 16];
let config = test_config_with_secret_hex("6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c6c");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.108:12345".parse().unwrap();
let massive_hostname = String::from_utf8(vec![b'a'; 65000]).unwrap();
let handshake =
make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, &massive_hostname, &[]);
let res = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(
matches!(
res,
HandshakeResult::Success(_) | HandshakeResult::BadClient { .. }
),
"Massive SNI hostname must be processed or ignored without stack overflow or panic"
);
}
#[tokio::test]
async fn tls_progressive_truncation_fuzzing_no_panics() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x7Du8; 16];
let config = test_config_with_secret_hex("7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d7d");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.109:12345".parse().unwrap();
let valid_handshake =
make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "example.com", &[b"h2"]);
let full_len = valid_handshake.len();
// Truncated corpus only: full_len is a valid baseline and should not be
// asserted as BadClient in a truncation-specific test.
for i in (0..full_len).rev() {
let truncated = &valid_handshake[..i];
let res = handle_tls_handshake(
truncated,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(
matches!(res, HandshakeResult::BadClient { .. }),
"Truncated TLS handshake at len {} must fail safely without panicking",
i
);
}
}
#[tokio::test]
async fn mtproto_pure_entropy_fuzzing_no_panics() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let config = test_config_with_secret_hex("8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e8e");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.110:12345".parse().unwrap();
let mut seeded = StdRng::seed_from_u64(0xDEADBEEFCAFE);
for _ in 0..10_000 {
let mut noise = [0u8; HANDSHAKE_LEN];
seeded.fill_bytes(&mut noise);
let res = handle_mtproto_handshake(
&noise,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(
matches!(res, HandshakeResult::BadClient { .. }),
"Pure entropy MTProto payload must fail closed and never panic"
);
}
}
#[test]
fn decode_user_secret_odd_length_hex_rejection() {
let _guard = warned_secrets_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_warned_secrets_for_testing();
let mut config = ProxyConfig::default();
config.access.users.clear();
config.access.users.insert(
"odd_user".to_string(),
"1234567890123456789012345678901".to_string(),
);
let decoded = decode_user_secrets(&config, None);
assert!(
decoded.is_empty(),
"Odd-length hex string must be gracefully rejected by hex::decode without unwrapping"
);
}
#[test]
fn saturation_grace_pre_existing_high_fail_streak_immediate_throttle() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let state = auth_probe_state_map();
let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 112));
let now = Instant::now();
let extreme_streak = AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS + 5;
state.insert(
peer_ip,
AuthProbeState {
fail_streak: extreme_streak,
blocked_until: now + Duration::from_secs(5),
last_seen: now,
},
);
{
let mut guard = auth_probe_saturation_state_lock();
*guard = Some(AuthProbeSaturationState {
fail_streak: AUTH_PROBE_BACKOFF_START_FAILS,
blocked_until: now + Duration::from_secs(5),
last_seen: now,
});
}
let is_throttled = auth_probe_should_apply_preauth_throttle(peer_ip, now);
assert!(
is_throttled,
"A peer with a pre-existing high fail streak must be immediately throttled when saturation begins, receiving no unearned grace period"
);
}
#[test]
fn auth_probe_saturation_note_resets_retention_window() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let base_time = Instant::now();
auth_probe_note_saturation(base_time);
let later = base_time + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS - 1);
auth_probe_note_saturation(later);
let check_time = base_time + Duration::from_secs(AUTH_PROBE_TRACK_RETENTION_SECS + 5);
// This call may return false if backoff has elapsed, but it must not clear
// the saturation state because `later` refreshed last_seen.
let _ = auth_probe_saturation_is_throttled_at_for_testing(check_time);
let guard = auth_probe_saturation_state_lock();
assert!(
guard.is_some(),
"Ongoing saturation notes must refresh last_seen so saturation state remains retained past the original window"
);
}
#[test]
fn mtproto_classic_tags_rejected_when_only_secure_mode_enabled() {
let mut config = ProxyConfig::default();
config.general.modes.classic = false;
config.general.modes.secure = true;
config.general.modes.tls = false;
assert!(!mode_enabled_for_proto(&config, ProtoTag::Abridged, false));
assert!(!mode_enabled_for_proto(
&config,
ProtoTag::Intermediate,
false
));
}
#[test]
fn mtproto_secure_tag_rejected_when_only_classic_mode_enabled() {
let mut config = ProxyConfig::default();
config.general.modes.classic = true;
config.general.modes.secure = false;
config.general.modes.tls = false;
assert!(!mode_enabled_for_proto(&config, ProtoTag::Secure, false));
}
#[test]
fn ipv6_localhost_and_unspecified_normalization() {
let localhost = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1));
let unspecified = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0));
let norm_local = normalize_auth_probe_ip(localhost);
let norm_unspec = normalize_auth_probe_ip(unspecified);
let expected_bucket = IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0));
assert_eq!(norm_local, expected_bucket);
assert_eq!(norm_unspec, expected_bucket);
}

View File

@ -0,0 +1,340 @@
use super::*;
use crate::crypto::{AesCtr, SecureRandom, sha256, sha256_hmac};
use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Barrier;
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn test_config_with_secret_hex(secret_hex: &str) -> ProxyConfig {
let mut cfg = ProxyConfig::default();
cfg.access.users.clear();
cfg.access
.users
.insert("user".to_string(), secret_hex.to_string());
cfg.access.ignore_time_skew = true;
cfg.general.modes.secure = true;
cfg.general.modes.classic = true;
cfg.general.modes.tls = true;
cfg
}
fn make_valid_tls_client_hello_with_alpn(
secret: &[u8],
timestamp: u32,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
body.push(32);
body.extend_from_slice(&[0x42u8; 32]);
body.extend_from_slice(&2u16.to_be_bytes());
body.extend_from_slice(&[0x13, 0x01]);
body.push(1);
body.push(0);
let mut ext_blob = Vec::new();
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
alpn_list.push(proto.len() as u8);
alpn_list.extend_from_slice(proto);
}
let mut alpn_data = Vec::new();
alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
alpn_data.extend_from_slice(&alpn_list);
ext_blob.extend_from_slice(&0x0010u16.to_be_bytes());
ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&alpn_data);
}
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_blob);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut record = Vec::new();
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&[0x03, 0x01]);
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
record.extend_from_slice(&handshake);
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &record);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest);
record
}
fn make_valid_mtproto_handshake(
secret_hex: &str,
proto_tag: ProtoTag,
dc_idx: i16,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode");
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]
.iter_mut()
.enumerate()
{
*b = (idx as u8).wrapping_add(1);
}
let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN];
let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut stream = AesCtr::new(&dec_key, dec_iv);
let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]);
let mut target_plain = [0u8; HANDSHAKE_LEN];
target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
for idx in PROTO_TAG_POS..HANDSHAKE_LEN {
handshake[idx] = target_plain[idx] ^ keystream[idx];
}
handshake
}
#[tokio::test]
async fn tls_alpn_reject_does_not_pollute_replay_cache() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let secret = [0x11u8; 16];
let mut config = test_config_with_secret_hex("11111111111111111111111111111111");
config.censorship.alpn_enforce = true;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.201:12345".parse().unwrap();
let handshake = make_valid_tls_client_hello_with_alpn(&secret, 0, &[b"h3"]);
let before = replay_checker.stats();
let res = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
let after = replay_checker.stats();
assert!(matches!(res, HandshakeResult::BadClient { .. }));
assert_eq!(
before.total_additions, after.total_additions,
"ALPN policy reject must not add TLS digest into replay cache"
);
}
#[tokio::test]
async fn tls_truncated_session_id_len_fails_closed_without_panic() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let config = test_config_with_secret_hex("33333333333333333333333333333333");
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "192.0.2.203:12345".parse().unwrap();
let min_len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1;
let mut malicious = vec![0x42u8; min_len];
malicious[min_len - 1] = u8::MAX;
let res = handle_tls_handshake(
&malicious,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(res, HandshakeResult::BadClient { .. }));
}
#[test]
fn auth_probe_eviction_identical_timestamps_keeps_map_bounded() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let state = auth_probe_state_map();
let same = Instant::now();
for i in 0..AUTH_PROBE_TRACK_MAX_ENTRIES {
let ip = IpAddr::V4(Ipv4Addr::new(10, 1, (i >> 8) as u8, (i & 0xFF) as u8));
state.insert(
ip,
AuthProbeState {
fail_streak: 7,
blocked_until: same,
last_seen: same,
},
);
}
let new_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 21, 21));
auth_probe_record_failure_with_state(state, new_ip, same + Duration::from_millis(1));
assert_eq!(state.len(), AUTH_PROBE_TRACK_MAX_ENTRIES);
assert!(state.contains_key(&new_ip));
}
#[test]
fn clear_auth_probe_state_recovers_from_poisoned_saturation_lock() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let saturation = auth_probe_saturation_state();
let poison_thread = std::thread::spawn(move || {
let _hold = saturation
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
panic!("intentional poison for regression coverage");
});
let _ = poison_thread.join();
clear_auth_probe_state_for_testing();
let guard = auth_probe_saturation_state()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
assert!(guard.is_none());
}
#[tokio::test]
async fn mtproto_invalid_length_secret_is_ignored_and_valid_user_still_auths() {
let _probe_guard = auth_probe_test_guard();
let _warn_guard = warned_secrets_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
clear_auth_probe_state_for_testing();
clear_warned_secrets_for_testing();
let mut config = ProxyConfig::default();
config.general.modes.secure = true;
config.access.ignore_time_skew = true;
config.access.users.insert(
"short_user".to_string(),
"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string(),
);
let valid_secret_hex = "77777777777777777777777777777777";
config
.access
.users
.insert("good_user".to_string(), valid_secret_hex.to_string());
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let peer: SocketAddr = "192.0.2.207:12345".parse().unwrap();
let handshake = make_valid_mtproto_handshake(valid_secret_hex, ProtoTag::Secure, 1);
let res = handle_mtproto_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
false,
None,
)
.await;
assert!(matches!(res, HandshakeResult::Success(_)));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn saturation_grace_exhaustion_under_concurrency_keeps_peer_throttled() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
let peer_ip = IpAddr::V4(Ipv4Addr::new(198, 51, 100, 80));
let now = Instant::now();
{
let mut guard = auth_probe_saturation_state()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
*guard = Some(AuthProbeSaturationState {
fail_streak: AUTH_PROBE_BACKOFF_START_FAILS,
blocked_until: now + Duration::from_secs(5),
last_seen: now,
});
}
let state = auth_probe_state_map();
state.insert(
peer_ip,
AuthProbeState {
fail_streak: AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS - 1,
blocked_until: now,
last_seen: now,
},
);
let tasks = 32;
let barrier = Arc::new(Barrier::new(tasks));
let mut handles = Vec::new();
for _ in 0..tasks {
let b = barrier.clone();
handles.push(tokio::spawn(async move {
b.wait().await;
auth_probe_record_failure(peer_ip, Instant::now());
}));
}
for handle in handles {
handle.await.unwrap();
}
let final_state = state.get(&peer_ip).expect("state must exist");
assert!(
final_state.fail_streak
>= AUTH_PROBE_BACKOFF_START_FAILS + AUTH_PROBE_SATURATION_GRACE_FAILS
);
assert!(auth_probe_should_apply_preauth_throttle(
peer_ip,
Instant::now()
));
}

View File

@ -956,6 +956,89 @@ async fn stress_tls_sni_preferred_user_hint_scales_to_large_user_set() {
}
}
#[tokio::test]
async fn tls_unknown_sni_drop_policy_returns_hard_error() {
let secret = [0x48u8; 16];
let mut config = test_config_with_secret_hex("48484848484848484848484848484848");
config.censorship.unknown_sni_action = UnknownSniAction::Drop;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.190:44326".parse().unwrap();
let handshake =
make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "unknown.example", &[b"h2"]);
let result = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(
result,
HandshakeResult::Error(ProxyError::UnknownTlsSni)
));
}
#[tokio::test]
async fn tls_unknown_sni_mask_policy_falls_back_to_bad_client() {
let secret = [0x49u8; 16];
let mut config = test_config_with_secret_hex("49494949494949494949494949494949");
config.censorship.unknown_sni_action = UnknownSniAction::Mask;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.191:44326".parse().unwrap();
let handshake =
make_valid_tls_client_hello_with_sni_and_alpn(&secret, 0, "unknown.example", &[b"h2"]);
let result = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::BadClient { .. }));
}
#[tokio::test]
async fn tls_missing_sni_keeps_legacy_auth_path() {
let secret = [0x4Au8; 16];
let mut config = test_config_with_secret_hex("4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a4a");
config.censorship.unknown_sni_action = UnknownSniAction::Drop;
let replay_checker = ReplayChecker::new(128, Duration::from_secs(60));
let rng = SecureRandom::new();
let peer: SocketAddr = "198.51.100.192:44326".parse().unwrap();
let handshake = make_valid_tls_handshake(&secret, 0);
let result = handle_tls_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer,
&config,
&replay_checker,
&rng,
None,
)
.await;
assert!(matches!(result, HandshakeResult::Success(_)));
}
#[tokio::test]
async fn alpn_enforce_rejects_unsupported_client_alpn() {
let secret = [0x33u8; 16];

View File

@ -0,0 +1,318 @@
use super::*;
use crate::crypto::{AesCtr, SecureRandom, sha256, sha256_hmac};
use crate::protocol::constants::{ProtoTag, TLS_RECORD_HANDSHAKE, TLS_VERSION};
use std::net::SocketAddr;
use std::time::{Duration, Instant};
fn auth_probe_test_guard() -> std::sync::MutexGuard<'static, ()> {
auth_probe_test_lock()
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
fn make_valid_mtproto_handshake(
secret_hex: &str,
proto_tag: ProtoTag,
dc_idx: i16,
salt: u8,
) -> [u8; HANDSHAKE_LEN] {
let secret = hex::decode(secret_hex).expect("secret hex must decode");
let mut handshake = [0x5Au8; HANDSHAKE_LEN];
for (idx, b) in handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN]
.iter_mut()
.enumerate()
{
*b = (idx as u8).wrapping_add(1).wrapping_add(salt);
}
let dec_prekey = &handshake[SKIP_LEN..SKIP_LEN + PREKEY_LEN];
let dec_iv_bytes = &handshake[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let mut dec_key_input = Vec::with_capacity(PREKEY_LEN + secret.len());
dec_key_input.extend_from_slice(dec_prekey);
dec_key_input.extend_from_slice(&secret);
let dec_key = sha256(&dec_key_input);
let mut dec_iv_arr = [0u8; IV_LEN];
dec_iv_arr.copy_from_slice(dec_iv_bytes);
let dec_iv = u128::from_be_bytes(dec_iv_arr);
let mut stream = AesCtr::new(&dec_key, dec_iv);
let keystream = stream.encrypt(&[0u8; HANDSHAKE_LEN]);
let mut target_plain = [0u8; HANDSHAKE_LEN];
target_plain[PROTO_TAG_POS..PROTO_TAG_POS + 4].copy_from_slice(&proto_tag.to_bytes());
target_plain[DC_IDX_POS..DC_IDX_POS + 2].copy_from_slice(&dc_idx.to_le_bytes());
for idx in PROTO_TAG_POS..HANDSHAKE_LEN {
handshake[idx] = target_plain[idx] ^ keystream[idx];
}
handshake
}
fn make_valid_tls_handshake(secret: &[u8], timestamp: u32) -> Vec<u8> {
let session_id_len: usize = 32;
let len = tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN + 1 + session_id_len;
let mut handshake = vec![0x42u8; len];
handshake[tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN] = session_id_len as u8;
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &handshake);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
handshake[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN]
.copy_from_slice(&digest);
handshake
}
fn make_valid_tls_client_hello_with_sni_and_alpn(
secret: &[u8],
timestamp: u32,
sni_host: &str,
alpn_protocols: &[&[u8]],
) -> Vec<u8> {
let mut body = Vec::new();
body.extend_from_slice(&TLS_VERSION);
body.extend_from_slice(&[0u8; 32]);
body.push(32);
body.extend_from_slice(&[0x42u8; 32]);
body.extend_from_slice(&2u16.to_be_bytes());
body.extend_from_slice(&[0x13, 0x01]);
body.push(1);
body.push(0);
let mut ext_blob = Vec::new();
let host_bytes = sni_host.as_bytes();
let mut sni_payload = Vec::new();
sni_payload.extend_from_slice(&((host_bytes.len() + 3) as u16).to_be_bytes());
sni_payload.push(0);
sni_payload.extend_from_slice(&(host_bytes.len() as u16).to_be_bytes());
sni_payload.extend_from_slice(host_bytes);
ext_blob.extend_from_slice(&0x0000u16.to_be_bytes());
ext_blob.extend_from_slice(&(sni_payload.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&sni_payload);
if !alpn_protocols.is_empty() {
let mut alpn_list = Vec::new();
for proto in alpn_protocols {
alpn_list.push(proto.len() as u8);
alpn_list.extend_from_slice(proto);
}
let mut alpn_data = Vec::new();
alpn_data.extend_from_slice(&(alpn_list.len() as u16).to_be_bytes());
alpn_data.extend_from_slice(&alpn_list);
ext_blob.extend_from_slice(&0x0010u16.to_be_bytes());
ext_blob.extend_from_slice(&(alpn_data.len() as u16).to_be_bytes());
ext_blob.extend_from_slice(&alpn_data);
}
body.extend_from_slice(&(ext_blob.len() as u16).to_be_bytes());
body.extend_from_slice(&ext_blob);
let mut handshake = Vec::new();
handshake.push(0x01);
let body_len = (body.len() as u32).to_be_bytes();
handshake.extend_from_slice(&body_len[1..4]);
handshake.extend_from_slice(&body);
let mut record = Vec::new();
record.push(TLS_RECORD_HANDSHAKE);
record.extend_from_slice(&[0x03, 0x01]);
record.extend_from_slice(&(handshake.len() as u16).to_be_bytes());
record.extend_from_slice(&handshake);
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].fill(0);
let computed = sha256_hmac(secret, &record);
let mut digest = computed;
let ts = timestamp.to_le_bytes();
for i in 0..4 {
digest[28 + i] ^= ts[i];
}
record[tls::TLS_DIGEST_POS..tls::TLS_DIGEST_POS + tls::TLS_DIGEST_LEN].copy_from_slice(&digest);
record
}
fn median_ns(samples: &mut [u128]) -> u128 {
samples.sort_unstable();
samples[samples.len() / 2]
}
#[tokio::test]
#[ignore = "manual benchmark: timing-sensitive and host-dependent"]
async fn mtproto_user_scan_timing_manual_benchmark() {
let _guard = auth_probe_test_guard();
clear_auth_probe_state_for_testing();
const DECOY_USERS: usize = 8_000;
const ITERATIONS: usize = 250;
let preferred_user = "target_user";
let target_secret_hex = "dededededededededededededededede";
let mut config = ProxyConfig::default();
config.general.modes.secure = true;
config.access.ignore_time_skew = true;
for i in 0..DECOY_USERS {
config.access.users.insert(
format!("decoy_{i}"),
"00000000000000000000000000000000".to_string(),
);
}
config
.access
.users
.insert(preferred_user.to_string(), target_secret_hex.to_string());
let replay_checker_preferred = ReplayChecker::new(65_536, Duration::from_secs(60));
let replay_checker_full_scan = ReplayChecker::new(65_536, Duration::from_secs(60));
let peer_a: SocketAddr = "192.0.2.241:12345".parse().unwrap();
let peer_b: SocketAddr = "192.0.2.242:12345".parse().unwrap();
let mut preferred_samples = Vec::with_capacity(ITERATIONS);
let mut full_scan_samples = Vec::with_capacity(ITERATIONS);
for i in 0..ITERATIONS {
let handshake = make_valid_mtproto_handshake(
target_secret_hex,
ProtoTag::Secure,
1 + i as i16,
(i % 251) as u8,
);
let started_preferred = Instant::now();
let preferred = handle_mtproto_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer_a,
&config,
&replay_checker_preferred,
false,
Some(preferred_user),
)
.await;
preferred_samples.push(started_preferred.elapsed().as_nanos());
assert!(matches!(preferred, HandshakeResult::Success(_)));
let started_scan = Instant::now();
let full_scan = handle_mtproto_handshake(
&handshake,
tokio::io::empty(),
tokio::io::sink(),
peer_b,
&config,
&replay_checker_full_scan,
false,
None,
)
.await;
full_scan_samples.push(started_scan.elapsed().as_nanos());
assert!(matches!(full_scan, HandshakeResult::Success(_)));
}
let preferred_median = median_ns(&mut preferred_samples);
let full_scan_median = median_ns(&mut full_scan_samples);
let ratio = if preferred_median == 0 {
0.0
} else {
full_scan_median as f64 / preferred_median as f64
};
println!(
"manual timing benchmark: decoys={DECOY_USERS}, iters={ITERATIONS}, preferred_median_ns={preferred_median}, full_scan_median_ns={full_scan_median}, ratio={ratio:.3}"
);
assert!(
full_scan_median >= preferred_median,
"full user scan should not be faster than preferred-user path in this benchmark"
);
}
#[tokio::test]
#[ignore = "manual benchmark: timing-sensitive and host-dependent"]
async fn tls_sni_preferred_vs_no_sni_fallback_manual_benchmark() {
let _guard = auth_probe_test_guard();
const DECOY_USERS: usize = 8_000;
const ITERATIONS: usize = 250;
let preferred_user = "user-b";
let target_secret_hex = "abababababababababababababababab";
let target_secret = [0xABu8; 16];
let mut config = ProxyConfig::default();
config.general.modes.tls = true;
config.access.ignore_time_skew = true;
for i in 0..DECOY_USERS {
config.access.users.insert(
format!("decoy_{i}"),
"00000000000000000000000000000000".to_string(),
);
}
config
.access
.users
.insert(preferred_user.to_string(), target_secret_hex.to_string());
let mut sni_samples = Vec::with_capacity(ITERATIONS);
let mut no_sni_samples = Vec::with_capacity(ITERATIONS);
for i in 0..ITERATIONS {
let with_sni = make_valid_tls_client_hello_with_sni_and_alpn(
&target_secret,
i as u32,
preferred_user,
&[b"h2"],
);
let no_sni = make_valid_tls_handshake(&target_secret, (i as u32).wrapping_add(10_000));
let started_sni = Instant::now();
let sni_secrets = decode_user_secrets(&config, Some(preferred_user));
let sni_result = tls::validate_tls_handshake_with_replay_window(
&with_sni,
&sni_secrets,
config.access.ignore_time_skew,
config.access.replay_window_secs,
);
sni_samples.push(started_sni.elapsed().as_nanos());
assert!(sni_result.is_some());
let started_no_sni = Instant::now();
let no_sni_secrets = decode_user_secrets(&config, None);
let no_sni_result = tls::validate_tls_handshake_with_replay_window(
&no_sni,
&no_sni_secrets,
config.access.ignore_time_skew,
config.access.replay_window_secs,
);
no_sni_samples.push(started_no_sni.elapsed().as_nanos());
assert!(no_sni_result.is_some());
}
let sni_median = median_ns(&mut sni_samples);
let no_sni_median = median_ns(&mut no_sni_samples);
let ratio = if sni_median == 0 {
0.0
} else {
no_sni_median as f64 / sni_median as f64
};
println!(
"manual tls benchmark: decoys={DECOY_USERS}, iters={ITERATIONS}, sni_median_ns={sni_median}, no_sni_median_ns={no_sni_median}, ratio_no_sni_over_sni={ratio:.3}"
);
}

View File

@ -493,9 +493,12 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
];
let mut meaningful_improvement_seen = false;
let mut baseline_sum = 0.0f64;
let mut hardened_sum = 0.0f64;
let mut pair_count = 0usize;
let mut informative_baseline_sum = 0.0f64;
let mut informative_hardened_sum = 0.0f64;
let mut informative_pair_count = 0usize;
let mut low_info_baseline_sum = 0.0f64;
let mut low_info_hardened_sum = 0.0f64;
let mut low_info_pair_count = 0usize;
let acc_quant_step = 1.0 / (2 * SAMPLE_COUNT) as f64;
let tolerated_pair_regression = acc_quant_step + 0.03;
@ -522,6 +525,16 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
hardened_acc <= baseline_acc + tolerated_pair_regression,
"normalization should not materially worsen informative pair: baseline={baseline_acc:.3} hardened={hardened_acc:.3} tolerated={tolerated_pair_regression:.3}"
);
informative_baseline_sum += baseline_acc;
informative_hardened_sum += hardened_acc;
informative_pair_count += 1;
} else {
// Low-information pairs (near-random baseline separability) are expected
// to exhibit quantized jitter at low sample counts; do not fold them into
// strict average-regression checks used for informative side-channel signal.
low_info_baseline_sum += baseline_acc;
low_info_hardened_sum += hardened_acc;
low_info_pair_count += 1;
}
println!(
@ -531,20 +544,30 @@ async fn timing_classifier_light_fuzz_pairwise_bucketed_accuracy_stays_bounded_u
if hardened_acc + 0.05 <= baseline_acc {
meaningful_improvement_seen = true;
}
baseline_sum += baseline_acc;
hardened_sum += hardened_acc;
pair_count += 1;
}
let baseline_avg = baseline_sum / pair_count as f64;
let hardened_avg = hardened_sum / pair_count as f64;
assert!(
informative_pair_count > 0,
"expected at least one informative pair for timing-separability guard"
);
let informative_baseline_avg = informative_baseline_sum / informative_pair_count as f64;
let informative_hardened_avg = informative_hardened_sum / informative_pair_count as f64;
assert!(
hardened_avg <= baseline_avg + 0.10,
"normalization should not materially increase average pairwise separability: baseline_avg={baseline_avg:.3} hardened_avg={hardened_avg:.3}"
informative_hardened_avg <= informative_baseline_avg + 0.10,
"normalization should not materially increase informative average separability: baseline_avg={informative_baseline_avg:.3} hardened_avg={informative_hardened_avg:.3}"
);
if low_info_pair_count > 0 {
let low_info_baseline_avg = low_info_baseline_sum / low_info_pair_count as f64;
let low_info_hardened_avg = low_info_hardened_sum / low_info_pair_count as f64;
assert!(
low_info_hardened_avg <= low_info_baseline_avg + 0.40,
"normalization low-info average drift exceeded jitter budget: baseline_avg={low_info_baseline_avg:.3} hardened_avg={low_info_hardened_avg:.3}"
);
}
// Optional signal only: do not require improvement on every run because
// noisy CI schedulers can flatten pairwise differences at low sample counts.
let _ = meaningful_improvement_seen;

View File

@ -0,0 +1,126 @@
use super::*;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use tokio::io::AsyncRead;
use tokio::time::{Duration, timeout};
struct EndlessReader {
produced: Arc<AtomicUsize>,
}
impl AsyncRead for EndlessReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let len = buf.remaining().max(1);
let fill = vec![0xAA; len];
buf.put_slice(&fill);
self.produced.fetch_add(len, Ordering::Relaxed);
Poll::Ready(Ok(()))
}
}
#[test]
fn loop_guard_unspecified_bind_uses_interface_inventory() {
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
let resolved: SocketAddr = "192.168.44.10:443".parse().unwrap();
let interfaces = vec!["192.168.44.10".parse().unwrap()];
assert!(is_mask_target_local_listener_with_interfaces(
"mask.example",
443,
local,
Some(resolved),
&interfaces,
));
}
#[tokio::test]
async fn consume_client_data_stops_after_byte_cap_without_eof() {
let produced = Arc::new(AtomicUsize::new(0));
let reader = EndlessReader {
produced: Arc::clone(&produced),
};
let cap = 10_000usize;
consume_client_data(reader, cap).await;
let total = produced.load(Ordering::Relaxed);
assert!(
total >= cap,
"consume path must read at least up to cap before stopping"
);
assert!(
total <= cap + 8192,
"consume path must stop within one read chunk above cap"
);
}
#[test]
fn masking_beobachten_minutes_zero_fail_closes_to_minimum_ttl() {
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 0;
let ttl = masking_beobachten_ttl(&config);
assert_eq!(ttl, std::time::Duration::from_secs(60));
}
#[test]
fn timing_normalization_zero_floor_safety_net_defaults_to_mask_timeout() {
let mut config = ProxyConfig::default();
config.censorship.mask_timing_normalization_enabled = true;
config.censorship.mask_timing_normalization_floor_ms = 0;
config.censorship.mask_timing_normalization_ceiling_ms = 0;
let budget = mask_outcome_target_budget(&config);
assert_eq!(
budget,
Duration::from_millis(0),
"zero floor/ceiling must produce zero extra normalization budget"
);
}
#[tokio::test]
async fn loop_guard_blocks_self_target_before_proxy_protocol_header_growth() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
timeout(Duration::from_millis(120), listener.accept())
.await
.is_ok()
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_proxy_protocol = 2;
let peer: SocketAddr = "203.0.113.251:55991".parse().unwrap();
let local_addr: SocketAddr = format!("0.0.0.0:{}", backend_addr.port()).parse().unwrap();
let beobachten = BeobachtenStore::new();
handle_bad_client(
tokio::io::empty(),
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
let accepted = accept_task.await.unwrap();
assert!(
!accepted,
"loop guard must fail closed before any recursive PROXY protocol amplification"
);
}

View File

@ -85,7 +85,10 @@ async fn aggressive_mode_shapes_backend_silent_non_eof_path() {
let legacy = capture_forwarded_len_with_mode(body_sent, false, false, false, 0).await;
let aggressive = capture_forwarded_len_with_mode(body_sent, false, true, false, 0).await;
assert!(legacy < floor, "legacy mode should keep timeout path unshaped");
assert!(
legacy < floor,
"legacy mode should keep timeout path unshaped"
);
assert!(
aggressive >= floor,
"aggressive mode must shape backend-silent non-EOF paths (aggressive={aggressive}, floor={floor})"

View File

@ -0,0 +1,16 @@
use super::*;
#[test]
fn detect_client_type_recognizes_extended_http_probe_verbs() {
assert_eq!(detect_client_type(b"CONNECT / HTTP/1.1\r\n"), "HTTP");
assert_eq!(detect_client_type(b"TRACE / HTTP/1.1\r\n"), "HTTP");
assert_eq!(detect_client_type(b"PATCH / HTTP/1.1\r\n"), "HTTP");
}
#[test]
fn detect_client_type_recognizes_fragmented_http_method_prefixes() {
assert_eq!(detect_client_type(b"CO"), "HTTP");
assert_eq!(detect_client_type(b"CON"), "HTTP");
assert_eq!(detect_client_type(b"TR"), "HTTP");
assert_eq!(detect_client_type(b"PAT"), "HTTP");
}

View File

@ -0,0 +1,126 @@
use super::*;
use crate::network::dns_overrides::install_entries;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant, timeout};
async fn run_connect_failure_case(
host: &str,
port: u16,
timing_normalization_enabled: bool,
peer: SocketAddr,
) -> Duration {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some(host.to_string());
config.censorship.mask_port = port;
config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled;
config.censorship.mask_timing_normalization_floor_ms = 120;
config.censorship.mask_timing_normalization_ceiling_ms = 120;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let probe = b"CONNECT example.org:443 HTTP/1.1\r\nHost: example.org\r\n\r\n";
let (mut client_writer, client_reader) = duplex(1024);
let (mut client_visible_reader, client_visible_writer) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
client_reader,
client_visible_writer,
probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
client_writer.shutdown().await.unwrap();
timeout(Duration::from_secs(4), task)
.await
.unwrap()
.unwrap();
let mut buf = [0u8; 1];
let n = timeout(Duration::from_secs(1), client_visible_reader.read(&mut buf))
.await
.unwrap()
.unwrap();
assert_eq!(
n, 0,
"connect-failure path must close client-visible writer"
);
started.elapsed()
}
#[tokio::test]
async fn connect_failure_refusal_close_behavior_matrix() {
let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() {
let peer: SocketAddr = format!("203.0.113.210:{}", 54100 + idx as u16)
.parse()
.unwrap();
let elapsed =
run_connect_failure_case("127.0.0.1", unused_port, timing_normalization_enabled, peer)
.await;
if timing_normalization_enabled {
assert!(
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250),
"normalized refusal path must honor configured timing envelope without stalling"
);
} else {
assert!(
elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150),
"non-normalized refusal path must honor baseline connect budget without stalling"
);
}
}
}
#[tokio::test]
async fn connect_failure_overridden_hostname_close_behavior_matrix() {
let temp_listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let unused_port = temp_listener.local_addr().unwrap().port();
drop(temp_listener);
// Make hostname resolution deterministic in tests so timing ceilings are meaningful.
install_entries(&[format!("mask.invalid:{}:127.0.0.1", unused_port)]).unwrap();
for (idx, timing_normalization_enabled) in [false, true].into_iter().enumerate() {
let peer: SocketAddr = format!("203.0.113.220:{}", 54200 + idx as u16)
.parse()
.unwrap();
let elapsed = run_connect_failure_case(
"mask.invalid",
unused_port,
timing_normalization_enabled,
peer,
)
.await;
if timing_normalization_enabled {
assert!(
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(250),
"normalized overridden-host path must honor configured timing envelope without stalling"
);
} else {
assert!(
elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(150),
"non-normalized overridden-host path must honor baseline connect budget without stalling"
);
}
}
install_entries(&[]).unwrap();
}

View File

@ -0,0 +1,88 @@
use super::*;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::io::{AsyncRead, ReadBuf};
struct OneByteThenStall {
sent: bool,
}
impl AsyncRead for OneByteThenStall {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if !self.sent {
self.sent = true;
buf.put_slice(&[0x42]);
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
#[tokio::test]
async fn stalling_client_terminates_at_idle_not_relay_timeout() {
let reader = OneByteThenStall { sent: false };
let started = Instant::now();
let result = tokio::time::timeout(
MASK_RELAY_TIMEOUT,
consume_client_data(reader, MASK_BUFFER_SIZE * 4),
)
.await;
assert!(
result.is_ok(),
"consume_client_data should complete by per-read idle timeout, not hit relay timeout"
);
let elapsed = started.elapsed();
assert!(
elapsed >= (MASK_RELAY_IDLE_TIMEOUT / 2),
"consume_client_data returned too quickly for idle-timeout path: {elapsed:?}"
);
assert!(
elapsed < MASK_RELAY_TIMEOUT,
"consume_client_data waited full relay timeout ({elapsed:?}); \
per-read idle timeout is missing"
);
}
#[tokio::test]
async fn fast_reader_drains_to_eof() {
let data = vec![0xAAu8; 32 * 1024];
let reader = std::io::Cursor::new(data);
tokio::time::timeout(MASK_RELAY_TIMEOUT, consume_client_data(reader, usize::MAX))
.await
.expect("consume_client_data did not complete for fast EOF reader");
}
#[tokio::test]
async fn io_error_terminates_cleanly() {
struct ErrReader;
impl AsyncRead for ErrReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"simulated reset",
)))
}
}
tokio::time::timeout(
MASK_RELAY_TIMEOUT,
consume_client_data(ErrReader, usize::MAX),
)
.await
.expect("consume_client_data did not return on I/O error");
}

View File

@ -0,0 +1,64 @@
use super::*;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::io::{AsyncRead, ReadBuf};
use tokio::task::JoinSet;
struct OneByteThenStall {
sent: bool,
}
impl AsyncRead for OneByteThenStall {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if !self.sent {
self.sent = true;
buf.put_slice(&[0xAA]);
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
}
#[tokio::test]
async fn consume_stall_stress_finishes_within_idle_budget() {
let mut set = JoinSet::new();
let started = Instant::now();
for _ in 0..64 {
set.spawn(async {
tokio::time::timeout(
MASK_RELAY_TIMEOUT,
consume_client_data(OneByteThenStall { sent: false }, usize::MAX),
)
.await
.expect("consume_client_data exceeded relay timeout under stall load");
});
}
while let Some(res) = set.join_next().await {
res.unwrap();
}
// Under test constants idle=100ms, relay=200ms. 64 concurrent tasks stalling
// for 100ms should complete well under a strict 600ms boundary.
assert!(
started.elapsed() < MASK_RELAY_TIMEOUT * 3,
"stall stress batch completed too slowly; possible async executor starvation or head-of-line blocking"
);
}
#[tokio::test]
async fn consume_zero_cap_returns_immediately() {
let started = Instant::now();
consume_client_data(tokio::io::empty(), 0).await;
assert!(
started.elapsed() < MASK_RELAY_IDLE_TIMEOUT,
"zero byte cap must return immediately"
);
}

View File

@ -0,0 +1,225 @@
use super::*;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant, timeout};
fn make_self_target_config(
timing_normalization_enabled: bool,
floor_ms: u64,
ceiling_ms: u64,
beobachten_enabled: bool,
) -> ProxyConfig {
let mut config = ProxyConfig::default();
config.general.beobachten = beobachten_enabled;
config.general.beobachten_minutes = 5;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 443;
config.censorship.mask_timing_normalization_enabled = timing_normalization_enabled;
config.censorship.mask_timing_normalization_floor_ms = floor_ms;
config.censorship.mask_timing_normalization_ceiling_ms = ceiling_ms;
config
}
async fn run_self_target_refusal(
config: ProxyConfig,
peer: SocketAddr,
initial: &'static [u8],
) -> Duration {
let beobachten = BeobachtenStore::new();
let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr");
let (mut client, server) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
server,
tokio::io::sink(),
initial,
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
client
.shutdown()
.await
.expect("client shutdown must succeed");
timeout(Duration::from_secs(3), task)
.await
.expect("self-target refusal must complete in bounded time")
.expect("self-target refusal task must not panic");
started.elapsed()
}
#[tokio::test]
async fn positive_self_target_refusal_honors_normalization_floor() {
let config = make_self_target_config(true, 120, 120, false);
let peer: SocketAddr = "203.0.113.41:54041".parse().expect("valid peer");
let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await;
assert!(
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(260),
"normalized self-target refusal must stay within expected envelope"
);
}
#[tokio::test]
async fn negative_non_normalized_refusal_does_not_sleep_to_large_floor() {
let config = make_self_target_config(false, 240, 240, false);
let peer: SocketAddr = "203.0.113.42:54042".parse().expect("valid peer");
let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await;
assert!(
elapsed < Duration::from_millis(180),
"non-normalized path must not inherit normalization floor delays"
);
}
#[tokio::test]
async fn edge_ceiling_below_floor_uses_floor_fail_closed() {
let config = make_self_target_config(true, 140, 80, false);
let peer: SocketAddr = "203.0.113.43:54043".parse().expect("valid peer");
let elapsed = run_self_target_refusal(config, peer, b"GET / HTTP/1.1\r\n\r\n").await;
assert!(
elapsed >= Duration::from_millis(130) && elapsed < Duration::from_millis(280),
"ceiling<floor must clamp to floor to preserve deterministic normalization"
);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_blackhat_parallel_probes_remain_bounded_and_uniform() {
let workers = 24usize;
let mut tasks = Vec::with_capacity(workers);
for idx in 0..workers {
tasks.push(tokio::spawn(async move {
let cfg = make_self_target_config(true, 110, 140, false);
let peer: SocketAddr = format!("203.0.113.50:{}", 54100 + idx as u16)
.parse()
.expect("valid peer");
run_self_target_refusal(cfg, peer, b"GET /x HTTP/1.1\r\n\r\n").await
}));
}
let mut min = Duration::from_secs(60);
let mut max = Duration::from_millis(0);
for task in tasks {
let elapsed = task.await.expect("probe task must not panic");
if elapsed < min {
min = elapsed;
}
if elapsed > max {
max = elapsed;
}
assert!(
elapsed >= Duration::from_millis(100) && elapsed < Duration::from_millis(320),
"parallel probe latency must stay bounded under normalization"
);
}
assert!(
max.saturating_sub(min) <= Duration::from_millis(130),
"normalization should limit path variance across adversarial parallel probes"
);
}
#[tokio::test]
async fn integration_beobachten_records_probe_classification_on_refusal() {
let config = make_self_target_config(false, 0, 0, true);
let peer: SocketAddr = "198.51.100.71:55071".parse().expect("valid peer");
let local_addr: SocketAddr = "127.0.0.1:443".parse().expect("valid local addr");
let beobachten = BeobachtenStore::new();
let (mut client, server) = duplex(1024);
let task = tokio::spawn(async move {
handle_bad_client(
server,
tokio::io::sink(),
b"GET /classified HTTP/1.1\r\nHost: demo\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
beobachten.snapshot_text(Duration::from_secs(60))
});
client
.shutdown()
.await
.expect("client shutdown must succeed");
let snapshot = timeout(Duration::from_secs(3), task)
.await
.expect("integration task must complete")
.expect("integration task must not panic");
assert!(snapshot.contains("[HTTP]"));
assert!(snapshot.contains("198.51.100.71-1"));
}
#[tokio::test]
async fn light_fuzz_timing_configuration_matrix_is_bounded() {
let mut seed = 0xA17E_55AA_2026_0323u64;
for case in 0..48u64 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let enabled = (seed & 1) == 0;
let floor = (seed >> 8) % 180;
let ceiling = (seed >> 24) % 180;
let config = make_self_target_config(enabled, floor, ceiling, false);
let peer: SocketAddr = format!("203.0.113.90:{}", 56000 + (case as u16))
.parse()
.expect("valid peer");
let elapsed = run_self_target_refusal(config, peer, b"HEAD /h HTTP/1.1\r\n\r\n").await;
assert!(
elapsed < Duration::from_millis(420),
"fuzz case must stay bounded and never hang"
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_high_fanout_self_target_refusal_no_deadlock_or_timeout() {
let workers = 64usize;
let mut tasks = Vec::with_capacity(workers);
for idx in 0..workers {
tasks.push(tokio::spawn(async move {
let config = make_self_target_config(false, 0, 0, false);
let peer: SocketAddr = format!("198.51.100.200:{}", 57000 + idx as u16)
.parse()
.expect("valid peer");
run_self_target_refusal(config, peer, b"GET /stress HTTP/1.1\r\n\r\n").await
}));
}
timeout(Duration::from_secs(5), async {
for task in tasks {
let elapsed = task.await.expect("stress task must not panic");
assert!(
elapsed < Duration::from_millis(260),
"stress refusal must remain bounded without normalization"
);
}
})
.await
.expect("high-fanout refusal workload must complete without deadlock");
}

View File

@ -0,0 +1,55 @@
use super::*;
use tokio::net::TcpListener;
use tokio::time::Duration;
#[tokio::test]
async fn http2_preface_is_forwarded_and_recorded_as_http() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".to_vec();
let accept_task = tokio::spawn({
let preface = preface.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut received = vec![0u8; preface.len()];
stream.read_exact(&mut received).await.unwrap();
assert_eq!(received, preface);
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = true;
config.general.beobachten_minutes = 1;
config.censorship.mask = true;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_unix_sock = None;
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "198.51.100.130:54130".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let (client_reader, _client_writer) = tokio::io::duplex(512);
let (_client_visible_reader, client_visible_writer) = tokio::io::duplex(512);
let beobachten = BeobachtenStore::new();
handle_bad_client(
client_reader,
client_visible_writer,
&preface,
peer,
local_addr,
&config,
&beobachten,
)
.await;
tokio::time::timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
let snapshot = beobachten.snapshot_text(Duration::from_secs(60));
assert!(snapshot.contains("[HTTP]"));
assert!(snapshot.contains("198.51.100.130-1"));
}

View File

@ -0,0 +1,92 @@
use super::*;
#[test]
fn full_http2_preface_classified_as_http_probe() {
let preface = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
assert!(
is_http_probe(preface),
"HTTP/2 connection preface must be classified as HTTP probe"
);
}
#[test]
fn partial_http2_preface_3_bytes_classified() {
assert!(
is_http_probe(b"PRI"),
"3-byte HTTP/2 preface prefix must be classified"
);
}
#[test]
fn partial_http2_preface_2_bytes_classified() {
assert!(
is_http_probe(b"PR"),
"2-byte HTTP/2 preface prefix must be classified"
);
}
#[test]
fn existing_http1_methods_unaffected() {
for prefix in [
b"GET / HTTP/1.1\r\n".as_ref(),
b"POST /api HTTP/1.1\r\n".as_ref(),
b"CONNECT example.com:443 HTTP/1.1\r\n".as_ref(),
b"TRACE / HTTP/1.1\r\n".as_ref(),
b"PATCH / HTTP/1.1\r\n".as_ref(),
] {
assert!(is_http_probe(prefix));
}
}
#[test]
fn non_http_data_not_classified() {
for data in [
b"\x16\x03\x01\x00\xf1".as_ref(),
b"SSH-2.0-OpenSSH_8.9\r\n".as_ref(),
b"\x00\x01\x02\x03".as_ref(),
b"".as_ref(),
b"P".as_ref(),
] {
assert!(!is_http_probe(data));
}
}
#[test]
fn light_fuzz_non_http_prefixes_not_misclassified() {
// Deterministic pseudo-fuzz to exercise classifier edges while avoiding
// known HTTP method and partial windows.
let mut x = 0x1234_5678u32;
for _ in 0..1024 {
x = x.wrapping_mul(1664525).wrapping_add(1013904223);
let len = 4 + ((x >> 8) as usize % 12);
let mut data = vec![0u8; len];
for byte in &mut data {
x = x.wrapping_mul(1664525).wrapping_add(1013904223);
*byte = (x & 0xFF) as u8;
}
if [
b"GET ".as_ref(),
b"POST".as_ref(),
b"HEAD".as_ref(),
b"PUT ".as_ref(),
b"DELETE".as_ref(),
b"OPTIONS".as_ref(),
b"CONNECT".as_ref(),
b"TRACE".as_ref(),
b"PATCH".as_ref(),
b"PRI ".as_ref(),
]
.iter()
.any(|m| data.starts_with(m))
{
continue;
}
assert!(
!is_http_probe(&data),
"non-http pseudo-fuzz input misclassified: {:?}",
&data[..data.len().min(8)]
);
}
}

View File

@ -0,0 +1,85 @@
use super::*;
#[test]
fn exact_four_byte_http_tokens_are_classified() {
for token in [
b"GET ".as_ref(),
b"POST".as_ref(),
b"HEAD".as_ref(),
b"PUT ".as_ref(),
b"PRI ".as_ref(),
] {
assert!(
is_http_probe(token),
"exact 4-byte token must be classified as HTTP probe: {:?}",
token
);
}
}
#[test]
fn exact_four_byte_non_http_tokens_are_not_classified() {
for token in [
b"GEX ".as_ref(),
b"POXT".as_ref(),
b"HEA/".as_ref(),
b"PU\0 ".as_ref(),
b"PRI/".as_ref(),
] {
assert!(
!is_http_probe(token),
"non-HTTP 4-byte token must not be classified: {:?}",
token
);
}
}
#[test]
fn detect_client_type_keeps_http_label_for_minimal_four_byte_http_prefixes() {
assert_eq!(detect_client_type(b"GET "), "HTTP");
assert_eq!(detect_client_type(b"PRI "), "HTTP");
}
#[test]
fn exact_long_http_tokens_are_classified() {
for token in [b"CONNECT".as_ref(), b"TRACE".as_ref(), b"PATCH".as_ref()] {
assert!(
is_http_probe(token),
"exact long HTTP token must be classified as HTTP probe: {:?}",
token
);
}
}
#[test]
fn detect_client_type_keeps_http_label_for_exact_long_http_tokens() {
assert_eq!(detect_client_type(b"CONNECT"), "HTTP");
assert_eq!(detect_client_type(b"TRACE"), "HTTP");
assert_eq!(detect_client_type(b"PATCH"), "HTTP");
}
#[test]
fn light_fuzz_four_byte_ascii_noise_not_misclassified() {
// Deterministic pseudo-fuzz over 4-byte printable ASCII inputs.
let mut x = 0xA17C_93E5u32;
for _ in 0..2048 {
let mut token = [0u8; 4];
for byte in &mut token {
x = x.wrapping_mul(1664525).wrapping_add(1013904223);
*byte = 32 + ((x & 0x3F) as u8); // printable ASCII subset
}
if [b"GET ", b"POST", b"HEAD", b"PUT ", b"PRI "]
.iter()
.any(|m| token.as_slice() == *m)
{
continue;
}
assert!(
!is_http_probe(&token),
"pseudo-fuzz noise misclassified as HTTP probe: {:?}",
token
);
}
}

View File

@ -0,0 +1,41 @@
#![cfg(unix)]
use super::*;
use std::sync::{Mutex, OnceLock};
use tokio::sync::Barrier;
fn interface_cache_test_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_parallel_cold_miss_performs_single_interface_refresh() {
let _guard = interface_cache_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
reset_local_interface_enumerations_for_tests();
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let workers = 32usize;
let barrier = std::sync::Arc::new(Barrier::new(workers));
let mut tasks = Vec::with_capacity(workers);
for _ in 0..workers {
let barrier = std::sync::Arc::clone(&barrier);
tasks.push(tokio::spawn(async move {
barrier.wait().await;
is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await
}));
}
for task in tasks {
let _ = task.await.expect("parallel cache task must not panic");
}
assert_eq!(
local_interface_enumerations_for_tests(),
1,
"parallel cold misses must coalesce into a single interface enumeration"
);
}

View File

@ -0,0 +1,51 @@
#![cfg(unix)]
use super::*;
#[test]
fn defense_in_depth_empty_refresh_preserves_previous_non_empty_interfaces() {
let previous = vec![
"192.168.100.7"
.parse::<IpAddr>()
.expect("must parse interface ip"),
];
let refreshed = Vec::new();
let next = choose_interface_snapshot(&previous, refreshed);
assert_eq!(
next, previous,
"empty refresh should preserve previous non-empty snapshot to avoid fail-open loop-guard regressions"
);
}
#[test]
fn defense_in_depth_non_empty_refresh_replaces_previous_snapshot() {
let previous = vec![
"192.168.100.7"
.parse::<IpAddr>()
.expect("must parse interface ip"),
];
let refreshed = vec![
"10.55.0.3"
.parse::<IpAddr>()
.expect("must parse refreshed interface ip"),
];
let next = choose_interface_snapshot(&previous, refreshed.clone());
assert_eq!(next, refreshed);
}
#[test]
fn defense_in_depth_empty_refresh_keeps_empty_when_no_previous_snapshot_exists() {
let previous = Vec::new();
let refreshed = Vec::new();
let next = choose_interface_snapshot(&previous, refreshed);
assert!(
next.is_empty(),
"empty refresh with no previous snapshot should remain empty"
);
}

View File

@ -0,0 +1,49 @@
#![cfg(unix)]
use super::*;
use std::sync::{Mutex, OnceLock};
fn interface_cache_test_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
#[tokio::test]
async fn tdd_repeated_local_listener_checks_do_not_repeat_interface_enumeration_within_window() {
let _guard = interface_cache_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
reset_local_interface_enumerations_for_tests();
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await;
let _ = is_mask_target_local_listener_async("127.0.0.1", 443, local_addr, None).await;
assert_eq!(
local_interface_enumerations_for_tests(),
1,
"interface enumeration must be cached across repeated bad-client checks"
);
}
#[tokio::test]
async fn tdd_non_local_port_short_circuit_does_not_enumerate_interfaces() {
let _guard = interface_cache_test_lock()
.lock()
.unwrap_or_else(|poison| poison.into_inner());
reset_local_interface_enumerations_for_tests();
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let is_local = is_mask_target_local_listener_async("127.0.0.1", 8443, local_addr, None).await;
assert!(
!is_local,
"different port must not be treated as local listener"
);
assert_eq!(
local_interface_enumerations_for_tests(),
0,
"port mismatch should bypass interface enumeration entirely"
);
}

View File

@ -0,0 +1,178 @@
use super::*;
use std::net::{SocketAddr, TcpListener as StdTcpListener};
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant};
fn closed_local_port() -> u16 {
let listener = StdTcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
port
}
#[tokio::test]
#[ignore = "red-team expected-fail: offline mask target keeps bad-client socket alive before consume timeout boundary"]
async fn redteam_offline_target_should_drop_idle_client_early() {
let (client_read, mut client_write) = duplex(1024);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = closed_local_port();
cfg.censorship.mask_timing_normalization_enabled = false;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer_addr: SocketAddr = "192.0.2.50:5000".parse().unwrap();
let beobachten = BeobachtenStore::new();
let handler = tokio::spawn(async move {
handle_bad_client(
client_read,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer_addr,
local_addr,
&cfg,
&beobachten,
)
.await;
});
tokio::time::sleep(Duration::from_millis(150)).await;
let write_res = client_write.write_all(b"probe-should-be-closed").await;
assert!(
write_res.is_err(),
"offline target path still keeps client writable before consume timeout"
);
handler.abort();
}
#[tokio::test]
#[ignore = "red-team expected-fail: proxy should mimic immediate RST-like close when target is offline"]
async fn redteam_offline_target_should_not_sleep_to_mask_refusal() {
let (client_read, mut client_write) = duplex(1024);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = closed_local_port();
cfg.censorship.mask_timing_normalization_enabled = false;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer_addr: SocketAddr = "192.0.2.51:5000".parse().unwrap();
let beobachten = BeobachtenStore::new();
let started = Instant::now();
let handler = tokio::spawn(async move {
handle_bad_client(
client_read,
tokio::io::sink(),
b"\x16\x03\x01\x00\x05hello",
peer_addr,
local_addr,
&cfg,
&beobachten,
)
.await;
});
client_write.shutdown().await.unwrap();
let _ = handler.await;
let elapsed = started.elapsed();
assert!(
elapsed < Duration::from_millis(10),
"offline target path still applies coarse masking sleep and is fingerprintable"
);
}
#[tokio::test]
#[ignore = "red-team expected-fail: refusal path should remain below strict latency envelope under burst"]
async fn redteam_offline_refusal_burst_timing_spread_should_be_tight() {
let mut samples = Vec::new();
for i in 0..12u16 {
let (client_read, mut client_write) = duplex(1024);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("127.0.0.1".to_string());
cfg.censorship.mask_port = closed_local_port();
cfg.censorship.mask_timing_normalization_enabled = false;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer_addr: SocketAddr = format!("192.0.2.52:{}", 5100 + i).parse().unwrap();
let beobachten = BeobachtenStore::new();
let started = Instant::now();
let handler = tokio::spawn(async move {
handle_bad_client(
client_read,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer_addr,
local_addr,
&cfg,
&beobachten,
)
.await;
});
client_write.shutdown().await.unwrap();
let _ = handler.await;
samples.push(started.elapsed());
}
let min = samples.iter().copied().min().unwrap_or_default();
let max = samples.iter().copied().max().unwrap_or_default();
let spread = max.saturating_sub(min);
assert!(
spread <= Duration::from_millis(5),
"offline refusal timing spread too wide for strict red-team envelope: {:?}",
spread
);
}
#[tokio::test]
#[ignore = "manual red-team: host resolver failure should complete without panic"]
async fn redteam_dns_resolution_failure_must_not_panic() {
let (client_read, mut client_write) = duplex(1024);
let mut cfg = ProxyConfig::default();
cfg.general.beobachten = false;
cfg.censorship.mask = true;
cfg.censorship.mask_unix_sock = None;
cfg.censorship.mask_host = Some("this.domain.definitely.does.not.exist.invalid".to_string());
cfg.censorship.mask_port = 443;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer_addr: SocketAddr = "192.0.2.99:5999".parse().unwrap();
let beobachten = BeobachtenStore::new();
let handler = tokio::spawn(async move {
handle_bad_client(
client_read,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer_addr,
local_addr,
&cfg,
&beobachten,
)
.await;
});
client_write.shutdown().await.unwrap();
let result = tokio::time::timeout(Duration::from_secs(2), handler).await;
assert!(
result.is_ok(),
"dns failure path stalled or panicked instead of terminating"
);
}

View File

@ -0,0 +1,51 @@
use super::*;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::io::AsyncWrite;
struct NeverWritable;
impl AsyncWrite for NeverWritable {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Pending
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn shape_padding_returns_before_global_mask_timeout_on_blocked_writer() {
let mut writer = NeverWritable;
let started = Instant::now();
maybe_write_shape_padding(&mut writer, 1, true, 256, 4096, false, 0, false).await;
assert!(
started.elapsed() <= MASK_TIMEOUT + std::time::Duration::from_millis(30),
"shape padding blocked past timeout budget"
);
}
#[tokio::test]
async fn shape_padding_with_non_http_blur_disabled_at_cap_writes_nothing() {
let mut output = Vec::new();
{
let mut writer = tokio::io::BufWriter::new(&mut output);
maybe_write_shape_padding(&mut writer, 4096, true, 64, 4096, false, 128, false).await;
use tokio::io::AsyncWriteExt;
writer.flush().await.unwrap();
}
assert!(output.is_empty());
}

View File

@ -0,0 +1,283 @@
use super::*;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::time::{Duration, Instant, timeout};
const PROD_CAP_BYTES: usize = 5 * 1024 * 1024;
struct FinitePatternReader {
remaining: usize,
chunk: usize,
read_calls: Arc<AtomicUsize>,
}
impl FinitePatternReader {
fn new(total: usize, chunk: usize, read_calls: Arc<AtomicUsize>) -> Self {
Self {
remaining: total,
chunk,
read_calls,
}
}
}
impl AsyncRead for FinitePatternReader {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.read_calls.fetch_add(1, Ordering::Relaxed);
if self.remaining == 0 {
return Poll::Ready(Ok(()));
}
let take = self.remaining.min(self.chunk).min(buf.remaining());
if take == 0 {
return Poll::Ready(Ok(()));
}
let fill = vec![0x5Au8; take];
buf.put_slice(&fill);
self.remaining -= take;
Poll::Ready(Ok(()))
}
}
#[derive(Default)]
struct CountingWriter {
written: usize,
}
impl AsyncWrite for CountingWriter {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.written = self.written.saturating_add(buf.len());
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Ready(Ok(()))
}
}
struct NeverReadyReader;
impl AsyncRead for NeverReadyReader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Poll::Pending
}
}
struct BudgetProbeReader {
remaining: usize,
total_read: Arc<AtomicUsize>,
}
impl BudgetProbeReader {
fn new(total: usize, total_read: Arc<AtomicUsize>) -> Self {
Self {
remaining: total,
total_read,
}
}
}
impl AsyncRead for BudgetProbeReader {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if self.remaining == 0 {
return Poll::Ready(Ok(()));
}
let take = self.remaining.min(buf.remaining());
if take == 0 {
return Poll::Ready(Ok(()));
}
let fill = vec![0xA5u8; take];
buf.put_slice(&fill);
self.remaining -= take;
self.total_read.fetch_add(take, Ordering::Relaxed);
Poll::Ready(Ok(()))
}
}
#[tokio::test]
async fn positive_copy_with_production_cap_stops_exactly_at_budget() {
let read_calls = Arc::new(AtomicUsize::new(0));
let mut reader = FinitePatternReader::new(PROD_CAP_BYTES + (256 * 1024), 4096, read_calls);
let mut writer = CountingWriter::default();
let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await;
assert_eq!(
outcome.total, PROD_CAP_BYTES,
"copy path must stop at explicit production cap"
);
assert_eq!(writer.written, PROD_CAP_BYTES);
assert!(
!outcome.ended_by_eof,
"byte-cap stop must not be misclassified as EOF"
);
}
#[tokio::test]
async fn negative_consume_with_zero_cap_performs_no_reads() {
let read_calls = Arc::new(AtomicUsize::new(0));
let reader = FinitePatternReader::new(1024, 64, Arc::clone(&read_calls));
consume_client_data_with_timeout_and_cap(reader, 0).await;
assert_eq!(
read_calls.load(Ordering::Relaxed),
0,
"zero cap must return before reading attacker-controlled bytes"
);
}
#[tokio::test]
async fn edge_copy_below_cap_reports_eof_without_overread() {
let read_calls = Arc::new(AtomicUsize::new(0));
let payload = 73 * 1024;
let mut reader = FinitePatternReader::new(payload, 3072, read_calls);
let mut writer = CountingWriter::default();
let outcome = copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await;
assert_eq!(outcome.total, payload);
assert_eq!(writer.written, payload);
assert!(
outcome.ended_by_eof,
"finite upstream below cap must terminate via EOF path"
);
}
#[tokio::test]
async fn adversarial_blackhat_never_ready_reader_is_bounded_by_timeout_guards() {
let started = Instant::now();
consume_client_data_with_timeout_and_cap(NeverReadyReader, PROD_CAP_BYTES).await;
assert!(
started.elapsed() < Duration::from_millis(350),
"never-ready reader must be bounded by idle/relay timeout protections"
);
}
#[tokio::test]
async fn integration_consume_path_honors_production_cap_for_large_payload() {
let read_calls = Arc::new(AtomicUsize::new(0));
let reader = FinitePatternReader::new(PROD_CAP_BYTES + (1024 * 1024), 8192, read_calls);
let bounded = timeout(
Duration::from_millis(350),
consume_client_data_with_timeout_and_cap(reader, PROD_CAP_BYTES),
)
.await;
assert!(
bounded.is_ok(),
"consume path with production cap must finish within bounded time"
);
}
#[tokio::test]
async fn adversarial_consume_path_never_reads_beyond_declared_byte_cap() {
let byte_cap = 5usize;
let total_read = Arc::new(AtomicUsize::new(0));
let reader = BudgetProbeReader::new(256 * 1024, Arc::clone(&total_read));
consume_client_data_with_timeout_and_cap(reader, byte_cap).await;
assert!(
total_read.load(Ordering::Relaxed) <= byte_cap,
"consume path must not read more than configured byte cap"
);
}
#[tokio::test]
async fn light_fuzz_cap_and_payload_matrix_preserves_min_budget_invariant() {
let mut seed = 0x1234_5678_9ABC_DEF0u64;
for _case in 0..96u32 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let cap = ((seed & 0x3ffff) as usize).saturating_add(1);
let payload = ((seed.rotate_left(11) & 0x7ffff) as usize).saturating_add(1);
let chunk = (((seed >> 5) & 0x1fff) as usize).saturating_add(1);
let read_calls = Arc::new(AtomicUsize::new(0));
let mut reader = FinitePatternReader::new(payload, chunk, read_calls);
let mut writer = CountingWriter::default();
let outcome = copy_with_idle_timeout(&mut reader, &mut writer, cap, true).await;
let expected = payload.min(cap);
assert_eq!(
outcome.total, expected,
"copy total must match min(payload, cap) under fuzzed inputs"
);
assert_eq!(writer.written, expected);
if payload <= cap {
assert!(outcome.ended_by_eof);
} else {
assert!(!outcome.ended_by_eof);
}
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_copy_tasks_with_production_cap_complete_without_leaks() {
let workers = 8usize;
let mut tasks = Vec::with_capacity(workers);
for idx in 0..workers {
tasks.push(tokio::spawn(async move {
let read_calls = Arc::new(AtomicUsize::new(0));
let mut reader = FinitePatternReader::new(
PROD_CAP_BYTES + (idx + 1) * 4096,
4096 + (idx * 257),
read_calls,
);
let mut writer = CountingWriter::default();
copy_with_idle_timeout(&mut reader, &mut writer, PROD_CAP_BYTES, true).await
}));
}
timeout(Duration::from_secs(3), async {
for task in tasks {
let outcome = task.await.expect("stress task must not panic");
assert_eq!(
outcome.total, PROD_CAP_BYTES,
"stress copy task must stay within production cap"
);
assert!(
!outcome.ended_by_eof,
"stress task should end due to cap, not EOF"
);
}
})
.await
.expect("stress suite must complete in bounded time");
}

View File

@ -0,0 +1,105 @@
use super::*;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex, sink};
use tokio::time::{Duration, timeout};
#[tokio::test]
async fn relay_to_mask_enforces_masking_session_byte_cap() {
let initial = vec![0x16, 0x03, 0x01, 0x00, 0x01];
let extra = vec![0xAB; 96 * 1024];
let (client_reader, mut client_writer) = duplex(128 * 1024);
let (mask_read, _mask_read_peer) = duplex(1024);
let (mut mask_observer, mask_write) = duplex(256 * 1024);
let initial_for_task = initial.clone();
let relay = tokio::spawn(async move {
relay_to_mask(
client_reader,
sink(),
mask_read,
mask_write,
&initial_for_task,
false,
512,
4096,
false,
0,
false,
32 * 1024,
)
.await;
});
client_writer.write_all(&extra).await.unwrap();
client_writer.shutdown().await.unwrap();
timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
let mut observed = Vec::new();
timeout(
Duration::from_secs(2),
mask_observer.read_to_end(&mut observed),
)
.await
.unwrap()
.unwrap();
// In this deterministic test, relay must stop exactly at the configured cap.
assert_eq!(
observed.len(),
initial.len() + (32 * 1024),
"masked relay must forward exactly up to the cap (observed={} initial={} cap={})",
observed.len(),
initial.len(),
32 * 1024
);
}
#[tokio::test]
async fn relay_to_mask_propagates_client_half_close_without_waiting_for_other_direction_timeout() {
let initial = b"GET /half-close HTTP/1.1\r\n".to_vec();
let (client_reader, mut client_writer) = duplex(8 * 1024);
let (mask_read, _mask_read_peer) = duplex(8 * 1024);
let (mut mask_observer, mask_write) = duplex(8 * 1024);
let initial_for_task = initial.clone();
let relay = tokio::spawn(async move {
relay_to_mask(
client_reader,
sink(),
mask_read,
mask_write,
&initial_for_task,
false,
512,
4096,
false,
0,
false,
32 * 1024,
)
.await;
});
client_writer.shutdown().await.unwrap();
let mut observed = Vec::new();
timeout(
Duration::from_millis(80),
mask_observer.read_to_end(&mut observed),
)
.await
.expect("mask backend write side should be half-closed promptly")
.unwrap();
assert_eq!(&observed[..initial.len()], initial.as_slice());
timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
}

View File

@ -0,0 +1,100 @@
use super::*;
use tokio::io::AsyncReadExt;
use tokio::time::{Duration, timeout};
async fn collect_padding(
total_sent: usize,
enabled: bool,
floor: usize,
cap: usize,
above_cap_blur: bool,
blur_max: usize,
aggressive: bool,
) -> Vec<u8> {
let (mut tx, mut rx) = tokio::io::duplex(256 * 1024);
maybe_write_shape_padding(
&mut tx,
total_sent,
enabled,
floor,
cap,
above_cap_blur,
blur_max,
aggressive,
)
.await;
drop(tx);
let mut output = Vec::new();
timeout(Duration::from_secs(1), rx.read_to_end(&mut output))
.await
.expect("reading padded output timed out")
.expect("failed reading padded output");
output
}
#[tokio::test]
async fn padding_output_is_not_all_zero() {
let output = collect_padding(1, true, 256, 4096, false, 0, false).await;
assert!(
output.len() >= 255,
"expected at least 255 padding bytes, got {}",
output.len()
);
let nonzero = output.iter().filter(|&&b| b != 0).count();
// In 255 bytes of uniform randomness, the expected number of zero bytes is ~1.
// A weak nonzero check can miss severe entropy collapse.
assert!(
nonzero >= 240,
"RNG output entropy collapsed, too many zero bytes: {} nonzero out of {}",
nonzero,
output.len(),
);
}
#[tokio::test]
async fn padding_reaches_first_bucket_boundary() {
let output = collect_padding(1, true, 64, 4096, false, 0, false).await;
assert_eq!(output.len(), 63);
}
#[tokio::test]
async fn disabled_padding_produces_no_output() {
let output = collect_padding(0, false, 256, 4096, false, 0, false).await;
assert!(output.is_empty());
}
#[tokio::test]
async fn at_cap_without_blur_produces_no_output() {
let output = collect_padding(4096, true, 64, 4096, false, 0, false).await;
assert!(output.is_empty());
}
#[tokio::test]
async fn above_cap_blur_is_positive_and_bounded_in_aggressive_mode() {
let output = collect_padding(4096, true, 64, 4096, true, 128, true).await;
assert!(!output.is_empty());
assert!(output.len() <= 128, "blur exceeded max: {}", output.len());
}
#[tokio::test]
async fn stress_padding_runs_are_not_constant_pattern() {
// Stress and sanity-check: repeated runs should not collapse to identical
// first 16 bytes across all samples.
let mut first_chunks = Vec::new();
for _ in 0..64 {
let out = collect_padding(1, true, 64, 4096, false, 0, false).await;
first_chunks.push(out[..16].to_vec());
}
let first = &first_chunks[0];
let all_same = first_chunks.iter().all(|chunk| chunk == first);
assert!(
!all_same,
"all stress samples had identical prefix, rng output appears degenerate"
);
}

View File

@ -1376,6 +1376,7 @@ async fn relay_to_mask_keeps_backend_to_client_flow_when_client_to_backend_stall
false,
0,
false,
5 * 1024 * 1024,
)
.await;
});
@ -1506,6 +1507,7 @@ async fn relay_to_mask_timeout_cancels_and_drops_all_io_endpoints() {
false,
0,
false,
5 * 1024 * 1024,
),
)
.await;

View File

@ -0,0 +1,330 @@
use super::*;
use std::net::SocketAddr;
use std::net::TcpListener as StdTcpListener;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::net::TcpListener;
use tokio::time::{Duration, Instant, timeout};
fn closed_local_port() -> u16 {
let listener = StdTcpListener::bind("127.0.0.1:0").unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
port
}
#[tokio::test]
async fn self_target_detection_matches_literal_ipv4_listener() {
let local: SocketAddr = "198.51.100.40:443".parse().unwrap();
assert!(is_mask_target_local_listener_async("198.51.100.40", 443, local, None,).await);
}
#[tokio::test]
async fn self_target_detection_matches_bracketed_ipv6_listener() {
let local: SocketAddr = "[2001:db8::44]:8443".parse().unwrap();
assert!(is_mask_target_local_listener_async("[2001:db8::44]", 8443, local, None,).await);
}
#[tokio::test]
async fn self_target_detection_keeps_same_ip_different_port_forwardable() {
let local: SocketAddr = "203.0.113.44:443".parse().unwrap();
assert!(!is_mask_target_local_listener_async("203.0.113.44", 8443, local, None,).await);
}
#[tokio::test]
async fn self_target_detection_normalizes_ipv4_mapped_ipv6_literal() {
let local: SocketAddr = "127.0.0.1:443".parse().unwrap();
assert!(is_mask_target_local_listener_async("::ffff:127.0.0.1", 443, local, None,).await);
}
#[tokio::test]
async fn self_target_detection_unspecified_bind_blocks_loopback_target() {
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
assert!(is_mask_target_local_listener_async("127.0.0.1", 443, local, None,).await);
}
#[tokio::test]
async fn self_target_detection_unspecified_bind_keeps_remote_target_forwardable() {
let local: SocketAddr = "0.0.0.0:443".parse().unwrap();
let remote: SocketAddr = "198.51.100.44:443".parse().unwrap();
assert!(!is_mask_target_local_listener_async("mask.example", 443, local, Some(remote),).await);
}
#[tokio::test]
async fn self_target_fallback_refuses_recursive_loopback_connect() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let local_addr = listener.local_addr().unwrap();
let accept_task = tokio::spawn(async move {
timeout(Duration::from_millis(120), listener.accept())
.await
.is_ok()
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some(local_addr.ip().to_string());
config.censorship.mask_port = local_addr.port();
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.90:55090".parse().unwrap();
let beobachten = BeobachtenStore::new();
handle_bad_client(
tokio::io::empty(),
tokio::io::sink(),
b"GET /",
peer,
local_addr,
&config,
&beobachten,
)
.await;
let accepted = accept_task.await.unwrap();
assert!(
!accepted,
"self-target masking must fail closed without connecting to local listener"
);
}
#[tokio::test]
async fn same_ip_different_port_still_forwards_to_mask_backend() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let backend_addr = listener.local_addr().unwrap();
let probe = b"GET /".to_vec();
let accept_task = tokio::spawn({
let expected = probe.clone();
async move {
let (mut stream, _) = listener.accept().await.unwrap();
let mut got = vec![0u8; expected.len()];
stream.read_exact(&mut got).await.unwrap();
assert_eq!(got, expected);
}
});
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = backend_addr.port();
config.censorship.mask_proxy_protocol = 0;
let peer: SocketAddr = "203.0.113.91:55091".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
handle_bad_client(
tokio::io::empty(),
tokio::io::sink(),
&probe,
peer,
local_addr,
&config,
&beobachten,
)
.await;
timeout(Duration::from_secs(2), accept_task)
.await
.unwrap()
.unwrap();
}
#[test]
fn detect_client_type_http_boundary_get_and_post() {
assert_eq!(detect_client_type(b"GET "), "HTTP");
assert_eq!(detect_client_type(b"GET /"), "HTTP");
assert_eq!(detect_client_type(b"POST"), "HTTP");
assert_eq!(detect_client_type(b"POST "), "HTTP");
assert_eq!(detect_client_type(b"POSTX"), "HTTP");
}
#[test]
fn detect_client_type_tls_and_length_boundaries() {
assert_eq!(detect_client_type(b"\x16\x03\x01"), "port-scanner");
assert_eq!(detect_client_type(b"\x16\x03\x01\x00"), "TLS-scanner");
assert_eq!(detect_client_type(b"123456789"), "port-scanner");
assert_eq!(detect_client_type(b"1234567890"), "unknown");
}
#[test]
fn build_mask_proxy_header_v1_cross_family_falls_back_to_unknown() {
let peer: SocketAddr = "192.168.1.5:12345".parse().unwrap();
let local: SocketAddr = "[2001:db8::1]:443".parse().unwrap();
let header = build_mask_proxy_header(1, peer, local).unwrap();
assert_eq!(header, b"PROXY UNKNOWN\r\n");
}
#[test]
fn next_mask_shape_bucket_checked_mul_overflow_fails_closed() {
let floor = usize::MAX / 2 + 1;
let cap = usize::MAX;
let total = floor + 1;
assert_eq!(next_mask_shape_bucket(total, floor, cap), total);
}
#[tokio::test]
async fn self_target_reject_path_keeps_timing_budget() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 443;
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let peer: SocketAddr = "203.0.113.92:55092".parse().unwrap();
let beobachten = BeobachtenStore::new();
let (client, server) = duplex(1024);
drop(client);
let started = Instant::now();
handle_bad_client(
server,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(40) && elapsed < Duration::from_millis(250),
"self-target reject path must keep coarse timing budget without stalling"
);
}
#[tokio::test]
async fn relay_path_idle_timeout_eviction_remains_effective() {
let (client_read, mut client_write) = duplex(1024);
let (mask_read, mask_write) = duplex(1024);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
client_write.write_all(b"a").await.unwrap();
tokio::time::sleep(Duration::from_millis(180)).await;
let _ = client_write.write_all(b"b").await;
});
let started = Instant::now();
relay_to_mask(
client_read,
tokio::io::sink(),
mask_read,
mask_write,
b"init",
false,
0,
0,
false,
0,
false,
5 * 1024 * 1024,
)
.await;
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(90) && elapsed < Duration::from_millis(180),
"idle-timeout eviction must occur before late trickle write"
);
}
#[tokio::test]
async fn offline_mask_target_refusal_respects_timing_normalization_budget() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = closed_local_port();
config.censorship.mask_timing_normalization_enabled = true;
config.censorship.mask_timing_normalization_floor_ms = 120;
config.censorship.mask_timing_normalization_ceiling_ms = 120;
let peer: SocketAddr = "203.0.113.93:55093".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let (mut client, server) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
server,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
client.shutdown().await.unwrap();
timeout(Duration::from_secs(2), task)
.await
.unwrap()
.unwrap();
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(110) && elapsed < Duration::from_millis(220),
"offline-refusal path must honor normalization budget without unbounded drift"
);
}
#[tokio::test]
async fn offline_mask_target_refusal_with_idle_client_is_bounded_by_consume_timeout() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = closed_local_port();
config.censorship.mask_timing_normalization_enabled = false;
let peer: SocketAddr = "203.0.113.94:55094".parse().unwrap();
let local_addr: SocketAddr = "127.0.0.1:443".parse().unwrap();
let beobachten = BeobachtenStore::new();
let (mut client, server) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
server,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
tokio::time::sleep(Duration::from_millis(120)).await;
client
.write_all(b"still-open-before-timeout")
.await
.expect("connection should still be open before consume timeout expires");
timeout(Duration::from_secs(2), task)
.await
.unwrap()
.unwrap();
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(190) && elapsed < Duration::from_millis(350),
"offline-refusal path must not retain idle client indefinitely"
);
}

View File

@ -43,6 +43,7 @@ async fn run_relay_case(
above_cap_blur,
above_cap_blur_max_bytes,
false,
5 * 1024 * 1024,
)
.await;
});

View File

@ -88,6 +88,7 @@ async fn relay_to_mask_applies_cap_clamped_padding_for_non_power_of_two_cap() {
false,
0,
false,
5 * 1024 * 1024,
)
.await;
});

View File

@ -0,0 +1,58 @@
#![cfg(unix)]
use super::*;
use tokio::io::{AsyncWriteExt, duplex};
use tokio::time::{Duration, Instant, timeout};
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn adversarial_delayed_interface_lookup_does_not_consume_outcome_floor_budget() {
let mut config = ProxyConfig::default();
config.general.beobachten = false;
config.censorship.mask = true;
config.censorship.mask_unix_sock = None;
config.censorship.mask_host = Some("127.0.0.1".to_string());
config.censorship.mask_port = 443;
config.censorship.mask_timing_normalization_enabled = true;
config.censorship.mask_timing_normalization_floor_ms = 120;
config.censorship.mask_timing_normalization_ceiling_ms = 120;
let peer: SocketAddr = "203.0.113.151:55151".parse().expect("valid peer");
let local_addr: SocketAddr = "0.0.0.0:443".parse().expect("valid local addr");
let beobachten = BeobachtenStore::new();
let refresh_lock = LOCAL_INTERFACE_REFRESH_LOCK.get_or_init(|| AsyncMutex::new(()));
let held_refresh_guard = refresh_lock.lock().await;
let (mut client, server) = duplex(1024);
let started = Instant::now();
let task = tokio::spawn(async move {
handle_bad_client(
server,
tokio::io::sink(),
b"GET / HTTP/1.1\r\n\r\n",
peer,
local_addr,
&config,
&beobachten,
)
.await;
});
tokio::time::sleep(Duration::from_millis(80)).await;
drop(held_refresh_guard);
client
.shutdown()
.await
.expect("client shutdown must succeed");
timeout(Duration::from_secs(2), task)
.await
.expect("task must finish in bounded time")
.expect("task must not panic");
let elapsed = started.elapsed();
assert!(
elapsed >= Duration::from_millis(180) && elapsed < Duration::from_millis(350),
"timing normalization floor must start after pre-outcome self-target checks"
);
}

View File

@ -0,0 +1,189 @@
use super::*;
use crate::crypto::AesCtr;
use bytes::Bytes;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::task::{Context, Poll};
use tokio::io::AsyncWrite;
struct CountedWriter {
write_calls: Arc<AtomicUsize>,
fail_writes: bool,
}
impl CountedWriter {
fn new(write_calls: Arc<AtomicUsize>, fail_writes: bool) -> Self {
Self {
write_calls,
fail_writes,
}
}
}
impl AsyncWrite for CountedWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
this.write_calls.fetch_add(1, Ordering::Relaxed);
if this.fail_writes {
Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"forced write failure",
)))
} else {
Poll::Ready(Ok(buf.len()))
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
fn make_crypto_writer(inner: CountedWriter) -> CryptoWriter<CountedWriter> {
let key = [0u8; 32];
let iv = 0u128;
CryptoWriter::new(inner, AesCtr::new(&key, iv), 8 * 1024)
}
#[tokio::test]
async fn me_writer_write_fail_keeps_reserved_quota_and_tracks_fail_metrics() {
let stats = Stats::new();
let user = "middle-me-writer-no-rollback-user";
let user_stats = stats.get_or_create_user_stats_handle(user);
let write_calls = Arc::new(AtomicUsize::new(0));
let mut writer = make_crypto_writer(CountedWriter::new(write_calls.clone(), true));
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let payload = Bytes::from_static(&[0x11, 0x22, 0x33, 0x44, 0x55]);
let result = process_me_writer_response(
MeResponse::Data {
flags: 0,
data: payload.clone(),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
user,
Some(user_stats.as_ref()),
Some(64),
0,
&bytes_me2c,
11,
true,
false,
)
.await;
assert!(
matches!(result, Err(ProxyError::Io(_))),
"write failure must propagate as I/O error"
);
assert!(
write_calls.load(Ordering::Relaxed) > 0,
"writer must be attempted after successful quota reservation"
);
assert_eq!(
stats.get_user_quota_used(user),
payload.len() as u64,
"reserved quota must not roll back on write failure"
);
assert_eq!(
stats.get_quota_write_fail_bytes_total(),
payload.len() as u64,
"write-fail byte metric must include failed payload size"
);
assert_eq!(
stats.get_quota_write_fail_events_total(),
1,
"write-fail events metric must increment once"
);
assert_eq!(
stats.get_user_total_octets(user),
0,
"telemetry octets_to should not advance when write fails"
);
assert_eq!(
bytes_me2c.load(Ordering::Relaxed),
0,
"ME->C committed byte counter must not advance on write failure"
);
}
#[tokio::test]
async fn me_writer_pre_write_quota_reject_happens_before_writer_poll() {
let stats = Stats::new();
let user = "middle-me-writer-precheck-user";
let limit = 8u64;
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), limit);
let write_calls = Arc::new(AtomicUsize::new(0));
let mut writer = make_crypto_writer(CountedWriter::new(write_calls.clone(), false));
let mut frame_buf = Vec::new();
let bytes_me2c = AtomicU64::new(0);
let result = process_me_writer_response(
MeResponse::Data {
flags: 0,
data: Bytes::from_static(&[0xAA, 0xBB, 0xCC]),
},
&mut writer,
ProtoTag::Intermediate,
&SecureRandom::new(),
&mut frame_buf,
&stats,
user,
Some(user_stats.as_ref()),
Some(limit),
0,
&bytes_me2c,
12,
true,
false,
)
.await;
assert!(
matches!(result, Err(ProxyError::DataQuotaExceeded { .. })),
"pre-write quota rejection must return typed quota error"
);
assert_eq!(
write_calls.load(Ordering::Relaxed),
0,
"writer must not be polled when pre-write quota reservation fails"
);
assert_eq!(
stats.get_me_d2c_quota_reject_pre_write_total(),
1,
"pre-write quota reject metric must increment"
);
assert_eq!(
stats.get_user_quota_used(user),
limit,
"failed pre-write reservation must keep previous quota usage unchanged"
);
assert_eq!(
stats.get_quota_write_fail_bytes_total(),
0,
"write-fail bytes metric must stay unchanged on pre-write reject"
);
assert_eq!(
stats.get_quota_write_fail_events_total(),
0,
"write-fail events metric must stay unchanged on pre-write reject"
);
assert_eq!(bytes_me2c.load(Ordering::Relaxed), 0);
}

View File

@ -1,112 +0,0 @@
use super::*;
use crate::stats::Stats;
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::sync::Barrier;
use tokio::time::{Duration, timeout};
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn blackhat_campaign_saturation_quota_race_with_queue_pressure_stays_fail_closed() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!(
"middle-blackhat-held-{}-{idx}",
std::process::id()
)));
}
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"precondition: bounded lock cache must be saturated"
);
let (tx, _rx) = mpsc::channel::<C2MeCommand>(1);
tx.send(C2MeCommand::Close)
.await
.expect("queue prefill should succeed");
let pressure_seq_before = relay_pressure_event_seq();
let pressure_errors = Arc::new(AtomicUsize::new(0));
let mut pressure_workers = Vec::new();
for _ in 0..16 {
let tx = tx.clone();
let pressure_errors = Arc::clone(&pressure_errors);
pressure_workers.push(tokio::spawn(async move {
if enqueue_c2me_command(&tx, C2MeCommand::Close).await.is_err() {
pressure_errors.fetch_add(1, Ordering::Relaxed);
}
}));
}
let stats = Arc::new(Stats::new());
let user = format!("middle-blackhat-quota-race-{}", std::process::id());
let gate = Arc::new(Barrier::new(16));
let mut quota_workers = Vec::new();
for _ in 0..16u8 {
let stats = Arc::clone(&stats);
let user = user.clone();
let gate = Arc::clone(&gate);
quota_workers.push(tokio::spawn(async move {
gate.wait().await;
let user_lock = quota_user_lock(&user);
let _quota_guard = user_lock.lock().await;
if quota_would_be_exceeded_for_user(&stats, &user, Some(1), 1) {
return false;
}
stats.add_user_octets_to(&user, 1);
true
}));
}
let mut ok_count = 0usize;
let mut denied_count = 0usize;
for worker in quota_workers {
let result = timeout(Duration::from_secs(2), worker)
.await
.expect("quota worker must finish")
.expect("quota worker must not panic");
if result {
ok_count += 1;
} else {
denied_count += 1;
}
}
for worker in pressure_workers {
timeout(Duration::from_secs(2), worker)
.await
.expect("pressure worker must finish")
.expect("pressure worker must not panic");
}
assert_eq!(
stats.get_user_total_octets(&user),
1,
"black-hat campaign must not overshoot same-user quota under saturation"
);
assert!(ok_count <= 1, "at most one quota contender may succeed");
assert!(
denied_count >= 15,
"all remaining contenders must be quota-denied"
);
let pressure_seq_after = relay_pressure_event_seq();
assert!(
pressure_seq_after > pressure_seq_before,
"queue pressure leg must trigger pressure accounting"
);
assert!(
pressure_errors.load(Ordering::Relaxed) >= 1,
"at least one pressure worker should fail from persistent backpressure"
);
drop(retained);
}

View File

@ -1,708 +0,0 @@
use super::*;
use crate::crypto::AesCtr;
use crate::crypto::SecureRandom;
use crate::stats::Stats;
use crate::stream::{BufferPool, PooledBuffer};
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::duplex;
use tokio::sync::mpsc;
use tokio::time::{Duration as TokioDuration, timeout};
fn make_pooled_payload(data: &[u8]) -> PooledBuffer {
let pool = Arc::new(BufferPool::with_config(data.len().max(1), 4));
let mut payload = pool.get();
payload.resize(data.len(), 0);
payload[..data.len()].copy_from_slice(data);
payload
}
#[tokio::test]
async fn write_client_payload_abridged_short_quickack_sets_flag_and_preserves_payload() {
let (mut read_side, write_side) = duplex(4096);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = vec![0xA1, 0xB2, 0xC3, 0xD4, 0x10, 0x20, 0x30, 0x40];
write_client_payload(
&mut writer,
ProtoTag::Abridged,
RPC_FLAG_QUICKACK,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("abridged quickack payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 1 + payload.len()];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read serialized abridged frame");
let plaintext = decryptor.decrypt(&encrypted);
assert_eq!(plaintext[0], 0x80 | ((payload.len() / 4) as u8));
assert_eq!(&plaintext[1..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_abridged_extended_header_is_encoded_correctly() {
let (mut read_side, write_side) = duplex(16 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
// Boundary where abridged switches to extended length encoding.
let payload = vec![0x5Au8; 0x7f * 4];
write_client_payload(
&mut writer,
ProtoTag::Abridged,
RPC_FLAG_QUICKACK,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("extended abridged payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 4 + payload.len()];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read serialized extended abridged frame");
let plaintext = decryptor.decrypt(&encrypted);
assert_eq!(plaintext[0], 0xff, "0x7f with quickack bit must be set");
assert_eq!(&plaintext[1..4], &[0x7f, 0x00, 0x00]);
assert_eq!(&plaintext[4..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_abridged_misaligned_is_rejected_fail_closed() {
let (_read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let err = write_client_payload(
&mut writer,
ProtoTag::Abridged,
0,
&[1, 2, 3],
&rng,
&mut frame_buf,
)
.await
.expect_err("misaligned abridged payload must be rejected");
let msg = format!("{err}");
assert!(
msg.contains("4-byte aligned"),
"error should explain alignment contract, got: {msg}"
);
}
#[tokio::test]
async fn write_client_payload_secure_misaligned_is_rejected_fail_closed() {
let (_read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let err = write_client_payload(
&mut writer,
ProtoTag::Secure,
0,
&[9, 8, 7, 6, 5],
&rng,
&mut frame_buf,
)
.await
.expect_err("misaligned secure payload must be rejected");
let msg = format!("{err}");
assert!(
msg.contains("Secure payload must be 4-byte aligned"),
"error should be explicit for fail-closed triage, got: {msg}"
);
}
#[tokio::test]
async fn write_client_payload_intermediate_quickack_sets_length_msb() {
let (mut read_side, write_side) = duplex(4096);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = b"hello-middle-relay";
write_client_payload(
&mut writer,
ProtoTag::Intermediate,
RPC_FLAG_QUICKACK,
payload,
&rng,
&mut frame_buf,
)
.await
.expect("intermediate quickack payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 4 + payload.len()];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read intermediate frame");
let plaintext = decryptor.decrypt(&encrypted);
let mut len_bytes = [0u8; 4];
len_bytes.copy_from_slice(&plaintext[..4]);
let len_with_flags = u32::from_le_bytes(len_bytes);
assert_ne!(len_with_flags & 0x8000_0000, 0, "quickack bit must be set");
assert_eq!((len_with_flags & 0x7fff_ffff) as usize, payload.len());
assert_eq!(&plaintext[4..], payload);
}
#[tokio::test]
async fn write_client_payload_secure_quickack_prefix_and_padding_bounds_hold() {
let (mut read_side, write_side) = duplex(4096);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = vec![0x33u8; 100]; // 4-byte aligned as required by secure mode.
write_client_payload(
&mut writer,
ProtoTag::Secure,
RPC_FLAG_QUICKACK,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("secure quickack payload should serialize");
writer.flush().await.expect("flush must succeed");
// Secure mode adds 1..=3 bytes of randomized tail padding.
let mut encrypted_header = [0u8; 4];
read_side
.read_exact(&mut encrypted_header)
.await
.expect("must read secure header");
let decrypted_header = decryptor.decrypt(&encrypted_header);
let header: [u8; 4] = decrypted_header
.try_into()
.expect("decrypted secure header must be 4 bytes");
let wire_len_raw = u32::from_le_bytes(header);
assert_ne!(
wire_len_raw & 0x8000_0000,
0,
"secure quickack bit must be set"
);
let wire_len = (wire_len_raw & 0x7fff_ffff) as usize;
assert!(wire_len >= payload.len());
let padding_len = wire_len - payload.len();
assert!(
(1..=3).contains(&padding_len),
"secure writer must add bounded random tail padding, got {padding_len}"
);
let mut encrypted_body = vec![0u8; wire_len];
read_side
.read_exact(&mut encrypted_body)
.await
.expect("must read secure body");
let decrypted_body = decryptor.decrypt(&encrypted_body);
assert_eq!(&decrypted_body[..payload.len()], payload.as_slice());
}
#[tokio::test]
#[ignore = "heavy: allocates >64MiB to validate abridged too-large fail-closed branch"]
async fn write_client_payload_abridged_too_large_is_rejected_fail_closed() {
let (_read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
// Exactly one 4-byte word above the encodable 24-bit abridged length range.
let payload = vec![0x00u8; (1 << 24) * 4];
let err = write_client_payload(
&mut writer,
ProtoTag::Abridged,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect_err("oversized abridged payload must be rejected");
let msg = format!("{err}");
assert!(
msg.contains("Abridged frame too large"),
"error must clearly indicate oversize fail-close path, got: {msg}"
);
}
#[tokio::test]
async fn write_client_ack_intermediate_is_little_endian() {
let (mut read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
write_client_ack(&mut writer, ProtoTag::Intermediate, 0x11_22_33_44)
.await
.expect("ack serialization should succeed");
writer.flush().await.expect("flush must succeed");
let mut encrypted = [0u8; 4];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read ack bytes");
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain.as_slice(), &0x11_22_33_44u32.to_le_bytes());
}
#[tokio::test]
async fn write_client_ack_abridged_is_big_endian() {
let (mut read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
write_client_ack(&mut writer, ProtoTag::Abridged, 0xDE_AD_BE_EF)
.await
.expect("ack serialization should succeed");
writer.flush().await.expect("flush must succeed");
let mut encrypted = [0u8; 4];
read_side
.read_exact(&mut encrypted)
.await
.expect("must read ack bytes");
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain.as_slice(), &0xDE_AD_BE_EFu32.to_be_bytes());
}
#[tokio::test]
async fn write_client_payload_abridged_short_boundary_0x7e_is_single_byte_header() {
let (mut read_side, write_side) = duplex(1024 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = vec![0xABu8; 0x7e * 4];
write_client_payload(
&mut writer,
ProtoTag::Abridged,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("boundary payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 1 + payload.len()];
read_side.read_exact(&mut encrypted).await.unwrap();
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain[0], 0x7e);
assert_eq!(&plain[1..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_abridged_extended_without_quickack_has_clean_prefix() {
let (mut read_side, write_side) = duplex(16 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = vec![0x42u8; 0x80 * 4];
write_client_payload(
&mut writer,
ProtoTag::Abridged,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("extended payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = vec![0u8; 4 + payload.len()];
read_side.read_exact(&mut encrypted).await.unwrap();
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain[0], 0x7f);
assert_eq!(&plain[1..4], &[0x80, 0x00, 0x00]);
assert_eq!(&plain[4..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_intermediate_zero_length_emits_header_only() {
let (mut read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
write_client_payload(
&mut writer,
ProtoTag::Intermediate,
0,
&[],
&rng,
&mut frame_buf,
)
.await
.expect("zero-length intermediate payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = [0u8; 4];
read_side.read_exact(&mut encrypted).await.unwrap();
let plain = decryptor.decrypt(&encrypted);
assert_eq!(plain.as_slice(), &[0, 0, 0, 0]);
}
#[tokio::test]
async fn write_client_payload_intermediate_ignores_unrelated_flags() {
let (mut read_side, write_side) = duplex(1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = [7u8; 12];
write_client_payload(
&mut writer,
ProtoTag::Intermediate,
0x4000_0000,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted = [0u8; 16];
read_side.read_exact(&mut encrypted).await.unwrap();
let plain = decryptor.decrypt(&encrypted);
let len = u32::from_le_bytes(plain[0..4].try_into().unwrap());
assert_eq!(len, payload.len() as u32, "only quickack bit may affect header");
assert_eq!(&plain[4..], payload.as_slice());
}
#[tokio::test]
async fn write_client_payload_secure_without_quickack_keeps_msb_clear() {
let (mut read_side, write_side) = duplex(4096);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = [0x1Du8; 64];
write_client_payload(
&mut writer,
ProtoTag::Secure,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted_header = [0u8; 4];
read_side.read_exact(&mut encrypted_header).await.unwrap();
let plain_header = decryptor.decrypt(&encrypted_header);
let h: [u8; 4] = plain_header.as_slice().try_into().unwrap();
let wire_len_raw = u32::from_le_bytes(h);
assert_eq!(wire_len_raw & 0x8000_0000, 0, "quickack bit must stay clear");
}
#[tokio::test]
async fn secure_padding_light_fuzz_distribution_has_multiple_outcomes() {
let (mut read_side, write_side) = duplex(256 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let payload = [0x55u8; 100];
let mut seen = [false; 4];
for _ in 0..96 {
write_client_payload(
&mut writer,
ProtoTag::Secure,
0,
&payload,
&rng,
&mut frame_buf,
)
.await
.expect("secure payload should serialize");
writer.flush().await.expect("flush must succeed");
let mut encrypted_header = [0u8; 4];
read_side.read_exact(&mut encrypted_header).await.unwrap();
let plain_header = decryptor.decrypt(&encrypted_header);
let h: [u8; 4] = plain_header.as_slice().try_into().unwrap();
let wire_len = (u32::from_le_bytes(h) & 0x7fff_ffff) as usize;
let padding_len = wire_len - payload.len();
assert!((1..=3).contains(&padding_len));
seen[padding_len] = true;
let mut encrypted_body = vec![0u8; wire_len];
read_side.read_exact(&mut encrypted_body).await.unwrap();
let _ = decryptor.decrypt(&encrypted_body);
}
let distinct = (1..=3).filter(|idx| seen[*idx]).count();
assert!(
distinct >= 2,
"padding generator should not collapse to a single outcome under campaign"
);
}
#[tokio::test]
async fn write_client_payload_mixed_proto_sequence_preserves_stream_sync() {
let (mut read_side, write_side) = duplex(128 * 1024);
let key = [0u8; 32];
let iv = 0u128;
let mut writer = CryptoWriter::new(write_side, AesCtr::new(&key, iv), 8 * 1024);
let mut decryptor = AesCtr::new(&key, iv);
let rng = SecureRandom::new();
let mut frame_buf = Vec::new();
let p1 = vec![1u8; 8];
let p2 = vec![2u8; 16];
let p3 = vec![3u8; 20];
write_client_payload(&mut writer, ProtoTag::Abridged, 0, &p1, &rng, &mut frame_buf)
.await
.unwrap();
write_client_payload(
&mut writer,
ProtoTag::Intermediate,
RPC_FLAG_QUICKACK,
&p2,
&rng,
&mut frame_buf,
)
.await
.unwrap();
write_client_payload(&mut writer, ProtoTag::Secure, 0, &p3, &rng, &mut frame_buf)
.await
.unwrap();
writer.flush().await.unwrap();
// Frame 1: abridged short.
let mut e1 = vec![0u8; 1 + p1.len()];
read_side.read_exact(&mut e1).await.unwrap();
let d1 = decryptor.decrypt(&e1);
assert_eq!(d1[0], (p1.len() / 4) as u8);
assert_eq!(&d1[1..], p1.as_slice());
// Frame 2: intermediate with quickack.
let mut e2 = vec![0u8; 4 + p2.len()];
read_side.read_exact(&mut e2).await.unwrap();
let d2 = decryptor.decrypt(&e2);
let l2 = u32::from_le_bytes(d2[0..4].try_into().unwrap());
assert_ne!(l2 & 0x8000_0000, 0);
assert_eq!((l2 & 0x7fff_ffff) as usize, p2.len());
assert_eq!(&d2[4..], p2.as_slice());
// Frame 3: secure with bounded tail.
let mut e3h = [0u8; 4];
read_side.read_exact(&mut e3h).await.unwrap();
let d3h = decryptor.decrypt(&e3h);
let l3 = (u32::from_le_bytes(d3h.as_slice().try_into().unwrap()) & 0x7fff_ffff) as usize;
assert!(l3 >= p3.len());
assert!((1..=3).contains(&(l3 - p3.len())));
let mut e3b = vec![0u8; l3];
read_side.read_exact(&mut e3b).await.unwrap();
let d3b = decryptor.decrypt(&e3b);
assert_eq!(&d3b[..p3.len()], p3.as_slice());
}
#[test]
fn should_yield_sender_boundary_matrix_blackhat() {
assert!(!should_yield_c2me_sender(0, false));
assert!(!should_yield_c2me_sender(0, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET - 1, true));
assert!(!should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, false));
assert!(should_yield_c2me_sender(C2ME_SENDER_FAIRNESS_BUDGET, true));
assert!(should_yield_c2me_sender(
C2ME_SENDER_FAIRNESS_BUDGET.saturating_add(1024),
true
));
}
#[test]
fn should_yield_sender_light_fuzz_matches_oracle() {
let mut s: u64 = 0xD00D_BAAD_F00D_CAFE;
for _ in 0..5000 {
s ^= s << 7;
s ^= s >> 9;
s ^= s << 8;
let sent = (s as usize) & 0x1fff;
let backlog = (s & 1) != 0;
let expected = backlog && sent >= C2ME_SENDER_FAIRNESS_BUDGET;
assert_eq!(should_yield_c2me_sender(sent, backlog), expected);
}
}
#[test]
fn quota_would_be_exceeded_exact_remaining_one_byte() {
let stats = Stats::new();
let user = "quota-edge";
let quota = 100u64;
stats.add_user_octets_to(user, 99);
assert!(
!quota_would_be_exceeded_for_user(&stats, user, Some(quota), 1),
"exactly remaining budget should be allowed"
);
assert!(
quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2),
"one byte beyond remaining budget must be rejected"
);
}
#[test]
fn quota_would_be_exceeded_saturating_edge_remains_fail_closed() {
let stats = Stats::new();
let user = "quota-saturating-edge";
let quota = u64::MAX - 3;
stats.add_user_octets_to(user, u64::MAX - 4);
assert!(
quota_would_be_exceeded_for_user(&stats, user, Some(quota), 2),
"saturating arithmetic edge must stay fail-closed"
);
}
#[test]
fn quota_exceeded_boundary_is_inclusive() {
let stats = Stats::new();
let user = "quota-inclusive-boundary";
stats.add_user_octets_to(user, 50);
assert!(quota_exceeded_for_user(&stats, user, Some(50)));
assert!(!quota_exceeded_for_user(&stats, user, Some(51)));
}
#[tokio::test]
async fn enqueue_c2me_close_fast_path_succeeds_without_backpressure() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(4);
enqueue_c2me_command(&tx, C2MeCommand::Close)
.await
.expect("close should enqueue on fast path");
let recv = timeout(TokioDuration::from_millis(50), rx.recv())
.await
.expect("must receive close command")
.expect("close command should be present");
assert!(matches!(recv, C2MeCommand::Close));
}
#[tokio::test]
async fn enqueue_c2me_data_full_then_drain_preserves_order() {
let (tx, mut rx) = mpsc::channel::<C2MeCommand>(1);
tx.send(C2MeCommand::Data {
payload: make_pooled_payload(&[1]),
flags: 10,
})
.await
.unwrap();
let tx2 = tx.clone();
let producer = tokio::spawn(async move {
enqueue_c2me_command(
&tx2,
C2MeCommand::Data {
payload: make_pooled_payload(&[2, 2]),
flags: 20,
},
)
.await
});
tokio::time::sleep(TokioDuration::from_millis(10)).await;
let first = rx.recv().await.expect("first item should exist");
match first {
C2MeCommand::Data { payload, flags } => {
assert_eq!(payload.as_ref(), &[1]);
assert_eq!(flags, 10);
}
C2MeCommand::Close => panic!("unexpected close as first item"),
}
producer.await.unwrap().expect("producer should complete");
let second = timeout(TokioDuration::from_millis(100), rx.recv())
.await
.unwrap()
.expect("second item should exist");
match second {
C2MeCommand::Data { payload, flags } => {
assert_eq!(payload.as_ref(), &[2, 2]);
assert_eq!(flags, 20);
}
C2MeCommand::Close => panic!("unexpected close as second item"),
}
}

View File

@ -2,8 +2,8 @@ use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::sync::{Arc, Mutex, OnceLock};
use tokio::io::AsyncWriteExt;
use tokio::io::duplex;
use tokio::time::{Duration as TokioDuration, Instant as TokioInstant, timeout};
@ -48,18 +48,6 @@ fn make_idle_policy(soft_ms: u64, hard_ms: u64, grace_ms: u64) -> RelayClientIdl
}
}
fn idle_pressure_test_lock() -> &'static Mutex<()> {
static TEST_LOCK: OnceLock<Mutex<()>> = OnceLock::new();
TEST_LOCK.get_or_init(|| Mutex::new(()))
}
fn acquire_idle_pressure_test_lock() -> std::sync::MutexGuard<'static, ()> {
match idle_pressure_test_lock().lock() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
}
}
#[tokio::test]
async fn idle_policy_soft_mark_then_hard_close_increments_reason_counters() {
let (reader, _writer) = duplex(1024);
@ -372,7 +360,7 @@ async fn stress_many_idle_sessions_fail_closed_without_hang() {
#[test]
fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -402,7 +390,7 @@ fn pressure_evicts_oldest_idle_candidate_with_deterministic_ordering() {
#[test]
fn pressure_does_not_evict_without_new_pressure_signal() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -421,7 +409,7 @@ fn pressure_does_not_evict_without_new_pressure_signal() {
#[test]
fn stress_pressure_eviction_preserves_fifo_across_many_candidates() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -457,7 +445,7 @@ fn stress_pressure_eviction_preserves_fifo_across_many_candidates() {
#[test]
fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -491,7 +479,7 @@ fn blackhat_single_pressure_event_must_not_evict_more_than_one_candidate() {
#[test]
fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -524,7 +512,7 @@ fn blackhat_pressure_counter_must_track_global_budget_not_per_session_cursor() {
#[test]
fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -543,7 +531,7 @@ fn blackhat_stale_pressure_before_idle_mark_must_not_trigger_eviction() {
#[test]
fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -575,7 +563,7 @@ fn blackhat_stale_pressure_must_not_evict_any_of_newly_marked_batch() {
#[test]
fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -601,7 +589,7 @@ fn blackhat_stale_pressure_seen_without_candidates_must_be_globally_invalidated(
#[test]
fn blackhat_stale_pressure_must_not_survive_candidate_churn() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Stats::new();
@ -621,7 +609,7 @@ fn blackhat_stale_pressure_must_not_survive_candidate_churn() {
#[test]
fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
{
@ -646,7 +634,7 @@ fn blackhat_pressure_seq_saturation_must_not_disable_future_pressure_accounting(
#[test]
fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
{
@ -673,7 +661,7 @@ fn blackhat_pressure_seq_saturation_must_not_break_multiple_distinct_events() {
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_race_single_pressure_event_allows_at_most_one_eviction_under_parallel_claims()
{
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Arc::new(Stats::new());
@ -738,7 +726,7 @@ async fn integration_race_single_pressure_event_allows_at_most_one_eviction_unde
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_race_burst_pressure_with_churn_preserves_empty_set_invalidation_and_budget() {
let _guard = acquire_idle_pressure_test_lock();
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let stats = Arc::new(Stats::new());

View File

@ -0,0 +1,62 @@
use super::*;
use std::panic::{AssertUnwindSafe, catch_unwind};
#[test]
fn blackhat_registry_poison_recovers_with_fail_closed_reset_and_pressure_accounting() {
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let _ = catch_unwind(AssertUnwindSafe(|| {
let registry = relay_idle_candidate_registry();
let mut guard = registry
.lock()
.expect("registry lock must be acquired before poison");
guard.by_conn_id.insert(
999,
RelayIdleCandidateMeta {
mark_order_seq: 1,
mark_pressure_seq: 0,
},
);
guard.ordered.insert((1, 999));
panic!("intentional poison for idle-registry recovery");
}));
// Helper lock must recover from poison, reset stale state, and continue.
assert!(mark_relay_idle_candidate(42));
assert_eq!(oldest_relay_idle_candidate(), Some(42));
let before = relay_pressure_event_seq();
note_relay_pressure_event();
let after = relay_pressure_event_seq();
assert!(
after > before,
"pressure accounting must still advance after poison"
);
clear_relay_idle_pressure_state_for_testing();
}
#[test]
fn clear_state_helper_must_reset_poisoned_registry_for_deterministic_fifo_tests() {
let _guard = relay_idle_pressure_test_scope();
clear_relay_idle_pressure_state_for_testing();
let _ = catch_unwind(AssertUnwindSafe(|| {
let registry = relay_idle_candidate_registry();
let _guard = registry
.lock()
.expect("registry lock must be acquired before poison");
panic!("intentional poison while lock held");
}));
clear_relay_idle_pressure_state_for_testing();
assert_eq!(oldest_relay_idle_candidate(), None);
assert_eq!(relay_pressure_event_seq(), 0);
assert!(mark_relay_idle_candidate(7));
assert_eq!(oldest_relay_idle_candidate(), Some(7));
clear_relay_idle_pressure_state_for_testing();
}

View File

@ -1,131 +0,0 @@
use super::*;
use dashmap::DashMap;
use std::sync::Arc;
#[test]
fn saturation_uses_stable_overflow_lock_without_cache_growth() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let prefix = format!("middle-quota-held-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
assert_eq!(map.len(), QUOTA_USER_LOCKS_MAX);
let user = format!("middle-quota-overflow-{}", std::process::id());
let first = quota_user_lock(&user);
let second = quota_user_lock(&user);
assert!(
Arc::ptr_eq(&first, &second),
"overflow user must get deterministic same lock while cache is saturated"
);
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"overflow path must not grow bounded lock map"
);
assert!(
map.get(&user).is_none(),
"overflow user should stay outside bounded lock map under saturation"
);
drop(retained);
}
#[test]
fn overflow_striping_keeps_different_users_distributed() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let prefix = format!("middle-quota-dist-held-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
let a = quota_user_lock("middle-overflow-user-a");
let b = quota_user_lock("middle-overflow-user-b");
let c = quota_user_lock("middle-overflow-user-c");
let distinct = [
Arc::as_ptr(&a) as usize,
Arc::as_ptr(&b) as usize,
Arc::as_ptr(&c) as usize,
]
.iter()
.copied()
.collect::<std::collections::HashSet<_>>()
.len();
assert!(
distinct >= 2,
"striped overflow lock set should avoid collapsing all users to one lock"
);
drop(retained);
}
#[test]
fn reclaim_path_caches_new_user_after_stale_entries_drop() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let prefix = format!("middle-quota-reclaim-held-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
drop(retained);
let user = format!("middle-quota-reclaim-user-{}", std::process::id());
let got = quota_user_lock(&user);
assert!(map.get(&user).is_some());
assert!(
Arc::strong_count(&got) >= 2,
"after reclaim, lock should be held both by caller and map"
);
}
#[test]
fn overflow_path_same_user_is_stable_across_parallel_threads() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!(
"middle-quota-thread-held-{}-{idx}",
std::process::id()
)));
}
let user = format!("middle-quota-overflow-thread-user-{}", std::process::id());
let mut workers = Vec::new();
for _ in 0..32 {
let user = user.clone();
workers.push(std::thread::spawn(move || quota_user_lock(&user)));
}
let first = workers
.remove(0)
.join()
.expect("thread must return lock handle");
for worker in workers {
let got = worker.join().expect("thread must return lock handle");
assert!(
Arc::ptr_eq(&first, &got),
"same overflow user should resolve to one striped lock even under contention"
);
}
drop(retained);
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,372 @@
use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
use tokio::task::JoinSet;
use tokio::time::{Duration as TokioDuration, sleep};
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
RelayForensicsState {
trace_id: 0xB200_0000 + conn_id,
conn_id,
user: format!("tiny-frame-debt-concurrency-user-{conn_id}"),
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
started_at,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
RelayClientIdlePolicy {
enabled: true,
soft_idle: Duration::from_millis(50),
hard_idle: Duration::from_millis(120),
grace_after_downstream_activity: Duration::from_secs(0),
legacy_frame_read_timeout: Duration::from_millis(50),
}
}
async fn read_once(
crypto_reader: &mut CryptoReader<tokio::io::DuplexStream>,
proto: ProtoTag,
forensics: &RelayForensicsState,
frame_counter: &mut u64,
idle_state: &mut RelayClientIdleState,
) -> Result<Option<(PooledBuffer, bool)>> {
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
read_client_payload_with_idle_policy(
crypto_reader,
proto,
1024,
&buffer_pool,
forensics,
frame_counter,
&stats,
&idle_policy,
idle_state,
&last_downstream_activity_ms,
forensics.started_at,
)
.await
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_pure_tiny_floods_all_fail_closed() {
let mut set = JoinSet::new();
for idx in 0..32u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(1000 + idx, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let flood_plaintext = vec![0u8; 1024];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = run_relay_test_step_timeout(
"tiny flood task",
read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
assert!(matches!(result, Err(ProxyError::Proxy(_))));
assert_eq!(frame_counter, 0);
});
}
while let Some(result) = set.join_next().await {
result.expect("parallel tiny flood worker must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_benign_tiny_burst_then_real_all_pass() {
let mut set = JoinSet::new();
for idx in 0..24u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(2048);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(2000 + idx, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let payload = [idx as u8, 2, 3, 4];
let mut plaintext = Vec::with_capacity(20);
for _ in 0..6 {
plaintext.push(0x00);
}
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
let result = run_relay_test_step_timeout(
"benign tiny burst read",
read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await
.expect("benign payload must parse")
.expect("benign payload must return frame");
assert_eq!(result.0.as_ref(), &payload);
assert_eq!(frame_counter, 1);
});
}
while let Some(result) = set.join_next().await {
result.expect("parallel benign worker must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn adversarial_lockstep_alternating_attack_under_jitter_closes() {
let mut set = JoinSet::new();
for idx in 0..12u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(3000 + idx, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(2000);
for n in 0..180u8 {
plaintext.push(0x00);
plaintext.push(0x01);
plaintext.extend_from_slice(&[n, n ^ 0x21, n ^ 0x42, n ^ 0x84]);
}
let encrypted = encrypt_for_reader(&plaintext);
let writer_task = tokio::spawn(async move {
for chunk in encrypted.chunks(17) {
writer.write_all(chunk).await.unwrap();
sleep(TokioDuration::from_millis(1)).await;
}
drop(writer);
});
let mut closed = false;
for _ in 0..220 {
let result = run_relay_test_step_timeout(
"alternating jitter read step",
read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
match result {
Ok(Some((_payload, _))) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected error in alternating jitter case: {other}"),
}
}
writer_task
.await
.expect("writer jitter task must not panic");
assert!(closed, "alternating attack must close before EOF");
});
}
while let Some(result) = set.join_next().await {
result.expect("alternating jitter worker must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_mixed_population_attackers_close_benign_survive() {
let mut set = JoinSet::new();
for idx in 0..20u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(4000 + idx, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
if idx % 2 == 0 {
let mut plaintext = Vec::with_capacity(1280);
for n in 0..140u8 {
plaintext.push(0x00);
plaintext.push(0x01);
plaintext.extend_from_slice(&[n, n, n, n]);
}
writer
.write_all(&encrypt_for_reader(&plaintext))
.await
.unwrap();
drop(writer);
let mut closed = false;
for _ in 0..200 {
match read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
)
.await
{
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected attacker error: {other}"),
}
}
assert!(closed, "attacker session must fail closed");
} else {
let payload = [1u8, 9, 8, 7];
let mut plaintext = Vec::new();
for _ in 0..4 {
plaintext.push(0x00);
}
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
writer
.write_all(&encrypt_for_reader(&plaintext))
.await
.unwrap();
let got = read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
)
.await
.expect("benign session must parse")
.expect("benign session must return a frame");
assert_eq!(got.0.as_ref(), &payload);
}
});
}
while let Some(result) = set.join_next().await {
result.expect("mixed-population worker must not panic");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn light_fuzz_parallel_patterns_no_hang_or_panic() {
let mut set = JoinSet::new();
for case in 0..40u64 {
set.spawn(async move {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(5000 + case, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut seed = 0x9E37_79B9u64 ^ (case << 8);
let mut plaintext = Vec::with_capacity(2048);
for _ in 0..256 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let is_tiny = (seed & 1) == 0;
if is_tiny {
plaintext.push(0x00);
} else {
plaintext.push(0x01);
plaintext.extend_from_slice(&[(seed >> 8) as u8, 2, 3, 4]);
}
}
writer
.write_all(&encrypt_for_reader(&plaintext))
.await
.unwrap();
drop(writer);
for _ in 0..320 {
let step = run_relay_test_step_timeout(
"fuzz case read step",
read_once(
&mut crypto_reader,
ProtoTag::Abridged,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
match step {
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => break,
Ok(None) => break,
Err(other) => panic!("unexpected fuzz case error: {other}"),
}
}
});
}
while let Some(result) = set.join_next().await {
result.expect("fuzz worker must not panic");
}
}

View File

@ -0,0 +1,435 @@
use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader, PooledBuffer};
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
use tokio::time::{Duration as TokioDuration, sleep};
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
RelayForensicsState {
trace_id: 0xB300_0000 + conn_id,
conn_id,
user: format!("tiny-frame-debt-proto-chunk-user-{conn_id}"),
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
started_at,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
RelayClientIdlePolicy {
enabled: true,
soft_idle: Duration::from_millis(50),
hard_idle: Duration::from_millis(120),
grace_after_downstream_activity: Duration::from_secs(0),
legacy_frame_read_timeout: Duration::from_millis(50),
}
}
fn append_tiny_frame(plaintext: &mut Vec<u8>, proto: ProtoTag) {
match proto {
ProtoTag::Abridged => plaintext.push(0x00),
ProtoTag::Intermediate | ProtoTag::Secure => {
plaintext.extend_from_slice(&0u32.to_le_bytes())
}
}
}
fn append_real_frame(plaintext: &mut Vec<u8>, proto: ProtoTag, payload: [u8; 4]) {
match proto {
ProtoTag::Abridged => {
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
}
ProtoTag::Intermediate | ProtoTag::Secure => {
plaintext.extend_from_slice(&4u32.to_le_bytes());
plaintext.extend_from_slice(&payload);
}
}
}
async fn write_chunked_with_jitter(
writer: &mut tokio::io::DuplexStream,
bytes: &[u8],
mut seed: u64,
) {
let mut offset = 0usize;
while offset < bytes.len() {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let chunk_len = 1 + ((seed as usize) & 0x1f);
let end = (offset + chunk_len).min(bytes.len());
writer.write_all(&bytes[offset..end]).await.unwrap();
let delay_ms = ((seed >> 16) % 3) as u64;
if delay_ms > 0 {
sleep(TokioDuration::from_millis(delay_ms)).await;
}
offset = end;
}
}
async fn read_once_with_state(
crypto_reader: &mut CryptoReader<tokio::io::DuplexStream>,
proto: ProtoTag,
forensics: &RelayForensicsState,
frame_counter: &mut u64,
idle_state: &mut RelayClientIdleState,
) -> Result<Option<(PooledBuffer, bool)>> {
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
read_client_payload_with_idle_policy(
crypto_reader,
proto,
1024,
&buffer_pool,
forensics,
frame_counter,
&stats,
&idle_policy,
idle_state,
&last_downstream_activity_ms,
forensics.started_at,
)
.await
}
fn is_fail_closed_outcome(result: &Result<Option<(PooledBuffer, bool)>>) -> bool {
matches!(result, Err(ProxyError::Proxy(_)))
|| matches!(result, Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut)
}
#[tokio::test]
async fn intermediate_chunked_zero_flood_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6101, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(4 * 256);
for _ in 0..256 {
append_tiny_frame(&mut plaintext, ProtoTag::Intermediate);
}
let encrypted = encrypt_for_reader(&plaintext);
write_chunked_with_jitter(&mut writer, &encrypted, 0x1111_2222).await;
drop(writer);
let result = run_relay_test_step_timeout(
"intermediate flood read",
read_once_with_state(
&mut crypto_reader,
ProtoTag::Intermediate,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
assert!(
is_fail_closed_outcome(&result),
"zero-length flood must fail closed via debt guard or idle timeout"
);
assert_eq!(frame_counter, 0);
}
#[tokio::test]
async fn secure_chunked_zero_flood_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6102, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(4 * 256);
for _ in 0..256 {
append_tiny_frame(&mut plaintext, ProtoTag::Secure);
}
let encrypted = encrypt_for_reader(&plaintext);
write_chunked_with_jitter(&mut writer, &encrypted, 0x3333_4444).await;
drop(writer);
let result = run_relay_test_step_timeout(
"secure flood read",
read_once_with_state(
&mut crypto_reader,
ProtoTag::Secure,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
assert!(
is_fail_closed_outcome(&result),
"secure zero-length flood must fail closed via debt guard or idle timeout"
);
assert_eq!(frame_counter, 0);
}
#[tokio::test]
async fn intermediate_chunked_alternating_attack_closes_before_eof() {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6103, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(8 * 200);
for n in 0..180u8 {
append_tiny_frame(&mut plaintext, ProtoTag::Intermediate);
append_real_frame(
&mut plaintext,
ProtoTag::Intermediate,
[n, n ^ 1, n ^ 2, n ^ 3],
);
}
let encrypted = encrypt_for_reader(&plaintext);
let writer_task = tokio::spawn(async move {
write_chunked_with_jitter(&mut writer, &encrypted, 0x5555_6666).await;
drop(writer);
});
let mut closed = false;
for _ in 0..240 {
let step = run_relay_test_step_timeout(
"intermediate alternating read step",
read_once_with_state(
&mut crypto_reader,
ProtoTag::Intermediate,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
match step {
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected intermediate alternating error: {other}"),
}
}
writer_task
.await
.expect("intermediate writer task must not panic");
assert!(closed, "intermediate alternating attack must fail closed");
}
#[tokio::test]
async fn secure_chunked_alternating_attack_closes_before_eof() {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6104, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut plaintext = Vec::with_capacity(8 * 200);
for n in 0..180u8 {
append_tiny_frame(&mut plaintext, ProtoTag::Secure);
append_real_frame(&mut plaintext, ProtoTag::Secure, [n, n ^ 7, n ^ 11, n ^ 19]);
}
let encrypted = encrypt_for_reader(&plaintext);
let writer_task = tokio::spawn(async move {
write_chunked_with_jitter(&mut writer, &encrypted, 0x7777_8888).await;
drop(writer);
});
let mut closed = false;
for _ in 0..240 {
let step = run_relay_test_step_timeout(
"secure alternating read step",
read_once_with_state(
&mut crypto_reader,
ProtoTag::Secure,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
match step {
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected secure alternating error: {other}"),
}
}
writer_task
.await
.expect("secure writer task must not panic");
assert!(closed, "secure alternating attack must fail closed");
}
#[tokio::test]
async fn intermediate_chunked_safe_small_burst_still_returns_real_frame() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6105, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let payload = [9u8, 8, 7, 6];
let mut plaintext = Vec::new();
for _ in 0..7 {
append_tiny_frame(&mut plaintext, ProtoTag::Intermediate);
}
append_real_frame(&mut plaintext, ProtoTag::Intermediate, payload);
let encrypted = encrypt_for_reader(&plaintext);
write_chunked_with_jitter(&mut writer, &encrypted, 0xAAAA_BBBB).await;
let result = read_once_with_state(
&mut crypto_reader,
ProtoTag::Intermediate,
&forensics,
&mut frame_counter,
&mut idle_state,
)
.await
.expect("intermediate safe burst should parse")
.expect("intermediate safe burst should return a frame");
assert_eq!(result.0.as_ref(), &payload);
assert_eq!(frame_counter, 1);
}
#[tokio::test]
async fn secure_chunked_safe_small_burst_still_returns_real_frame() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6106, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let payload = [3u8, 1, 4, 1];
let mut plaintext = Vec::new();
for _ in 0..7 {
append_tiny_frame(&mut plaintext, ProtoTag::Secure);
}
append_real_frame(&mut plaintext, ProtoTag::Secure, payload);
let encrypted = encrypt_for_reader(&plaintext);
write_chunked_with_jitter(&mut writer, &encrypted, 0xCCCC_DDDD).await;
let result = read_once_with_state(
&mut crypto_reader,
ProtoTag::Secure,
&forensics,
&mut frame_counter,
&mut idle_state,
)
.await
.expect("secure safe burst should parse")
.expect("secure safe burst should return a frame");
assert_eq!(result.0.as_ref(), &payload);
assert_eq!(frame_counter, 1);
}
#[tokio::test]
async fn light_fuzz_proto_chunking_outcomes_are_bounded() {
let mut seed = 0xDEAD_BEEF_2026_0322u64;
for case in 0..48u64 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let proto = if (seed & 1) == 0 {
ProtoTag::Intermediate
} else {
ProtoTag::Secure
};
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let started = Instant::now();
let forensics = make_forensics(6200 + case, started);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(started);
let mut stream = Vec::new();
let mut local_seed = seed ^ case;
for _ in 0..220 {
local_seed ^= local_seed << 7;
local_seed ^= local_seed >> 9;
local_seed ^= local_seed << 8;
if (local_seed & 1) == 0 {
append_tiny_frame(&mut stream, proto);
} else {
let b = (local_seed >> 8) as u8;
append_real_frame(&mut stream, proto, [b, b ^ 0x12, b ^ 0x24, b ^ 0x48]);
}
}
let encrypted = encrypt_for_reader(&stream);
write_chunked_with_jitter(&mut writer, &encrypted, seed ^ 0x1234_5678).await;
drop(writer);
for _ in 0..260 {
let step = run_relay_test_step_timeout(
"fuzz proto read step",
read_once_with_state(
&mut crypto_reader,
proto,
&forensics,
&mut frame_counter,
&mut idle_state,
),
)
.await;
match step {
Ok(Some((_payload, _))) => {}
Err(ProxyError::Proxy(_)) => break,
Err(ProxyError::Io(e)) if e.kind() == std::io::ErrorKind::TimedOut => break,
Ok(None) => break,
Err(other) => panic!("unexpected proto chunking fuzz error: {other}"),
}
}
}
}

View File

@ -0,0 +1,804 @@
use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
RelayForensicsState {
trace_id: 0xB100_0000 + conn_id,
conn_id,
user: format!("tiny-frame-debt-user-{conn_id}"),
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
started_at,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
fn make_enabled_idle_policy() -> RelayClientIdlePolicy {
RelayClientIdlePolicy {
enabled: true,
soft_idle: Duration::from_millis(50),
hard_idle: Duration::from_millis(120),
grace_after_downstream_activity: Duration::from_secs(0),
legacy_frame_read_timeout: Duration::from_millis(50),
}
}
async fn read_bounded(
crypto_reader: &mut CryptoReader<tokio::io::DuplexStream>,
proto_tag: ProtoTag,
buffer_pool: &Arc<BufferPool>,
forensics: &RelayForensicsState,
frame_counter: &mut u64,
stats: &Stats,
idle_policy: &RelayClientIdlePolicy,
idle_state: &mut RelayClientIdleState,
last_downstream_activity_ms: &AtomicU64,
session_started_at: Instant,
) -> Result<Option<(PooledBuffer, bool)>> {
run_relay_test_step_timeout(
"tiny-frame debt read step",
read_client_payload_with_idle_policy(
crypto_reader,
proto_tag,
1024,
buffer_pool,
forensics,
frame_counter,
stats,
idle_policy,
idle_state,
last_downstream_activity_ms,
session_started_at,
),
)
.await
}
fn simulate_tiny_debt_pattern(pattern: &[bool], max_steps: usize) -> (Option<usize>, u32, usize) {
let mut debt = 0u32;
let mut reals = 0usize;
for (idx, is_tiny) in pattern.iter().copied().take(max_steps).enumerate() {
if is_tiny {
debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY);
if debt >= TINY_FRAME_DEBT_LIMIT {
return (Some(idx + 1), debt, reals);
}
} else {
reals = reals.saturating_add(1);
debt = debt.saturating_sub(1);
}
}
(None, debt, reals)
}
#[test]
fn tiny_frame_debt_constants_match_security_budget_expectations() {
assert_eq!(TINY_FRAME_DEBT_PER_TINY, 8);
assert_eq!(TINY_FRAME_DEBT_LIMIT, 512);
}
#[test]
fn relay_client_idle_state_initial_debt_is_zero() {
let state = RelayClientIdleState::new(Instant::now());
assert_eq!(state.tiny_frame_debt, 0);
}
#[test]
fn on_client_frame_does_not_reset_tiny_frame_debt() {
let now = Instant::now();
let mut state = RelayClientIdleState::new(now);
state.tiny_frame_debt = 77;
state.on_client_frame(now);
assert_eq!(state.tiny_frame_debt, 77);
}
#[test]
fn tiny_frame_debt_increment_is_saturating() {
let mut debt = u32::MAX - 1;
debt = debt.saturating_add(TINY_FRAME_DEBT_PER_TINY);
assert_eq!(debt, u32::MAX);
}
#[test]
fn tiny_frame_debt_decrement_is_saturating() {
let mut debt = 0u32;
debt = debt.saturating_sub(1);
assert_eq!(debt, 0);
}
#[test]
fn consecutive_tiny_frames_close_exactly_at_threshold() {
let max_tiny_without_close = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize;
let pattern = vec![true; max_tiny_without_close];
let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert_eq!(closed_at, Some(max_tiny_without_close));
}
#[test]
fn one_less_than_threshold_tiny_frames_do_not_close() {
let tiny_count = (TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) as usize - 1;
let pattern = vec![true; tiny_count];
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert_eq!(closed_at, None);
assert!(debt < TINY_FRAME_DEBT_LIMIT);
}
#[test]
fn alternating_one_to_one_closes_with_bounded_real_frame_count() {
let mut pattern = Vec::with_capacity(512);
for _ in 0..256 {
pattern.push(true);
pattern.push(false);
}
let (closed_at, _, reals) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert!(closed_at.is_some());
assert!(
reals <= 80,
"expected bounded real frames before close, got {reals}"
);
}
#[test]
fn alternating_one_to_eight_is_stable_for_long_runs() {
let mut pattern = Vec::with_capacity(9 * 5000);
for _ in 0..5000 {
pattern.push(true);
for _ in 0..8 {
pattern.push(false);
}
}
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert_eq!(closed_at, None);
assert!(debt <= TINY_FRAME_DEBT_PER_TINY);
}
#[test]
fn alternating_one_to_seven_eventually_closes() {
let mut pattern = Vec::with_capacity(8 * 2000);
for _ in 0..2000 {
pattern.push(true);
for _ in 0..7 {
pattern.push(false);
}
}
let (closed_at, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert!(
closed_at.is_some(),
"1:7 tiny-to-real must eventually close"
);
}
#[test]
fn two_tiny_one_real_closes_faster_than_one_to_one() {
let mut one_to_one = Vec::with_capacity(512);
for _ in 0..256 {
one_to_one.push(true);
one_to_one.push(false);
}
let mut two_to_one = Vec::with_capacity(768);
for _ in 0..256 {
two_to_one.push(true);
two_to_one.push(true);
two_to_one.push(false);
}
let (a_close, _, _) = simulate_tiny_debt_pattern(&one_to_one, one_to_one.len());
let (b_close, _, _) = simulate_tiny_debt_pattern(&two_to_one, two_to_one.len());
assert!(a_close.is_some() && b_close.is_some());
assert!(b_close.unwrap_or(usize::MAX) < a_close.unwrap_or(0));
}
#[test]
fn burst_then_drain_can_recover_without_close() {
let burst_tiny = ((TINY_FRAME_DEBT_LIMIT / TINY_FRAME_DEBT_PER_TINY) / 2) as usize;
let mut pattern = Vec::with_capacity(burst_tiny + 600);
for _ in 0..burst_tiny {
pattern.push(true);
}
pattern.extend(std::iter::repeat_n(false, 600));
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert_eq!(closed_at, None);
assert_eq!(debt, 0);
}
#[test]
fn light_fuzz_tiny_frame_debt_model_stays_within_bounds() {
let mut seed = 0xA5A5_91C3_2026_0322u64;
for _case in 0..128 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let len = 512 + ((seed as usize) & 0x3ff);
let mut pattern = Vec::with_capacity(len);
let mut local_seed = seed;
for _ in 0..len {
local_seed ^= local_seed << 7;
local_seed ^= local_seed >> 9;
local_seed ^= local_seed << 8;
pattern.push((local_seed & 1) == 0);
}
let (closed_at, debt, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
if closed_at.is_none() {
assert!(debt < TINY_FRAME_DEBT_LIMIT);
}
assert!(debt <= u32::MAX);
}
}
#[test]
fn stress_many_independent_simulations_keep_isolated_debt_state() {
for idx in 0..2048usize {
let mut pattern = Vec::with_capacity(64);
for j in 0..64usize {
pattern.push(((idx ^ j) & 3) == 0);
}
let (_closed_at, debt, _reals) = simulate_tiny_debt_pattern(&pattern, pattern.len());
assert!(debt <= TINY_FRAME_DEBT_LIMIT.saturating_add(TINY_FRAME_DEBT_PER_TINY));
}
}
#[tokio::test]
async fn idle_policy_enabled_intermediate_zero_length_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(11, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let flood_plaintext = vec![0u8; 4 * 256];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Intermediate,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(matches!(result, Err(ProxyError::Proxy(_))));
}
#[tokio::test]
async fn idle_policy_enabled_secure_zero_length_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(12, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let flood_plaintext = vec![0u8; 4 * 256];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Secure,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(matches!(result, Err(ProxyError::Proxy(_))));
}
#[tokio::test]
async fn intermediate_alternating_zero_and_real_eventually_closes() {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(13, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let mut plaintext = Vec::with_capacity(3000);
for idx in 0..160u8 {
plaintext.extend_from_slice(&0u32.to_le_bytes());
plaintext.extend_from_slice(&4u32.to_le_bytes());
plaintext.extend_from_slice(&[idx, idx ^ 0x11, idx ^ 0x22, idx ^ 0x33]);
}
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
drop(writer);
let mut closed = false;
for _ in 0..220 {
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Intermediate,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
match result {
Ok(Some(_)) => {}
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Ok(None) => break,
Err(other) => panic!("unexpected error while probing alternating close: {other}"),
}
}
assert!(closed, "intermediate alternating attack must fail closed");
}
#[tokio::test]
async fn small_tiny_burst_followed_by_real_frame_does_not_spuriously_close() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(14, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let mut plaintext = Vec::with_capacity(64);
for _ in 0..8 {
plaintext.push(0x00);
}
plaintext.push(0x01);
plaintext.extend_from_slice(&[1, 2, 3, 4]);
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
let first = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
match first {
Ok(Some((payload, _))) => assert_eq!(payload.as_ref(), &[1, 2, 3, 4]),
Err(e) => panic!("unexpected close after small tiny burst: {e}"),
Ok(None) => panic!("unexpected EOF before real frame"),
}
}
#[tokio::test]
async fn idle_policy_enabled_zero_length_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(1, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let flood_plaintext = vec![0u8; 1024];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer
.write_all(&flood_encrypted)
.await
.expect("zero-length flood bytes must be writable");
drop(writer);
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(
matches!(result, Err(ProxyError::Proxy(_))),
"idle policy enabled must fail closed for pure zero-length flood"
);
}
#[tokio::test]
async fn idle_policy_enabled_alternating_tiny_real_eventually_closes() {
let (reader, mut writer) = duplex(8192);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(2, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let mut plaintext = Vec::with_capacity(256 * 6);
for idx in 0..=255u8 {
plaintext.push(0x00);
plaintext.push(0x01);
plaintext.extend_from_slice(&[idx, idx ^ 0x55, idx ^ 0xAA, 0x11]);
}
let encrypted = encrypt_for_reader(&plaintext);
writer
.write_all(&encrypted)
.await
.expect("alternating flood bytes must be writable");
drop(writer);
let mut saw_proxy_close = false;
for _ in 0..300 {
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
match result {
Ok(Some((_payload, _quickack))) => {}
Err(ProxyError::Proxy(_)) => {
saw_proxy_close = true;
break;
}
Err(ProxyError::Io(e)) => panic!("unexpected IO error before close: {e}"),
Ok(None) => panic!("unexpected EOF before debt-based closure"),
Err(other) => panic!("unexpected error before close: {other}"),
}
}
assert!(
saw_proxy_close,
"alternating tiny/real sequence must eventually fail closed"
);
}
#[tokio::test]
async fn enabled_idle_policy_valid_nonzero_frame_still_passes() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(3, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let payload = [7u8, 8, 9, 10];
let mut plaintext = Vec::with_capacity(1 + payload.len());
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
let encrypted = encrypt_for_reader(&plaintext);
writer
.write_all(&encrypted)
.await
.expect("nonzero frame must be writable");
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await
.expect("valid frame should decode")
.expect("valid frame should return payload");
assert_eq!(result.0.as_ref(), &payload);
assert!(!result.1);
assert_eq!(frame_counter, 1);
}
#[tokio::test]
async fn abridged_quickack_tiny_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(21, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let flood_plaintext = vec![0x80u8; 256];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(
matches!(result, Err(ProxyError::Proxy(_))),
"quickack-marked zero-length flood must fail closed"
);
}
#[tokio::test]
async fn abridged_extended_zero_len_flood_is_fail_closed() {
let (reader, mut writer) = duplex(4096);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(22, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let mut flood_plaintext = Vec::with_capacity(4 * 256);
for _ in 0..256 {
flood_plaintext.extend_from_slice(&[0x7f, 0x00, 0x00, 0x00]);
}
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer.write_all(&flood_encrypted).await.unwrap();
drop(writer);
let result = read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await;
assert!(
matches!(result, Err(ProxyError::Proxy(_))),
"extended zero-length abridged flood must fail closed"
);
}
#[tokio::test]
async fn one_to_eight_abridged_wire_pattern_survives_without_false_positive_close() {
let mut plaintext = Vec::with_capacity(9 * 300);
for idx in 0..300usize {
plaintext.push(0x00);
for _ in 0..8 {
let b = idx as u8;
plaintext.push(0x01);
plaintext.extend_from_slice(&[b, b ^ 0x11, b ^ 0x22, b ^ 0x33]);
}
}
// Keep the test single-task and deterministic: make duplex capacity larger than the
// generated ciphertext so write_all cannot block waiting for a concurrent reader.
let duplex_capacity = plaintext.len().saturating_add(1024);
let (reader, mut writer) = duplex(duplex_capacity);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(23, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
let encrypted = encrypt_for_reader(&plaintext);
writer.write_all(&encrypted).await.unwrap();
drop(writer);
let mut closed = false;
for _ in 0..3000 {
match read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await
{
Ok(Some(_)) => {}
Ok(None) => break,
Err(ProxyError::Proxy(_)) => {
closed = true;
break;
}
Err(other) => panic!("unexpected error in 1:8 wire test: {other}"),
}
}
assert!(
!closed,
"wire-level 1:8 tiny-to-real pattern should not trigger debt close"
);
}
#[tokio::test]
async fn deterministic_light_fuzz_abridged_wire_behavior_matches_model() {
let mut seed = 0xD1CE_BAAD_2026_0322u64;
for case_idx in 0..32u64 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let events = 300 + ((seed as usize) & 0xff);
let mut pattern = Vec::with_capacity(events);
let mut local = seed;
for _ in 0..events {
local ^= local << 7;
local ^= local >> 9;
local ^= local << 8;
pattern.push((local & 0x03) == 0);
}
let mut plaintext = Vec::with_capacity(events * 6);
for (idx, tiny) in pattern.iter().copied().enumerate() {
if tiny {
plaintext.push(0x00);
} else {
let b = (idx as u8) ^ (case_idx as u8);
plaintext.push(0x01);
plaintext.extend_from_slice(&[b, b ^ 0x1F, b ^ 0x7A, b ^ 0xC3]);
}
}
let (reader, mut writer) = duplex(16 * 1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(500 + case_idx, session_started_at);
let mut frame_counter = 0u64;
let mut idle_state = RelayClientIdleState::new(session_started_at);
let idle_policy = make_enabled_idle_policy();
let last_downstream_activity_ms = AtomicU64::new(0);
writer
.write_all(&encrypt_for_reader(&plaintext))
.await
.unwrap();
drop(writer);
let (expected_close, _, _) = simulate_tiny_debt_pattern(&pattern, pattern.len());
let mut observed_close = false;
for _ in 0..(events + 8) {
match read_bounded(
&mut crypto_reader,
ProtoTag::Abridged,
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
&idle_policy,
&mut idle_state,
&last_downstream_activity_ms,
session_started_at,
)
.await
{
Ok(Some(_)) => {}
Ok(None) => break,
Err(ProxyError::Proxy(_)) => {
observed_close = true;
break;
}
Err(other) => panic!("unexpected fuzz error: {other}"),
}
}
assert_eq!(
observed_close,
expected_close.is_some(),
"wire parser behavior must match debt model for case {case_idx}"
);
}
}

View File

@ -0,0 +1,121 @@
use super::*;
use crate::crypto::AesCtr;
use crate::stats::Stats;
use crate::stream::{BufferPool, CryptoReader};
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncWriteExt, duplex};
fn make_crypto_reader<T>(reader: T) -> CryptoReader<T>
where
T: AsyncRead + Unpin + Send + 'static,
{
let key = [0u8; 32];
let iv = 0u128;
CryptoReader::new(reader, AesCtr::new(&key, iv))
}
fn encrypt_for_reader(plaintext: &[u8]) -> Vec<u8> {
let key = [0u8; 32];
let iv = 0u128;
let mut cipher = AesCtr::new(&key, iv);
cipher.encrypt(plaintext)
}
fn make_forensics(conn_id: u64, started_at: Instant) -> RelayForensicsState {
RelayForensicsState {
trace_id: 0xB000_0000 + conn_id,
conn_id,
user: format!("zero-len-test-user-{conn_id}"),
peer: "127.0.0.1:50000".parse().expect("peer parse must succeed"),
peer_hash: hash_ip("127.0.0.1".parse().expect("ip parse must succeed")),
started_at,
bytes_c2me: 0,
bytes_me2c: Arc::new(AtomicU64::new(0)),
desync_all_full: false,
}
}
#[tokio::test]
async fn adversarial_legacy_zero_length_flood_is_fail_closed() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(1, session_started_at);
let mut frame_counter = 0u64;
let flood_plaintext = vec![0u8; 128];
let flood_encrypted = encrypt_for_reader(&flood_plaintext);
writer
.write_all(&flood_encrypted)
.await
.expect("zero-length flood bytes must be writable");
drop(writer);
let result = read_client_payload_legacy(
&mut crypto_reader,
ProtoTag::Abridged,
1024,
Duration::from_millis(30),
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
)
.await;
match result {
Err(ProxyError::Proxy(msg)) => {
assert!(
msg.contains("Excessive zero-length"),
"legacy mode must close flood with explicit zero-length reason, got: {msg}"
);
}
Ok(None) => panic!("legacy zero-length flood must not be accepted as EOF"),
Ok(Some(_)) => panic!("legacy zero-length flood must not produce a data frame"),
Err(err) => panic!("legacy zero-length flood must be a Proxy error, got: {err}"),
}
}
#[tokio::test]
async fn business_abridged_nonzero_frame_still_passes() {
let (reader, mut writer) = duplex(1024);
let mut crypto_reader = make_crypto_reader(reader);
let buffer_pool = Arc::new(BufferPool::new());
let stats = Stats::new();
let session_started_at = Instant::now();
let forensics = make_forensics(2, session_started_at);
let mut frame_counter = 0u64;
let payload = [1u8, 2, 3, 4];
let mut plaintext = Vec::with_capacity(1 + payload.len());
plaintext.push(0x01);
plaintext.extend_from_slice(&payload);
let encrypted = encrypt_for_reader(&plaintext);
writer
.write_all(&encrypted)
.await
.expect("nonzero abridged frame must be writable");
let result = read_client_payload_legacy(
&mut crypto_reader,
ProtoTag::Abridged,
1024,
Duration::from_millis(30),
&buffer_pool,
&forensics,
&mut frame_counter,
&stats,
)
.await
.expect("valid abridged frame should decode")
.expect("valid abridged frame should return payload");
assert_eq!(result.0.as_ref(), &payload);
assert!(!result.1, "quickack flag must remain false");
assert_eq!(frame_counter, 1);
}

View File

@ -78,7 +78,8 @@ async fn relay_hol_blocking_prevention_regression() {
async fn relay_quota_mid_session_cutoff() {
let stats = Arc::new(Stats::new());
let user = "quota-mid-user";
let quota = 5000;
let quota = 5000u64;
let c2s_buf_size = 1024usize;
let (client_peer, relay_client) = duplex(8192);
let (relay_server, server_peer) = duplex(8192);
@ -93,7 +94,7 @@ async fn relay_quota_mid_session_cutoff() {
client_writer,
server_reader,
server_writer,
1024,
c2s_buf_size,
1024,
user,
Arc::clone(&stats),
@ -120,9 +121,25 @@ async fn relay_quota_mid_session_cutoff() {
other => panic!("Expected DataQuotaExceeded error, got: {:?}", other),
}
let mut small_buf = [0u8; 1];
let n = sp_reader.read(&mut small_buf).await.unwrap();
assert_eq!(n, 0, "Server must see EOF after quota reached");
let mut overshoot_bytes = 0usize;
let mut buf = [0u8; 256];
loop {
match timeout(Duration::from_millis(20), sp_reader.read(&mut buf)).await {
Ok(Ok(0)) => break,
Ok(Ok(n)) => overshoot_bytes = overshoot_bytes.saturating_add(n),
Ok(Err(e)) => panic!("server read must not fail after relay cutoff: {e}"),
Err(_) => break,
}
}
assert!(
overshoot_bytes <= c2s_buf_size,
"post-write cutoff may leak at most one C->S chunk after boundary, got {overshoot_bytes}"
);
assert!(
stats.get_user_quota_used(user) <= quota.saturating_add(c2s_buf_size as u64),
"accounted quota must remain bounded by one in-flight chunk overshoot"
);
}
#[tokio::test]

View File

@ -0,0 +1,243 @@
use super::*;
use std::collections::VecDeque;
use std::io;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::time::Instant;
struct ScriptedWriter {
scripted_writes: Arc<Mutex<VecDeque<usize>>>,
write_calls: Arc<AtomicUsize>,
}
impl ScriptedWriter {
fn new(script: &[usize], write_calls: Arc<AtomicUsize>) -> Self {
Self {
scripted_writes: Arc::new(Mutex::new(script.iter().copied().collect())),
write_calls,
}
}
}
impl AsyncWrite for ScriptedWriter {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
this.write_calls.fetch_add(1, Ordering::Relaxed);
let planned = this
.scripted_writes
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.pop_front()
.unwrap_or(buf.len());
Poll::Ready(Ok(planned.min(buf.len())))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
fn make_stats_io_with_script(
user: &str,
quota_limit: u64,
precharged_quota: u64,
script: &[usize],
) -> (
StatsIo<ScriptedWriter>,
Arc<Stats>,
Arc<AtomicUsize>,
Arc<AtomicBool>,
) {
let stats = Arc::new(Stats::new());
if precharged_quota > 0 {
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), precharged_quota);
}
let write_calls = Arc::new(AtomicUsize::new(0));
let quota_exceeded = Arc::new(AtomicBool::new(false));
let io = StatsIo::new(
ScriptedWriter::new(script, write_calls.clone()),
Arc::new(SharedCounters::new()),
stats.clone(),
user.to_string(),
Some(quota_limit),
quota_exceeded.clone(),
Instant::now(),
);
(io, stats, write_calls, quota_exceeded)
}
#[tokio::test]
async fn direct_partial_write_charges_only_committed_bytes_without_double_charge() {
let user = "direct-partial-charge-user";
let (mut io, stats, write_calls, quota_exceeded) =
make_stats_io_with_script(user, 1_048_576, 0, &[8 * 1024, 8 * 1024, 48 * 1024]);
let payload = vec![0xAB; 64 * 1024];
let n1 = io
.write(&payload)
.await
.expect("first partial write must succeed");
let n2 = io
.write(&payload)
.await
.expect("second partial write must succeed");
let n3 = io.write(&payload).await.expect("tail write must succeed");
assert_eq!(n1, 8 * 1024);
assert_eq!(n2, 8 * 1024);
assert_eq!(n3, 48 * 1024);
assert_eq!(write_calls.load(Ordering::Relaxed), 3);
assert_eq!(
stats.get_user_quota_used(user),
(n1 + n2 + n3) as u64,
"quota accounting must follow committed bytes only"
);
assert_eq!(
stats.get_user_total_octets(user),
(n1 + n2 + n3) as u64,
"telemetry octets should match committed bytes on successful writes"
);
assert!(
!quota_exceeded.load(Ordering::Acquire),
"quota flag should stay false under large remaining budget"
);
}
#[tokio::test]
async fn direct_hybrid_branch_selection_matches_contract() {
let near_limit = 256 * 1024u64;
let near_remaining = 32 * 1024u64;
let (mut near_io, _stats, _calls, _flag) = make_stats_io_with_script(
"direct-near-limit-hard-check-user",
near_limit,
near_limit - near_remaining,
&[4 * 1024],
);
let near_payload = vec![0x11; 4 * 1024];
let near_written = near_io
.write(&near_payload)
.await
.expect("near-limit write must succeed");
assert_eq!(near_written, 4 * 1024);
assert_eq!(
near_io.quota_bytes_since_check, 0,
"near-limit branch must go through immediate hard check"
);
let (mut far_small_io, _stats, _calls, _flag) =
make_stats_io_with_script("direct-far-small-amortized-user", 1_048_576, 0, &[4 * 1024]);
let far_small_payload = vec![0x22; 4 * 1024];
let far_small_written = far_small_io
.write(&far_small_payload)
.await
.expect("small far-from-limit write must succeed");
assert_eq!(far_small_written, 4 * 1024);
assert_eq!(
far_small_io.quota_bytes_since_check,
4 * 1024,
"small far-from-limit write must go through amortized path"
);
let (mut far_large_io, _stats, _calls, _flag) = make_stats_io_with_script(
"direct-far-large-hard-check-user",
1_048_576,
0,
&[32 * 1024],
);
let far_large_payload = vec![0x33; 32 * 1024];
let far_large_written = far_large_io
.write(&far_large_payload)
.await
.expect("large write must succeed");
assert_eq!(far_large_written, 32 * 1024);
assert_eq!(
far_large_io.quota_bytes_since_check, 0,
"large write must force immediate hard check even far from limit"
);
}
#[tokio::test]
async fn remaining_before_zero_rejects_without_calling_inner_writer() {
let user = "direct-zero-remaining-user";
let limit = 8u64;
let (mut io, stats, write_calls, quota_exceeded) =
make_stats_io_with_script(user, limit, limit, &[1]);
let err = io
.write(&[0x44])
.await
.expect_err("write must fail when remaining quota is zero");
assert!(
is_quota_io_error(&err),
"zero-remaining gate must return typed quota I/O error"
);
assert_eq!(
write_calls.load(Ordering::Relaxed),
0,
"inner poll_write must not be called when remaining quota is zero"
);
assert!(
quota_exceeded.load(Ordering::Acquire),
"zero-remaining gate must set exceeded flag"
);
assert_eq!(stats.get_user_quota_used(user), limit);
}
#[tokio::test]
async fn exceeded_flag_blocks_following_poll_before_inner_write() {
let user = "direct-exceeded-visibility-user";
let (mut io, stats, write_calls, quota_exceeded) =
make_stats_io_with_script(user, 1, 0, &[1, 1]);
let first = io
.write(&[0x55])
.await
.expect("first byte should consume remaining quota");
assert_eq!(first, 1);
assert!(
quota_exceeded.load(Ordering::Acquire),
"hard check should store quota_exceeded after boundary hit"
);
let second = io
.write(&[0x66])
.await
.expect_err("next write must be rejected by early exceeded gate");
assert!(
is_quota_io_error(&second),
"following write must fail with typed quota error"
);
assert_eq!(
write_calls.load(Ordering::Relaxed),
1,
"second write must be cut before touching inner writer"
);
assert_eq!(stats.get_user_quota_used(user), 1);
}
#[test]
fn adaptive_interval_clamp_matches_contract() {
assert_eq!(quota_adaptive_interval_bytes(0), 4 * 1024);
assert_eq!(quota_adaptive_interval_bytes(2 * 1024), 4 * 1024);
assert_eq!(quota_adaptive_interval_bytes(32 * 1024), 16 * 1024);
assert_eq!(quota_adaptive_interval_bytes(256 * 1024), 64 * 1024);
assert!(should_immediate_quota_check(32 * 1024, 4 * 1024));
assert!(should_immediate_quota_check(1_048_576, 32 * 1024));
assert!(!should_immediate_quota_check(1_048_576, 4 * 1024));
}

View File

@ -29,6 +29,11 @@ async fn read_available<R: AsyncRead + Unpin>(reader: &mut R, budget: Duration)
total
}
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
}
#[tokio::test]
async fn integration_full_duplex_exact_budget_then_hard_cutoff() {
let stats = Arc::new(Stats::new());
@ -102,14 +107,14 @@ async fn integration_full_duplex_exact_budget_then_hard_cutoff() {
relay_result,
Err(ProxyError::DataQuotaExceeded { ref user }) if user == "quota-full-duplex-boundary-user"
));
assert!(stats.get_user_total_octets(user) <= 10);
assert!(stats.get_user_quota_used(user) <= 10);
}
#[tokio::test]
async fn negative_preloaded_quota_blocks_both_directions_immediately() {
let stats = Arc::new(Stats::new());
let user = "quota-preloaded-cutoff-user";
stats.add_user_octets_from(user, 5);
preload_user_quota(stats.as_ref(), user, 5);
let (mut client_peer, relay_client) = duplex(2048);
let (relay_server, mut server_peer) = duplex(2048);
@ -154,7 +159,7 @@ async fn negative_preloaded_quota_blocks_both_directions_immediately() {
relay_result,
Err(ProxyError::DataQuotaExceeded { .. })
));
assert!(stats.get_user_total_octets(user) <= 5);
assert!(stats.get_user_quota_used(user) <= 5);
}
#[tokio::test]
@ -212,7 +217,7 @@ async fn edge_quota_one_bidirectional_race_allows_at_most_one_forwarded_octet()
relay_result,
Err(ProxyError::DataQuotaExceeded { .. })
));
assert!(stats.get_user_total_octets(user) <= 1);
assert!(stats.get_user_quota_used(user) <= 1);
}
#[tokio::test]
@ -277,7 +282,7 @@ async fn adversarial_blackhat_alternating_fragmented_jitter_never_overshoots_glo
delivered_to_server + delivered_to_client <= quota as usize,
"combined forwarded bytes must never exceed configured quota"
);
assert!(stats.get_user_total_octets(user) <= quota);
assert!(stats.get_user_quota_used(user) <= quota);
}
#[tokio::test]
@ -356,7 +361,7 @@ async fn light_fuzz_randomized_schedule_preserves_quota_and_forwarded_byte_invar
"fuzz case {case}: forwarded bytes must not exceed quota"
);
assert!(
stats.get_user_total_octets(&user) <= quota,
stats.get_user_quota_used(&user) <= quota,
"fuzz case {case}: accounted bytes must not exceed quota"
);
}
@ -451,7 +456,7 @@ async fn stress_multi_relay_same_user_mixed_direction_jitter_respects_global_quo
}
assert!(
stats.get_user_total_octets(user) <= quota,
stats.get_user_quota_used(user) <= quota,
"global per-user quota must hold under concurrent mixed-direction relay stress"
);
assert!(

View File

@ -0,0 +1,399 @@
use super::relay_bidirectional;
use crate::error::ProxyError;
use crate::stats::Stats;
use crate::stream::BufferPool;
use rand::rngs::StdRng;
use rand::{RngExt, SeedableRng};
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::time::{Duration, timeout};
async fn read_available<R: tokio::io::AsyncRead + Unpin>(
reader: &mut R,
budget: Duration,
) -> usize {
let start = tokio::time::Instant::now();
let mut total = 0usize;
let mut buf = [0u8; 128];
loop {
let elapsed = start.elapsed();
if elapsed >= budget {
break;
}
let remaining = budget.saturating_sub(elapsed);
match timeout(remaining, reader.read(&mut buf)).await {
Ok(Ok(0)) => break,
Ok(Ok(n)) => total = total.saturating_add(n),
Ok(Err(_)) | Err(_) => break,
}
}
total
}
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
}
#[tokio::test]
async fn positive_quota_path_forwards_both_directions_within_limit() {
let stats = Arc::new(Stats::new());
let user = "quota-extended-positive-user";
let (mut client_peer, relay_client) = duplex(2048);
let (relay_server, mut server_peer) = duplex(2048);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
256,
256,
user,
Arc::clone(&stats),
Some(16),
Arc::new(BufferPool::new()),
));
client_peer
.write_all(&[0xAA, 0xBB, 0xCC, 0xDD])
.await
.unwrap();
server_peer.read_exact(&mut [0u8; 4]).await.unwrap();
server_peer
.write_all(&[0x11, 0x22, 0x33, 0x44])
.await
.unwrap();
client_peer.read_exact(&mut [0u8; 4]).await.unwrap();
drop(client_peer);
drop(server_peer);
let relay_result = timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
assert!(relay_result.is_ok());
assert!(stats.get_user_quota_used(user) <= 16);
}
#[tokio::test]
async fn negative_preloaded_quota_forbids_any_forwarding() {
let stats = Arc::new(Stats::new());
let user = "quota-extended-negative-user";
preload_user_quota(stats.as_ref(), user, 8);
let (mut client_peer, relay_client) = duplex(1024);
let (relay_server, mut server_peer) = duplex(1024);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
128,
128,
user,
Arc::clone(&stats),
Some(8),
Arc::new(BufferPool::new()),
));
client_peer.write_all(&[0xAA]).await.unwrap();
server_peer.write_all(&[0xBB]).await.unwrap();
assert_eq!(
read_available(&mut server_peer, Duration::from_millis(120)).await,
0
);
assert_eq!(
read_available(&mut client_peer, Duration::from_millis(120)).await,
0
);
let relay_result = timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
assert!(matches!(
relay_result,
Err(ProxyError::DataQuotaExceeded { .. })
));
assert!(stats.get_user_quota_used(user) <= 8);
}
#[tokio::test]
async fn edge_quota_one_ensures_at_most_one_byte_across_directions() {
let stats = Arc::new(Stats::new());
let user = "quota-extended-edge-user";
let (mut client_peer, relay_client) = duplex(1024);
let (relay_server, mut server_peer) = duplex(1024);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
128,
128,
user,
Arc::clone(&stats),
Some(1),
Arc::new(BufferPool::new()),
));
let _ = tokio::join!(
client_peer.write_all(&[0xFE]),
server_peer.write_all(&[0xEF]),
);
let mut buf = [0u8; 1];
let delivered_s2c = timeout(Duration::from_millis(120), client_peer.read(&mut buf))
.await
.unwrap()
.unwrap_or(0);
let delivered_c2s = timeout(Duration::from_millis(120), server_peer.read(&mut buf))
.await
.unwrap()
.unwrap_or(0);
assert!(delivered_s2c + delivered_c2s <= 1);
let relay_result = timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
assert!(matches!(
relay_result,
Err(ProxyError::DataQuotaExceeded { .. })
));
}
#[tokio::test]
async fn adversarial_blackhat_alternating_jitter_does_not_overshoot_quota() {
let stats = Arc::new(Stats::new());
let user = "quota-extended-blackhat-user";
let quota = 24u64;
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
256,
256,
user,
Arc::clone(&stats),
Some(quota),
Arc::new(BufferPool::new()),
));
let mut total_forwarded = 0usize;
for i in 0..256usize {
if relay.is_finished() {
break;
}
if (i & 1) == 0 {
let _ = client_peer.write_all(&[(i as u8) ^ 0x57]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), server_peer.read(&mut one)).await {
total_forwarded += n;
}
} else {
let _ = server_peer.write_all(&[(i as u8) ^ 0xA8]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) = timeout(Duration::from_millis(6), client_peer.read(&mut one)).await {
total_forwarded += n;
}
}
tokio::time::sleep(Duration::from_millis(((i % 3) + 1) as u64)).await;
}
let relay_result = timeout(Duration::from_secs(3), relay)
.await
.unwrap()
.unwrap();
assert!(matches!(
relay_result,
Err(ProxyError::DataQuotaExceeded { .. })
));
assert!(total_forwarded <= quota as usize);
assert!(stats.get_user_quota_used(user) <= quota);
}
#[tokio::test]
async fn light_fuzz_random_quota_schedule_preserves_quota_invariants() {
let mut rng = StdRng::seed_from_u64(0xBEEF_C0DE);
for case in 0..32u64 {
let stats = Arc::new(Stats::new());
let user = format!("quota-extended-fuzz-{case}");
let quota = rng.random_range(1u64..=35u64);
let (mut client_peer, relay_client) = duplex(4096);
let (relay_server, mut server_peer) = duplex(4096);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_user = user.clone();
let relay_stats = Arc::clone(&stats);
let relay = tokio::spawn(async move {
relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
256,
256,
&relay_user,
Arc::clone(&relay_stats),
Some(quota),
Arc::new(BufferPool::new()),
)
.await
});
let mut total_forwarded = 0usize;
for _ in 0..96usize {
if relay.is_finished() {
break;
}
if rng.random::<bool>() {
let _ = client_peer.write_all(&[rng.random::<u8>()]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) =
timeout(Duration::from_millis(4), server_peer.read(&mut one)).await
{
total_forwarded += n;
}
} else {
let _ = server_peer.write_all(&[rng.random::<u8>()]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) =
timeout(Duration::from_millis(4), client_peer.read(&mut one)).await
{
total_forwarded += n;
}
}
}
drop(client_peer);
drop(server_peer);
let relay_result = timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
assert!(
relay_result.is_ok()
|| matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))
);
assert!(total_forwarded <= quota as usize);
assert!(stats.get_user_quota_used(&user) <= quota);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_parallel_relays_for_one_user_obey_global_quota() {
let stats = Arc::new(Stats::new());
let user = "quota-extended-stress-user".to_string();
let quota = 64u64;
let mut tasks = Vec::new();
for worker in 0..4u8 {
let stats = Arc::clone(&stats);
let user = user.clone();
tasks.push(tokio::spawn(async move {
let (mut client_peer, relay_client) = duplex(2048);
let (relay_server, mut server_peer) = duplex(2048);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay_user = user.clone();
let relay_stats = Arc::clone(&stats);
let relay = tokio::spawn(async move {
relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
128,
128,
&relay_user,
Arc::clone(&relay_stats),
Some(quota),
Arc::new(BufferPool::new()),
)
.await
});
let mut total = 0usize;
for step in 0..64u8 {
if relay.is_finished() {
break;
}
if (step as usize + worker as usize) % 2 == 0 {
let _ = client_peer.write_all(&[(step ^ 0x5A)]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) =
timeout(Duration::from_millis(6), server_peer.read(&mut one)).await
{
total += n;
}
} else {
let _ = server_peer.write_all(&[(step ^ 0xA5)]).await;
let mut one = [0u8; 1];
if let Ok(Ok(n)) =
timeout(Duration::from_millis(6), client_peer.read(&mut one)).await
{
total += n;
}
}
tokio::time::sleep(Duration::from_millis(1)).await;
}
drop(client_peer);
drop(server_peer);
let relay_result = timeout(Duration::from_secs(2), relay)
.await
.unwrap()
.unwrap();
assert!(
relay_result.is_ok()
|| matches!(relay_result, Err(ProxyError::DataQuotaExceeded { .. }))
);
total
}));
}
let mut delivered = 0usize;
for task in tasks {
delivered += task.await.unwrap();
}
assert!(stats.get_user_quota_used(&user) <= quota);
assert!(delivered <= quota as usize);
}

View File

@ -1,438 +0,0 @@
use super::*;
use crate::error::ProxyError;
use crate::stats::Stats;
use crate::stream::BufferPool;
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};
use tokio::sync::Barrier;
use tokio::time::Instant;
#[test]
fn quota_lock_same_user_returns_same_arc_instance() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let a = quota_user_lock("quota-lock-same-user");
let b = quota_user_lock("quota-lock-same-user");
assert!(Arc::ptr_eq(&a, &b));
}
#[test]
fn quota_lock_parallel_same_user_reuses_single_lock() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let user = "quota-lock-parallel-same";
let mut handles = Vec::new();
for _ in 0..64 {
handles.push(std::thread::spawn(move || quota_user_lock(user)));
}
let first = handles
.remove(0)
.join()
.expect("thread must return lock handle");
for handle in handles {
let got = handle.join().expect("thread must return lock handle");
assert!(Arc::ptr_eq(&first, &got));
}
}
#[test]
fn quota_lock_unique_users_materialize_distinct_entries() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let base = format!("quota-lock-distinct-{}", std::process::id());
let users: Vec<String> = (0..(QUOTA_USER_LOCKS_MAX / 2))
.map(|idx| format!("{base}-{idx}"))
.collect();
for user in &users {
let _ = quota_user_lock(user);
}
for user in &users {
assert!(
map.get(user).is_some(),
"lock cache must contain entry for {user}"
);
}
}
#[test]
fn quota_lock_unique_churn_stress_keeps_all_inserted_keys_addressable() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let base = format!("quota-lock-churn-{}", std::process::id());
for idx in 0..(QUOTA_USER_LOCKS_MAX + 256) {
let _ = quota_user_lock(&format!("{base}-{idx}"));
}
assert!(
map.len() <= QUOTA_USER_LOCKS_MAX,
"quota lock cache must stay bounded under unique-user churn"
);
}
#[test]
fn quota_lock_saturation_returns_stable_overflow_lock_without_cache_growth() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let prefix = format!("quota-held-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"cache must be saturated for overflow check"
);
let overflow_user = format!("quota-overflow-{}", std::process::id());
let overflow_a = quota_user_lock(&overflow_user);
let overflow_b = quota_user_lock(&overflow_user);
assert_eq!(
map.len(),
QUOTA_USER_LOCKS_MAX,
"overflow path must not grow lock cache"
);
assert!(
map.get(&overflow_user).is_none(),
"overflow user lock must stay outside bounded cache under saturation"
);
assert!(
Arc::ptr_eq(&overflow_a, &overflow_b),
"overflow user must receive stable striped overflow lock while saturated"
);
drop(retained);
}
#[test]
fn quota_lock_reclaims_unreferenced_entries_before_ephemeral_fallback() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
// Saturate with retained strong references first so parallel tests cannot
// reclaim our fixture entries before we validate the reclaim path.
let prefix = format!("quota-reclaim-drop-{}", std::process::id());
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("{prefix}-{idx}")));
}
drop(retained);
let overflow_user = format!("quota-reclaim-overflow-{}", std::process::id());
let overflow = quota_user_lock(&overflow_user);
assert!(
map.get(&overflow_user).is_some(),
"after reclaiming stale entries, overflow user should become cacheable"
);
assert!(
Arc::strong_count(&overflow) >= 2,
"cacheable overflow lock should be held by both map and caller"
);
}
#[test]
fn quota_lock_saturated_same_user_must_not_return_distinct_locks() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!(
"quota-saturated-held-{}-{idx}",
std::process::id()
)));
}
let overflow_user = format!("quota-saturated-same-user-{}", std::process::id());
let a = quota_user_lock(&overflow_user);
let b = quota_user_lock(&overflow_user);
assert!(
Arc::ptr_eq(&a, &b),
"same user must not receive distinct locks under saturation because that enables quota race bypass"
);
drop(retained);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn quota_lock_saturation_concurrent_same_user_never_overshoots_quota() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!(
"quota-saturated-race-held-{}-{idx}",
std::process::id()
)));
}
let stats = Arc::new(Stats::new());
let user = format!("quota-saturated-race-user-{}", std::process::id());
let gate = Arc::new(Barrier::new(2));
let worker = |label: u8, stats: Arc<Stats>, user: String, gate: Arc<Barrier>| {
tokio::spawn(async move {
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user,
Some(1),
quota_exceeded,
Instant::now(),
);
gate.wait().await;
io.write_all(&[label]).await
})
};
let one = worker(0x11, Arc::clone(&stats), user.clone(), Arc::clone(&gate));
let two = worker(0x22, Arc::clone(&stats), user.clone(), Arc::clone(&gate));
let _ = tokio::time::timeout(Duration::from_secs(2), async {
let _ = one.await.expect("task one must not panic");
let _ = two.await.expect("task two must not panic");
})
.await
.expect("quota race workers must complete");
assert!(
stats.get_user_total_octets(&user) <= 1,
"saturated lock path must never overshoot quota for same user"
);
drop(retained);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn quota_lock_saturation_stress_same_user_never_overshoots_quota() {
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!(
"quota-saturated-stress-held-{}-{idx}",
std::process::id()
)));
}
for round in 0..128u32 {
let stats = Arc::new(Stats::new());
let user = format!("quota-saturated-stress-user-{}-{round}", std::process::id());
let gate = Arc::new(Barrier::new(2));
let one = {
let stats = Arc::clone(&stats);
let user = user.clone();
let gate = Arc::clone(&gate);
tokio::spawn(async move {
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user,
Some(1),
quota_exceeded,
Instant::now(),
);
gate.wait().await;
io.write_all(&[0x31]).await
})
};
let two = {
let stats = Arc::clone(&stats);
let user = user.clone();
let gate = Arc::clone(&gate);
tokio::spawn(async move {
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user,
Some(1),
quota_exceeded,
Instant::now(),
);
gate.wait().await;
io.write_all(&[0x32]).await
})
};
let _ = one.await.expect("stress task one must not panic");
let _ = two.await.expect("stress task two must not panic");
assert!(
stats.get_user_total_octets(&user) <= 1,
"round {round}: saturated path must not overshoot quota"
);
}
drop(retained);
}
#[test]
fn quota_error_classifier_accepts_internal_quota_sentinel_only() {
let err = quota_io_error();
assert!(is_quota_io_error(&err));
}
#[test]
fn quota_error_classifier_rejects_plain_permission_denied() {
let err = std::io::Error::new(std::io::ErrorKind::PermissionDenied, "permission denied");
assert!(!is_quota_io_error(&err));
}
#[test]
fn quota_lock_test_scope_recovers_after_guard_poison() {
let poison_result = std::thread::spawn(|| {
let _guard = super::quota_user_lock_test_scope();
panic!("intentional test-only guard poison");
})
.join();
assert!(poison_result.is_err(), "poison setup thread must panic");
let _guard = super::quota_user_lock_test_scope();
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let a = quota_user_lock("quota-lock-poison-recovery-user");
let b = quota_user_lock("quota-lock-poison-recovery-user");
assert!(Arc::ptr_eq(&a, &b));
}
#[tokio::test]
async fn quota_lock_integration_zero_quota_cuts_off_without_forwarding() {
let stats = Arc::new(Stats::new());
let user = "quota-zero-user";
let (mut client_peer, relay_client) = duplex(2048);
let (relay_server, mut server_peer) = duplex(2048);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
512,
512,
user,
Arc::clone(&stats),
Some(0),
Arc::new(BufferPool::new()),
));
client_peer
.write_all(b"x")
.await
.expect("client write must succeed");
let mut probe = [0u8; 1];
let forwarded =
tokio::time::timeout(Duration::from_millis(80), server_peer.read(&mut probe)).await;
if let Ok(Ok(n)) = forwarded {
assert_eq!(n, 0, "zero quota path must not forward payload bytes");
}
let result = tokio::time::timeout(Duration::from_secs(2), relay)
.await
.expect("relay must terminate under zero quota")
.expect("relay task must not panic");
assert!(matches!(result, Err(ProxyError::DataQuotaExceeded { .. })));
}
#[tokio::test]
async fn quota_lock_integration_no_quota_relays_both_directions_under_burst() {
let stats = Arc::new(Stats::new());
let (mut client_peer, relay_client) = duplex(8192);
let (relay_server, mut server_peer) = duplex(8192);
let (client_reader, client_writer) = tokio::io::split(relay_client);
let (server_reader, server_writer) = tokio::io::split(relay_server);
let relay = tokio::spawn(relay_bidirectional(
client_reader,
client_writer,
server_reader,
server_writer,
1024,
1024,
"quota-none-burst-user",
Arc::clone(&stats),
None,
Arc::new(BufferPool::new()),
));
let c2s = vec![0xA5; 2048];
let s2c = vec![0x5A; 1536];
client_peer
.write_all(&c2s)
.await
.expect("client burst write must succeed");
let mut got_c2s = vec![0u8; c2s.len()];
server_peer
.read_exact(&mut got_c2s)
.await
.expect("server must receive c2s burst");
assert_eq!(got_c2s, c2s);
server_peer
.write_all(&s2c)
.await
.expect("server burst write must succeed");
let mut got_s2c = vec![0u8; s2c.len()];
client_peer
.read_exact(&mut got_s2c)
.await
.expect("client must receive s2c burst");
assert_eq!(got_s2c, s2c);
drop(client_peer);
drop(server_peer);
let done = tokio::time::timeout(Duration::from_secs(2), relay)
.await
.expect("relay must terminate after peers close")
.expect("relay task must not panic");
assert!(done.is_ok());
}

View File

@ -32,6 +32,7 @@ async fn drain_available<R: AsyncRead + Unpin>(reader: &mut R, out: &mut Vec<u8>
#[tokio::test]
async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget() {
let mut rng = StdRng::seed_from_u64(0xC0DE_CAFE_D15C_F00D);
const MAX_INPUT_CHUNK: usize = 12;
for case in 0..64u64 {
let stats = Arc::new(Stats::new());
@ -92,12 +93,12 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget()
assert_is_prefix(&recv_at_server, &sent_c2s, "C->S");
assert_is_prefix(&recv_at_client, &sent_s2c, "S->C");
assert!(
recv_at_server.len() + recv_at_client.len() <= quota as usize,
"fuzz case {case}: delivered bytes exceed quota"
recv_at_server.len() + recv_at_client.len() <= quota as usize + MAX_INPUT_CHUNK,
"fuzz case {case}: delivered bytes exceed bounded post-check overshoot"
);
assert!(
stats.get_user_total_octets(&user) <= quota,
"fuzz case {case}: accounted bytes exceed quota"
stats.get_user_quota_used(&user) <= quota + MAX_INPUT_CHUNK as u64,
"fuzz case {case}: accounted bytes exceed bounded post-check overshoot"
);
}
@ -117,8 +118,8 @@ async fn model_fuzz_bidirectional_schedule_preserves_prefixes_and_quota_budget()
assert_is_prefix(&recv_at_server, &sent_c2s, "C->S final");
assert_is_prefix(&recv_at_client, &sent_s2c, "S->C final");
assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize);
assert!(stats.get_user_total_octets(&user) <= quota);
assert!(recv_at_server.len() + recv_at_client.len() <= quota as usize + MAX_INPUT_CHUNK);
assert!(stats.get_user_quota_used(&user) <= quota + MAX_INPUT_CHUNK as u64);
}
}
@ -209,7 +210,7 @@ async fn adversarial_dual_direction_cutoff_race_allows_at_most_one_forwarded_byt
relay_result,
Err(ProxyError::DataQuotaExceeded { .. })
));
assert!(stats.get_user_total_octets(user) <= 1);
assert!(stats.get_user_quota_used(user) <= 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
@ -217,9 +218,12 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode
let stats = Arc::new(Stats::new());
let user = "quota-model-stress-user";
let quota = 96u64;
const WORKERS: usize = 6;
const MAX_WORKER_CHUNK: u64 = 10;
let max_parallel_post_write_overshoot = WORKERS as u64 * MAX_WORKER_CHUNK;
let mut workers = Vec::new();
for worker_id in 0..6u64 {
for worker_id in 0..WORKERS as u64 {
let stats = Arc::clone(&stats);
let user = user.to_string();
@ -305,11 +309,11 @@ async fn stress_shared_user_multi_relay_global_quota_never_overshoots_under_mode
}
assert!(
stats.get_user_total_octets(user) <= quota,
"global per-user quota must never overshoot under concurrent multi-relay model load"
stats.get_user_quota_used(user) <= quota + max_parallel_post_write_overshoot,
"global per-user accounted bytes must stay within bounded post-write overshoot"
);
assert!(
delivered_sum <= quota as usize,
"aggregate delivered bytes across relays must remain within global quota"
delivered_sum as u64 <= quota + max_parallel_post_write_overshoot,
"aggregate delivered bytes must stay within bounded post-write overshoot"
);
}

View File

@ -19,13 +19,22 @@ async fn read_available<R: AsyncRead + Unpin>(reader: &mut R, budget_ms: u64) ->
total
}
fn preload_user_quota(stats: &Stats, user: &str, bytes: u64) {
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.quota_charge_post_write(user_stats.as_ref(), bytes);
}
#[tokio::test]
async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_accounting() {
let stats = Arc::new(Stats::new());
let user = "quota-overflow-regression-client-chunk";
let quota = 10u64;
let preloaded = 9u64;
let attempted_chunk = [0x11, 0x22, 0x33, 0x44];
let max_post_write_overshoot = attempted_chunk.len() as u64;
// Leave only 1 byte remaining under quota.
stats.add_user_octets_from(user, 9);
preload_user_quota(stats.as_ref(), user, preloaded);
let (mut client_peer, relay_client) = duplex(2048);
let (relay_server, mut server_peer) = duplex(2048);
@ -41,15 +50,12 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_
512,
user,
Arc::clone(&stats),
Some(10),
Some(quota),
Arc::new(BufferPool::new()),
));
// Single chunk attempts to cross remaining budget (4 > 1).
client_peer
.write_all(&[0x11, 0x22, 0x33, 0x44])
.await
.unwrap();
client_peer.write_all(&attempted_chunk).await.unwrap();
client_peer.shutdown().await.unwrap();
let forwarded = read_available(&mut server_peer, 60).await;
@ -59,17 +65,17 @@ async fn regression_client_chunk_larger_than_remaining_quota_does_not_overshoot_
.expect("relay must terminate after quota overflow attempt")
.expect("relay task must not panic");
assert_eq!(
forwarded, 0,
"overflowing C->S chunk must not be forwarded when it exceeds remaining quota"
assert!(
forwarded <= attempted_chunk.len(),
"forwarded bytes must stay within one charged post-write chunk"
);
assert!(matches!(
relay_result,
Err(ProxyError::DataQuotaExceeded { .. })
));
assert!(
stats.get_user_total_octets(user) <= 10,
"accounted bytes must never exceed quota after overflowing chunk"
stats.get_user_quota_used(user) <= quota + max_post_write_overshoot,
"accounted bytes must stay within bounded post-write overshoot"
);
}
@ -79,7 +85,7 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of
let user = "quota-overflow-regression-boundary";
// Leave exactly 4 bytes remaining.
stats.add_user_octets_from(user, 6);
preload_user_quota(stats.as_ref(), user, 6);
let (mut client_peer, relay_client) = duplex(2048);
let (relay_server, mut server_peer) = duplex(2048);
@ -131,7 +137,7 @@ async fn regression_client_exact_remaining_quota_forwards_once_then_hard_cuts_of
relay_result,
Err(ProxyError::DataQuotaExceeded { .. })
));
assert!(stats.get_user_total_octets(user) <= 10);
assert!(stats.get_user_quota_used(user) <= 10);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
@ -139,9 +145,12 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() {
let stats = Arc::new(Stats::new());
let user = "quota-overflow-regression-stress";
let quota = 12u64;
const WORKERS: usize = 4;
const BURST_LEN: usize = 64;
let max_parallel_post_write_overshoot = (WORKERS * BURST_LEN) as u64;
let mut handles = Vec::new();
for _ in 0..4usize {
for _ in 0..WORKERS {
let stats = Arc::clone(&stats);
let user = user.to_string();
@ -170,7 +179,7 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() {
});
// Aggressive sender tries to overflow shared user quota.
let burst = vec![0x5Au8; 64];
let burst = vec![0x5Au8; BURST_LEN];
let _ = client_peer.write_all(&burst).await;
let _ = client_peer.shutdown().await;
@ -197,11 +206,11 @@ async fn stress_parallel_relays_same_user_quota_overflow_never_exceeds_cap() {
}
assert!(
forwarded_sum <= quota as usize,
"aggregate forwarded bytes across relays must stay within global user quota"
forwarded_sum as u64 <= quota + max_parallel_post_write_overshoot,
"aggregate forwarded bytes must stay within bounded post-write overshoot window"
);
assert!(
stats.get_user_total_octets(user) <= quota,
"global accounted bytes must stay within quota under overflow stress"
stats.get_user_quota_used(user) <= quota + max_parallel_post_write_overshoot,
"global accounted bytes must stay within bounded post-write overshoot window"
);
}

View File

@ -1,294 +0,0 @@
use super::*;
use crate::stats::Stats;
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::Barrier;
use tokio::time::{Duration, timeout};
fn saturate_lock_cache() -> Vec<Arc<std::sync::Mutex<()>>> {
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("quota-liveness-saturated-{idx}")));
}
retained
}
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
#[tokio::test]
async fn positive_writer_progresses_after_contention_release_without_external_wake() {
let _guard = quota_test_guard();
let _retained = saturate_lock_cache();
let user = "quota-liveness-writer-positive";
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold user quota lock before write");
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let writer = tokio::spawn(async move { io.write_all(&[0x11]).await });
// Let the initial deferred wake fire while contention is still active.
tokio::time::sleep(Duration::from_millis(4)).await;
drop(held_guard);
let completed = timeout(Duration::from_millis(250), writer)
.await
.expect("writer must be re-polled and complete after lock release")
.expect("writer task must not panic");
assert!(completed.is_ok(), "writer must complete after lock release");
}
#[tokio::test]
async fn edge_reader_progresses_after_contention_release_without_external_wake() {
let _guard = quota_test_guard();
let _retained = saturate_lock_cache();
let user = "quota-liveness-reader-edge";
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold user quota lock before read");
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::empty(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let reader = tokio::spawn(async move {
let mut one = [0u8; 1];
io.read(&mut one).await
});
tokio::time::sleep(Duration::from_millis(4)).await;
drop(held_guard);
let completed = timeout(Duration::from_millis(250), reader)
.await
.expect("reader must be re-polled and complete after lock release")
.expect("reader task must not panic");
assert!(completed.is_ok(), "reader must complete after lock release");
}
#[tokio::test]
async fn adversarial_early_deferred_wake_consumption_does_not_deadlock_writer() {
let _guard = quota_test_guard();
let _retained = saturate_lock_cache();
let user = "quota-liveness-adversarial";
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold user quota lock before adversarial write");
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let writer = tokio::spawn(async move { io.write_all(&[0x22]).await });
// Force multiple scheduler rounds while lock remains held so the first
// deferred wake has already been consumed under contention.
for _ in 0..32 {
tokio::task::yield_now().await;
}
drop(held_guard);
let completed = timeout(Duration::from_millis(300), writer)
.await
.expect("writer must not stay parked forever after release")
.expect("writer task must not panic");
assert!(completed.is_ok());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn integration_parallel_waiters_resume_after_single_release_event() {
let _guard = quota_test_guard();
let _retained = saturate_lock_cache();
let user = format!("quota-liveness-integration-{}", std::process::id());
let stats = Arc::new(Stats::new());
let barrier = Arc::new(Barrier::new(13));
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold user quota lock before launching waiters");
let mut waiters = Vec::new();
for _ in 0..12 {
let stats = Arc::clone(&stats);
let user = user.clone();
let barrier = Arc::clone(&barrier);
waiters.push(tokio::spawn(async move {
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
stats,
user,
Some(4096),
quota_exceeded,
tokio::time::Instant::now(),
);
barrier.wait().await;
io.write_all(&[0x33]).await
}));
}
barrier.wait().await;
tokio::time::sleep(Duration::from_millis(4)).await;
drop(held_guard);
timeout(Duration::from_secs(1), async {
for waiter in waiters {
let outcome = waiter.await.expect("waiter must not panic");
assert!(
outcome.is_ok(),
"waiter must resume and complete after release"
);
}
})
.await
.expect("all waiters must complete in bounded time");
}
#[tokio::test]
async fn light_fuzz_release_timing_matrix_preserves_liveness() {
let _guard = quota_test_guard();
let _retained = saturate_lock_cache();
let stats = Arc::new(Stats::new());
let mut seed = 0xD1CE_F00D_0123_4567u64;
for round in 0..64u32 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
let delay_ms = 1 + (seed & 0x7) as u64;
let user = format!("quota-liveness-fuzz-{}-{round}", std::process::id());
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold user quota lock in fuzz round");
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user,
Some(2048),
quota_exceeded,
tokio::time::Instant::now(),
);
let writer = tokio::spawn(async move { io.write_all(&[0x44]).await });
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
drop(held_guard);
let done = timeout(Duration::from_millis(300), writer)
.await
.expect("fuzz round writer must complete")
.expect("fuzz writer task must not panic");
assert!(
done.is_ok(),
"fuzz round writer must not stall after release"
);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_repeated_contention_cycles_remain_live() {
let _guard = quota_test_guard();
let _retained = saturate_lock_cache();
let stats = Arc::new(Stats::new());
for cycle in 0..40u32 {
let user = format!("quota-liveness-stress-{}-{cycle}", std::process::id());
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold lock before stress cycle");
let mut tasks = Vec::new();
for _ in 0..6 {
let stats = Arc::clone(&stats);
let user = user.clone();
tasks.push(tokio::spawn(async move {
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
stats,
user,
Some(2048),
quota_exceeded,
tokio::time::Instant::now(),
);
io.write_all(&[0x55]).await
}));
}
tokio::task::yield_now().await;
drop(held_guard);
timeout(Duration::from_millis(700), async {
for task in tasks {
let outcome = task.await.expect("stress task must not panic");
assert!(outcome.is_ok(), "stress writer must complete");
}
})
.await
.expect("stress cycle must finish in bounded time");
}
}

View File

@ -1,310 +0,0 @@
use super::*;
use crate::stats::Stats;
use dashmap::DashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::task::{Context, Waker};
use tokio::io::{AsyncWriteExt, ReadBuf};
use tokio::time::{Duration, timeout};
#[derive(Default)]
struct WakeCounter {
wakes: AtomicUsize,
}
impl std::task::Wake for WakeCounter {
fn wake(self: Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wakes.fetch_add(1, Ordering::Relaxed);
}
}
fn quota_test_guard() -> impl Drop {
super::quota_user_lock_test_scope()
}
fn saturate_quota_user_locks() -> Vec<Arc<std::sync::Mutex<()>>> {
let map = QUOTA_USER_LOCKS.get_or_init(DashMap::new);
map.clear();
let mut retained = Vec::with_capacity(QUOTA_USER_LOCKS_MAX);
for idx in 0..QUOTA_USER_LOCKS_MAX {
retained.push(quota_user_lock(&format!("quota-waker-saturate-{idx}")));
}
retained
}
#[tokio::test]
async fn positive_contended_writer_emits_deferred_wake_for_liveness() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let stats = Arc::new(Stats::new());
let user = "quota-waker-positive-user";
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold overflow lock before polling writer");
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let pending = Pin::new(&mut io).poll_write(&mut cx, &[0xA1]);
assert!(pending.is_pending());
timeout(Duration::from_millis(100), async {
loop {
if wake_counter.wakes.load(Ordering::Relaxed) >= 1 {
break;
}
tokio::task::yield_now().await;
}
})
.await
.expect("contended writer must receive deferred wake");
drop(held_guard);
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xA2]);
assert!(
ready.is_ready(),
"writer must progress after contention release"
);
}
#[tokio::test]
async fn adversarial_blackhat_writer_contention_does_not_create_waker_storm() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let stats = Arc::new(Stats::new());
let user = "quota-waker-blackhat-writer";
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold overflow lock before polling writer");
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
for _ in 0..512 {
let poll = Pin::new(&mut io).poll_write(&mut cx, &[0xBE]);
assert!(
poll.is_pending(),
"writer must stay pending while lock is held"
);
tokio::task::yield_now().await;
}
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
assert!(
wakes <= 128,
"pending writer retries must not trigger wake storm; observed wakes={wakes}"
);
drop(held_guard);
let ready = Pin::new(&mut io).poll_write(&mut cx, &[0xEF]);
assert!(ready.is_ready());
}
#[tokio::test]
async fn edge_read_path_contention_keeps_wake_budget_bounded() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let stats = Arc::new(Stats::new());
let user = "quota-waker-read-edge";
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold overflow lock before polling reader");
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::empty(),
counters,
Arc::clone(&stats),
user.to_string(),
Some(1024),
quota_exceeded,
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let mut storage = [0u8; 1];
for _ in 0..512 {
let mut buf = ReadBuf::new(&mut storage);
let poll = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
assert!(poll.is_pending());
tokio::task::yield_now().await;
}
let wakes = wake_counter.wakes.load(Ordering::Relaxed);
assert!(
wakes <= 128,
"pending reader retries must not trigger wake storm; observed wakes={wakes}"
);
drop(held_guard);
let mut buf = ReadBuf::new(&mut storage);
let ready = Pin::new(&mut io).poll_read(&mut cx, &mut buf);
assert!(ready.is_ready());
}
#[tokio::test]
async fn light_fuzz_mixed_poll_schedule_under_contention_stays_bounded() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let stats = Arc::new(Stats::new());
let user = "quota-waker-fuzz-user";
let lock = quota_user_lock(user);
let held_guard = lock
.try_lock()
.expect("test must hold overflow lock before fuzz polling");
let counters_w = Arc::new(SharedCounters::new());
let mut writer_io = StatsIo::new(
tokio::io::sink(),
counters_w,
Arc::clone(&stats),
user.to_string(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let counters_r = Arc::new(SharedCounters::new());
let mut reader_io = StatsIo::new(
tokio::io::empty(),
counters_r,
Arc::clone(&stats),
user.to_string(),
Some(1024),
Arc::new(AtomicBool::new(false)),
tokio::time::Instant::now(),
);
let wake_counter = Arc::new(WakeCounter::default());
let waker = Waker::from(Arc::clone(&wake_counter));
let mut cx = Context::from_waker(&waker);
let mut seed = 0xBADC_0FFE_EE11_2211u64;
let mut storage = [0u8; 1];
for _ in 0..1024 {
seed ^= seed << 7;
seed ^= seed >> 9;
seed ^= seed << 8;
if (seed & 1) == 0 {
let poll = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x44]);
assert!(poll.is_pending());
} else {
let mut buf = ReadBuf::new(&mut storage);
let poll = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf);
assert!(poll.is_pending());
}
tokio::task::yield_now().await;
}
assert!(
wake_counter.wakes.load(Ordering::Relaxed) <= 192,
"mixed contention fuzz must keep deferred wake count tightly bounded"
);
drop(held_guard);
let ready_w = Pin::new(&mut writer_io).poll_write(&mut cx, &[0x55]);
assert!(ready_w.is_ready());
let mut buf = ReadBuf::new(&mut storage);
let ready_r = Pin::new(&mut reader_io).poll_read(&mut cx, &mut buf);
assert!(ready_r.is_ready());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[ignore = "red-team detector: reveals possible starvation if deferred wake fires before contention release"]
async fn stress_many_contended_writers_complete_after_release() {
let _guard = quota_test_guard();
let _retained = saturate_quota_user_locks();
let user = "quota-waker-stress-user".to_string();
let stats = Arc::new(Stats::new());
let lock = quota_user_lock(&user);
let held_guard = lock
.try_lock()
.expect("test must hold overflow lock before launching contended tasks");
let mut tasks = Vec::new();
for _ in 0..32 {
let stats = Arc::clone(&stats);
let user = user.clone();
tasks.push(tokio::spawn(async move {
let counters = Arc::new(SharedCounters::new());
let quota_exceeded = Arc::new(AtomicBool::new(false));
let mut io = StatsIo::new(
tokio::io::sink(),
counters,
stats,
user,
Some(2048),
quota_exceeded,
tokio::time::Instant::now(),
);
io.write_all(&[0xAA]).await
}));
}
for _ in 0..8 {
tokio::task::yield_now().await;
}
drop(held_guard);
timeout(Duration::from_secs(2), async {
for task in tasks {
let result = task.await.expect("stress task must not panic");
assert!(result.is_ok(), "task must complete after lock release");
}
})
.await
.expect("all contended writer tasks must finish in bounded time after release");
}

File diff suppressed because it is too large Load Diff

View File

@ -238,10 +238,12 @@ pub struct Stats {
me_inline_recovery_total: AtomicU64,
ip_reservation_rollback_tcp_limit_total: AtomicU64,
ip_reservation_rollback_quota_limit_total: AtomicU64,
quota_write_fail_bytes_total: AtomicU64,
quota_write_fail_events_total: AtomicU64,
telemetry_core_enabled: AtomicBool,
telemetry_user_enabled: AtomicBool,
telemetry_me_level: AtomicU8,
user_stats: DashMap<String, UserStats>,
user_stats: DashMap<String, Arc<UserStats>>,
user_stats_last_cleanup_epoch_secs: AtomicU64,
start_time: parking_lot::RwLock<Option<Instant>>,
}
@ -254,9 +256,51 @@ pub struct UserStats {
pub octets_to_client: AtomicU64,
pub msgs_from_client: AtomicU64,
pub msgs_to_client: AtomicU64,
/// Total bytes charged against per-user quota admission.
///
/// This counter is the single source of truth for quota enforcement and
/// intentionally tracks attempted traffic, not guaranteed delivery.
pub quota_used: AtomicU64,
pub last_seen_epoch_secs: AtomicU64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QuotaReserveError {
LimitExceeded,
Contended,
}
impl UserStats {
#[inline]
pub fn quota_used(&self) -> u64 {
self.quota_used.load(Ordering::Relaxed)
}
/// Attempts one CAS reservation step against the quota counter.
///
/// Callers control retry/yield policy. This primitive intentionally does
/// not block or sleep so both sync poll paths and async paths can wrap it
/// with their own contention strategy.
#[inline]
pub fn quota_try_reserve(&self, bytes: u64, limit: u64) -> Result<u64, QuotaReserveError> {
let current = self.quota_used.load(Ordering::Relaxed);
if bytes > limit.saturating_sub(current) {
return Err(QuotaReserveError::LimitExceeded);
}
let next = current.saturating_add(bytes);
match self.quota_used.compare_exchange_weak(
current,
next,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => Ok(next),
Err(_) => Err(QuotaReserveError::Contended),
}
}
}
impl Stats {
pub fn new() -> Self {
let stats = Self::default();
@ -316,6 +360,74 @@ impl Stats {
.store(Self::now_epoch_secs(), Ordering::Relaxed);
}
pub(crate) fn get_or_create_user_stats_handle(&self, user: &str) -> Arc<UserStats> {
self.maybe_cleanup_user_stats();
if let Some(existing) = self.user_stats.get(user) {
let handle = Arc::clone(existing.value());
Self::touch_user_stats(handle.as_ref());
return handle;
}
let entry = self.user_stats.entry(user.to_string()).or_default();
if entry.last_seen_epoch_secs.load(Ordering::Relaxed) == 0 {
Self::touch_user_stats(entry.value().as_ref());
}
Arc::clone(entry.value())
}
#[inline]
pub(crate) fn add_user_octets_from_handle(&self, user_stats: &UserStats, bytes: u64) {
if !self.telemetry_user_enabled() {
return;
}
Self::touch_user_stats(user_stats);
user_stats
.octets_from_client
.fetch_add(bytes, Ordering::Relaxed);
}
#[inline]
pub(crate) fn add_user_octets_to_handle(&self, user_stats: &UserStats, bytes: u64) {
if !self.telemetry_user_enabled() {
return;
}
Self::touch_user_stats(user_stats);
user_stats
.octets_to_client
.fetch_add(bytes, Ordering::Relaxed);
}
#[inline]
pub(crate) fn increment_user_msgs_from_handle(&self, user_stats: &UserStats) {
if !self.telemetry_user_enabled() {
return;
}
Self::touch_user_stats(user_stats);
user_stats.msgs_from_client.fetch_add(1, Ordering::Relaxed);
}
#[inline]
pub(crate) fn increment_user_msgs_to_handle(&self, user_stats: &UserStats) {
if !self.telemetry_user_enabled() {
return;
}
Self::touch_user_stats(user_stats);
user_stats.msgs_to_client.fetch_add(1, Ordering::Relaxed);
}
/// Charges already committed bytes in a post-I/O path.
///
/// This helper is intentionally separate from `quota_try_reserve` to avoid
/// mixing reserve and post-charge on a single I/O event.
#[inline]
pub(crate) fn quota_charge_post_write(&self, user_stats: &UserStats, bytes: u64) -> u64 {
Self::touch_user_stats(user_stats);
user_stats
.quota_used
.fetch_add(bytes, Ordering::Relaxed)
.saturating_add(bytes)
}
fn maybe_cleanup_user_stats(&self) {
const USER_STATS_CLEANUP_INTERVAL_SECS: u64 = 60;
const USER_STATS_IDLE_TTL_SECS: u64 = 24 * 60 * 60;
@ -704,7 +816,8 @@ impl Stats {
}
pub fn increment_me_d2c_data_frames_total(&self) {
if self.telemetry_me_allows_normal() {
self.me_d2c_data_frames_total.fetch_add(1, Ordering::Relaxed);
self.me_d2c_data_frames_total
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_me_d2c_ack_frames_total(&self) {
@ -1114,6 +1227,18 @@ impl Stats {
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn add_quota_write_fail_bytes_total(&self, bytes: u64) {
if self.telemetry_core_enabled() {
self.quota_write_fail_bytes_total
.fetch_add(bytes, Ordering::Relaxed);
}
}
pub fn increment_quota_write_fail_events_total(&self) {
if self.telemetry_core_enabled() {
self.quota_write_fail_events_total
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn increment_me_endpoint_quarantine_total(&self) {
if self.telemetry_me_allows_normal() {
self.me_endpoint_quarantine_total
@ -1588,7 +1713,8 @@ impl Stats {
self.me_d2c_batch_bytes_bucket_1k_4k.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_bytes_bucket_4k_16k(&self) -> u64 {
self.me_d2c_batch_bytes_bucket_4k_16k.load(Ordering::Relaxed)
self.me_d2c_batch_bytes_bucket_4k_16k
.load(Ordering::Relaxed)
}
pub fn get_me_d2c_batch_bytes_bucket_16k_64k(&self) -> u64 {
self.me_d2c_batch_bytes_bucket_16k_64k
@ -1764,19 +1890,19 @@ impl Stats {
self.ip_reservation_rollback_quota_limit_total
.load(Ordering::Relaxed)
}
pub fn get_quota_write_fail_bytes_total(&self) -> u64 {
self.quota_write_fail_bytes_total.load(Ordering::Relaxed)
}
pub fn get_quota_write_fail_events_total(&self) -> u64 {
self.quota_write_fail_events_total.load(Ordering::Relaxed)
}
pub fn increment_user_connects(&self, user: &str) {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.connects.fetch_add(1, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
let stats = self.get_or_create_user_stats_handle(user);
Self::touch_user_stats(stats.as_ref());
stats.connects.fetch_add(1, Ordering::Relaxed);
}
@ -1784,14 +1910,8 @@ impl Stats {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.curr_connects.fetch_add(1, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
let stats = self.get_or_create_user_stats_handle(user);
Self::touch_user_stats(stats.as_ref());
stats.curr_connects.fetch_add(1, Ordering::Relaxed);
}
@ -1800,9 +1920,8 @@ impl Stats {
return true;
}
self.maybe_cleanup_user_stats();
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
let stats = self.get_or_create_user_stats_handle(user);
Self::touch_user_stats(stats.as_ref());
let counter = &stats.curr_connects;
let mut current = counter.load(Ordering::Relaxed);
@ -1827,7 +1946,7 @@ impl Stats {
pub fn decrement_user_curr_connects(&self, user: &str) {
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
Self::touch_user_stats(stats.value().as_ref());
let counter = &stats.curr_connects;
let mut current = counter.load(Ordering::Relaxed);
loop {
@ -1858,60 +1977,32 @@ impl Stats {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
stats.octets_from_client.fetch_add(bytes, Ordering::Relaxed);
let stats = self.get_or_create_user_stats_handle(user);
self.add_user_octets_from_handle(stats.as_ref(), bytes);
}
pub fn add_user_octets_to(&self, user: &str, bytes: u64) {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
stats.octets_to_client.fetch_add(bytes, Ordering::Relaxed);
let stats = self.get_or_create_user_stats_handle(user);
self.add_user_octets_to_handle(stats.as_ref(), bytes);
}
pub fn increment_user_msgs_from(&self, user: &str) {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.msgs_from_client.fetch_add(1, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
stats.msgs_from_client.fetch_add(1, Ordering::Relaxed);
let stats = self.get_or_create_user_stats_handle(user);
self.increment_user_msgs_from_handle(stats.as_ref());
}
pub fn increment_user_msgs_to(&self, user: &str) {
if !self.telemetry_user_enabled() {
return;
}
self.maybe_cleanup_user_stats();
if let Some(stats) = self.user_stats.get(user) {
Self::touch_user_stats(stats.value());
stats.msgs_to_client.fetch_add(1, Ordering::Relaxed);
return;
}
let stats = self.user_stats.entry(user.to_string()).or_default();
Self::touch_user_stats(stats.value());
stats.msgs_to_client.fetch_add(1, Ordering::Relaxed);
let stats = self.get_or_create_user_stats_handle(user);
self.increment_user_msgs_to_handle(stats.as_ref());
}
pub fn get_user_total_octets(&self, user: &str) -> u64 {
@ -1924,6 +2015,13 @@ impl Stats {
.unwrap_or(0)
}
pub fn get_user_quota_used(&self, user: &str) -> u64 {
self.user_stats
.get(user)
.map(|s| s.quota_used.load(Ordering::Relaxed))
.unwrap_or(0)
}
pub fn get_handshake_timeouts(&self) -> u64 {
self.handshake_timeouts.load(Ordering::Relaxed)
}
@ -1989,7 +2087,7 @@ impl Stats {
.load(Ordering::Relaxed)
}
pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, UserStats> {
pub fn iter_user_stats(&self) -> dashmap::iter::Iter<'_, String, Arc<UserStats>> {
self.user_stats.iter()
}
@ -2137,6 +2235,22 @@ impl ReplayChecker {
found
}
fn check_only_internal(
&self,
data: &[u8],
shards: &[Mutex<ReplayShard>],
window: Duration,
) -> bool {
self.checks.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
let mut shard = shards[idx].lock();
let found = shard.check(data, Instant::now(), window);
if found {
self.hits.fetch_add(1, Ordering::Relaxed);
}
found
}
fn add_only(&self, data: &[u8], shards: &[Mutex<ReplayShard>], window: Duration) {
self.additions.fetch_add(1, Ordering::Relaxed);
let idx = self.get_shard_idx(data);
@ -2160,7 +2274,7 @@ impl ReplayChecker {
self.add_only(data, &self.handshake_shards, self.window)
}
pub fn check_tls_digest(&self, data: &[u8]) -> bool {
self.check_and_add_tls_digest(data)
self.check_only_internal(data, &self.tls_shards, self.tls_window)
}
pub fn add_tls_digest(&self, data: &[u8]) {
self.add_only(data, &self.tls_shards, self.tls_window)
@ -2264,6 +2378,7 @@ mod tests {
use super::*;
use crate::config::MeTelemetryLevel;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
#[test]
fn test_stats_shared_counters() {
@ -2431,6 +2546,137 @@ mod tests {
}
assert_eq!(checker.stats().total_entries, 500);
}
#[test]
fn test_quota_reserve_under_contention_hits_limit_exactly() {
let user_stats = Arc::new(UserStats::default());
let successes = Arc::new(AtomicU64::new(0));
let limit = 8_192u64;
let mut workers = Vec::new();
for _ in 0..8 {
let user_stats = user_stats.clone();
let successes = successes.clone();
workers.push(std::thread::spawn(move || {
loop {
match user_stats.quota_try_reserve(1, limit) {
Ok(_) => {
successes.fetch_add(1, Ordering::Relaxed);
}
Err(QuotaReserveError::Contended) => {
std::hint::spin_loop();
}
Err(QuotaReserveError::LimitExceeded) => {
break;
}
}
}
}));
}
for worker in workers {
worker.join().expect("worker thread must finish");
}
assert_eq!(
successes.load(Ordering::Relaxed),
limit,
"successful reservations must stop exactly at limit"
);
assert_eq!(user_stats.quota_used(), limit);
}
#[test]
fn test_quota_reserve_200x_1k_reaches_100k_without_overshoot() {
let user_stats = Arc::new(UserStats::default());
let successes = Arc::new(AtomicU64::new(0));
let failures = Arc::new(AtomicU64::new(0));
let attempts = 200usize;
let reserve_bytes = 1_024u64;
let limit = 100 * 1_024u64;
let mut workers = Vec::with_capacity(attempts);
for _ in 0..attempts {
let user_stats = user_stats.clone();
let successes = successes.clone();
let failures = failures.clone();
workers.push(std::thread::spawn(move || {
loop {
match user_stats.quota_try_reserve(reserve_bytes, limit) {
Ok(_) => {
successes.fetch_add(1, Ordering::Relaxed);
return;
}
Err(QuotaReserveError::LimitExceeded) => {
failures.fetch_add(1, Ordering::Relaxed);
return;
}
Err(QuotaReserveError::Contended) => {
std::hint::spin_loop();
}
}
}
}));
}
for worker in workers {
worker.join().expect("reservation worker must finish");
}
assert_eq!(
successes.load(Ordering::Relaxed),
100,
"exactly 100 reservations of 1 KiB must fit into a 100 KiB quota"
);
assert_eq!(
failures.load(Ordering::Relaxed),
100,
"remaining workers must fail once quota is fully reserved"
);
assert_eq!(user_stats.quota_used(), limit);
}
#[test]
fn test_quota_used_is_authoritative_and_independent_from_octets_telemetry() {
let stats = Stats::new();
let user = "quota-authoritative-user";
let user_stats = stats.get_or_create_user_stats_handle(user);
stats.add_user_octets_to_handle(&user_stats, 5);
assert_eq!(stats.get_user_total_octets(user), 5);
assert_eq!(stats.get_user_quota_used(user), 0);
stats.quota_charge_post_write(&user_stats, 7);
assert_eq!(stats.get_user_total_octets(user), 5);
assert_eq!(stats.get_user_quota_used(user), 7);
}
#[test]
fn test_cached_handle_survives_map_cleanup_until_last_drop() {
let stats = Stats::new();
let user = "quota-handle-lifetime-user";
let user_stats = stats.get_or_create_user_stats_handle(user);
let weak = Arc::downgrade(&user_stats);
stats.user_stats.remove(user);
assert!(
stats.user_stats.get(user).is_none(),
"map cleanup should remove idle entry"
);
assert!(
weak.upgrade().is_some(),
"cached handle must keep user stats object alive after map removal"
);
stats.quota_charge_post_write(user_stats.as_ref(), 3);
assert_eq!(user_stats.quota_used(), 3);
drop(user_stats);
assert!(
weak.upgrade().is_none(),
"user stats object must be dropped after the last cached handle is released"
);
}
}
#[cfg(test)]

View File

@ -14,7 +14,10 @@ fn padding_rounding_equivalent_for_extensive_safe_domain() {
let old = old_padding_round_up_to_4(len).expect("old expression must be safe");
let new = new_padding_round_up_to_4(len).expect("new expression must be safe");
assert_eq!(old, new, "mismatch for len={len}");
assert!(new >= len, "rounded length must not shrink: len={len}, out={new}");
assert!(
new >= len,
"rounded length must not shrink: len={len}, out={new}"
);
assert_eq!(new % 4, 0, "rounded length must stay 4-byte aligned");
}
}

View File

@ -44,7 +44,10 @@ async fn encapsulation_repeated_queue_poison_recovery_preserves_forward_progress
let ip_primary = ip_from_idx(10_001);
let ip_alt = ip_from_idx(10_002);
tracker.check_and_add("encap-poison", ip_primary).await.unwrap();
tracker
.check_and_add("encap-poison", ip_primary)
.await
.unwrap();
for _ in 0..128 {
let queue = tracker.cleanup_queue_mutex_for_tests();

Some files were not shown because too many files have changed in this diff Show More